1use std::path::Path;
39
40use anyhow::Result;
41use clap::Args;
42use serde_json::json;
43
44use crate::cli::CliOutput;
45use crate::config::AppConfig;
46use crate::db;
47use crate::embeddings::Embed;
48
49pub const EXIT_NO_EMBEDDER: i32 = 2;
52
53pub const EXIT_EMBEDDER_INIT_FAILED: i32 = 3;
56
57#[derive(Args, Debug, Clone)]
59pub struct ReembedArgs {
60 #[arg(long)]
62 pub namespace: Option<String>,
63
64 #[arg(long)]
68 pub dry_run: bool,
69
70 #[arg(long)]
74 pub batch: Option<usize>,
75
76 #[arg(long)]
80 pub json: bool,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
86pub(crate) struct ReembedPlan {
87 pub(crate) total_rows: u64,
88 pub(crate) rows_with_embeddings: u64,
89 pub(crate) rows_missing_embeddings: u64,
90 pub(crate) target_model: String,
91 pub(crate) target_dim: usize,
92 pub(crate) backend: String,
93}
94
95#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
97pub(crate) struct ReembedOutcome {
98 pub(crate) total: usize,
100 pub(crate) reembedded: usize,
102 pub(crate) skipped: usize,
105}
106
107pub(crate) fn build_plan(
114 conn: &rusqlite::Connection,
115 namespace: Option<&str>,
116 target_model: &str,
117 target_dim: usize,
118 backend: &str,
119) -> Result<ReembedPlan> {
120 let (total_rows, rows_with_embeddings) = db::embedding_coverage(conn, namespace)?;
121 Ok(ReembedPlan {
122 total_rows,
123 rows_with_embeddings,
124 rows_missing_embeddings: total_rows.saturating_sub(rows_with_embeddings),
125 target_model: target_model.to_string(),
126 target_dim,
127 backend: backend.to_string(),
128 })
129}
130
131pub(crate) fn resolve_batch_size(batch_flag: Option<usize>, resolved_default: usize) -> usize {
136 batch_flag
137 .filter(|&n| n > 0)
138 .or(Some(resolved_default).filter(|&n| n > 0))
139 .unwrap_or(crate::mcp::DEFAULT_EMBED_BACKFILL_BATCH_SIZE)
140}
141
142pub(crate) fn run_reembed_live(
155 conn: &mut rusqlite::Connection,
156 emb: &dyn Embed,
157 namespace: Option<&str>,
158 batch_size: usize,
159 out: &mut CliOutput<'_>,
160) -> Result<ReembedOutcome> {
161 let mut outcome = ReembedOutcome::default();
162 let mut cursor: Option<String> = None;
163 loop {
164 let chunk = db::get_memory_texts_batch(conn, namespace, cursor.as_deref(), batch_size)?;
165 if chunk.is_empty() {
166 break;
167 }
168 outcome.total += chunk.len();
169 cursor = chunk.last().map(|(id, _, _)| id.clone());
170
171 let embedded = crate::mcp::embed_rows_with_fallback(emb, &chunk);
172 for (id, reason) in &embedded.skipped {
173 writeln!(
174 out.stderr,
175 "reembed: skipped row {id}: {reason} (previous vector kept, #1598)"
176 )?;
177 }
178 outcome.skipped += embedded.skipped.len();
179 if embedded.entries.is_empty() {
180 continue;
181 }
182 outcome.reembedded += db::set_embeddings_batch_reembed(conn, &embedded.entries)?;
183 }
184 Ok(outcome)
185}
186
187pub async fn cmd_reembed(
203 db_path: &Path,
204 args: &ReembedArgs,
205 app_config: &AppConfig,
206 out: &mut CliOutput<'_>,
207) -> Result<i32> {
208 let feature_tier = app_config.effective_tier(None);
209 let tier_config = feature_tier.config();
210 let resolved = app_config.resolve_embeddings();
211
212 let tier_model = if crate::config::is_api_embed_backend(&resolved.backend) {
216 tier_config.embedding_model
217 } else {
218 crate::daemon_runtime::resolve_embedder_model(&tier_config, app_config)
219 };
220 let Some(tier_model) = tier_model else {
221 writeln!(
222 out.stderr,
223 "reembed: tier '{}' is keyword-only (no embedding model) — reembed \
224 requires an embedding-capable tier (set `tier = \"semantic\"` or \
225 above in config.toml, or configure [embeddings] / \
226 AI_MEMORY_EMBED_* for an API backend)",
227 feature_tier.as_str()
228 )?;
229 return Ok(EXIT_NO_EMBEDDER);
230 };
231
232 let resolved_for_build = resolved.clone();
235 let built = tokio::task::spawn_blocking(move || {
236 crate::embeddings::Embedder::from_resolved(&resolved_for_build, Some(tier_model))
237 })
238 .await?;
239 let embedder = match built {
240 Ok(Some(emb)) => emb,
241 Ok(None) => {
245 writeln!(
246 out.stderr,
247 "reembed: resolver returned no embedder for tier '{}'",
248 feature_tier.as_str()
249 )?;
250 return Ok(EXIT_NO_EMBEDDER);
251 }
252 Err(e) => {
253 writeln!(
254 out.stderr,
255 "reembed: embedder init failed (backend={}, model={}, url={}, \
256 source={}): {e:#}",
257 resolved.backend,
258 resolved.model,
259 resolved.url,
260 resolved.source.as_str(),
261 )?;
262 return Ok(EXIT_EMBEDDER_INIT_FAILED);
263 }
264 };
265
266 let mut conn = db::open(db_path)?;
267 let ns = args.namespace.as_deref();
268 let target_model = embedder.model_description();
269 let target_dim = embedder.dim();
270 let plan = build_plan(&conn, ns, &target_model, target_dim, &resolved.backend)?;
271
272 let stored_dims = db::distinct_embedding_dims(&conn, ns)?;
274 writeln!(
275 out.stderr,
276 "reembed: PRE-FLIGHT — stored embedding dims: {stored_dims:?}; target: \
277 {target_dim}-dim ({target_model}); every scanned row's vector will be \
278 REPLACED (vector-space migration, #1598)"
279 )?;
280 if stored_dims.iter().any(|&d| d != target_dim) {
281 writeln!(
282 out.stderr,
283 "reembed: NOTE — stored dims {stored_dims:?} differ from target \
284 {target_dim}; recall dim-guards skip mismatched vectors until the \
285 sweep completes"
286 )?;
287 }
288
289 if args.dry_run {
290 if args.json {
291 writeln!(
292 out.stdout,
293 "{}",
294 serde_json::to_string(&json!({
295 "total_rows": plan.total_rows,
296 "rows_with_embeddings": plan.rows_with_embeddings,
297 "rows_missing_embeddings": plan.rows_missing_embeddings,
298 "target_model": plan.target_model,
299 "target_dim": plan.target_dim,
300 "backend": plan.backend,
301 }))?
302 )?;
303 } else {
304 writeln!(
305 out.stdout,
306 "reembed plan: total_rows={} rows_with_embeddings={} \
307 rows_missing_embeddings={} target_model='{}' target_dim={} \
308 backend={} (dry-run: nothing written)",
309 plan.total_rows,
310 plan.rows_with_embeddings,
311 plan.rows_missing_embeddings,
312 plan.target_model,
313 plan.target_dim,
314 plan.backend,
315 )?;
316 }
317 return Ok(0);
318 }
319
320 let batch_size = resolve_batch_size(args.batch, resolved.backfill_batch as usize);
321 let started = std::time::Instant::now();
322 let outcome = run_reembed_live(&mut conn, &embedder, ns, batch_size, out)?;
323 let duration_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
324
325 if args.json {
326 writeln!(
327 out.stdout,
328 "{}",
329 serde_json::to_string(&json!({
330 "total": outcome.total,
331 "reembedded": outcome.reembedded,
332 "skipped": outcome.skipped,
333 "model": target_model,
334 "dim": target_dim,
335 "duration_ms": duration_ms,
336 }))?
337 )?;
338 } else {
339 writeln!(
340 out.stdout,
341 "reembed: {}/{} re-embedded, {} skipped (model {target_model}, \
342 {target_dim}-dim, {duration_ms} ms)",
343 outcome.reembedded, outcome.total, outcome.skipped,
344 )?;
345 }
346 Ok(0)
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::models::{Memory, Tier};
353
354 fn seed(conn: &rusqlite::Connection, ns: &str, title: &str, content: &str) -> String {
355 let now = chrono::Utc::now().to_rfc3339();
356 let mem = Memory {
357 id: uuid::Uuid::new_v4().to_string(),
358 tier: Tier::Long,
359 namespace: ns.to_string(),
360 title: title.to_string(),
361 content: content.to_string(),
362 tags: vec![],
363 priority: 5,
364 confidence: 1.0,
365 source: "test".to_string(),
366 access_count: 0,
367 created_at: now.clone(),
368 updated_at: now,
369 last_accessed_at: None,
370 expires_at: None,
371 metadata: serde_json::json!({}),
372 reflection_depth: 0,
373 memory_kind: crate::models::MemoryKind::Observation,
374 entity_id: None,
375 persona_version: None,
376 citations: Vec::new(),
377 source_uri: None,
378 source_span: None,
379 confidence_source: crate::models::ConfidenceSource::CallerProvided,
380 confidence_signals: None,
381 confidence_decayed_at: None,
382 version: 1,
383 };
384 db::insert(conn, &mem).unwrap()
385 }
386
387 fn test_conn() -> rusqlite::Connection {
388 db::open(std::path::Path::new(":memory:")).unwrap()
389 }
390
391 struct FixedDimEmbedder {
393 dim: usize,
394 poison_marker: Option<&'static str>,
395 }
396 impl Embed for FixedDimEmbedder {
397 fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
398 if let Some(marker) = self.poison_marker
399 && text.contains(marker)
400 {
401 anyhow::bail!("test: synthetic per-row embed failure");
402 }
403 Ok(vec![0.5_f32; self.dim])
404 }
405 }
406
407 #[test]
410 fn build_plan_counts_and_namespace_filter_1598() {
411 let conn = test_conn();
412 let id_a = seed(&conn, "plan-a", "a-1", "content");
413 seed(&conn, "plan-a", "a-2", "content");
414 seed(&conn, "plan-b", "b-1", "content");
415 db::set_embedding(&conn, &id_a, &[0.1, 0.2]).unwrap();
416
417 let all = build_plan(&conn, None, "model-x (8-dim, remote)", 8, "openrouter").unwrap();
418 assert_eq!(
419 all,
420 ReembedPlan {
421 total_rows: 3,
422 rows_with_embeddings: 1,
423 rows_missing_embeddings: 2,
424 target_model: "model-x (8-dim, remote)".to_string(),
425 target_dim: 8,
426 backend: "openrouter".to_string(),
427 }
428 );
429
430 let only_a = build_plan(&conn, Some("plan-a"), "m", 8, "b").unwrap();
431 assert_eq!(only_a.total_rows, 2);
432 assert_eq!(only_a.rows_with_embeddings, 1);
433 assert_eq!(only_a.rows_missing_embeddings, 1);
434
435 let none = build_plan(&conn, Some("plan-nope"), "m", 8, "b").unwrap();
436 assert_eq!(none.total_rows, 0);
437 assert_eq!(none.rows_missing_embeddings, 0);
438 }
439
440 #[test]
444 fn live_run_replaces_existing_vectors_1598() {
445 let mut conn = test_conn();
446 let id_old = seed(&conn, "live-ns", "old", "already embedded");
447 let id_new = seed(&conn, "live-ns", "new", "never embedded");
448 db::set_embedding(&conn, &id_old, &[0.1, 0.2, 0.3, 0.4]).unwrap();
449
450 let emb = FixedDimEmbedder {
451 dim: 8,
452 poison_marker: None,
453 };
454 let mut stdout = Vec::<u8>::new();
455 let mut stderr = Vec::<u8>::new();
456 let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
457 let outcome = run_reembed_live(&mut conn, &emb, Some("live-ns"), 1, &mut out).unwrap();
458
459 assert_eq!(
460 outcome,
461 ReembedOutcome {
462 total: 2,
463 reembedded: 2,
464 skipped: 0,
465 }
466 );
467 assert_eq!(
468 db::get_embedding(&conn, &id_old).unwrap().unwrap().len(),
469 8,
470 "existing vector replaced at the new dim"
471 );
472 assert_eq!(db::get_embedding(&conn, &id_new).unwrap().unwrap().len(), 8);
473 }
474
475 #[test]
478 fn live_run_namespace_filter_leaves_others_untouched_1598() {
479 let mut conn = test_conn();
480 let id_in = seed(&conn, "ns-in", "in", "inside the filter");
481 let id_out = seed(&conn, "ns-out", "out", "outside the filter");
482 db::set_embedding(&conn, &id_out, &[0.9, 0.8]).unwrap();
483
484 let emb = FixedDimEmbedder {
485 dim: 4,
486 poison_marker: None,
487 };
488 let mut stdout = Vec::<u8>::new();
489 let mut stderr = Vec::<u8>::new();
490 let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
491 let outcome = run_reembed_live(&mut conn, &emb, Some("ns-in"), 16, &mut out).unwrap();
492
493 assert_eq!(outcome.total, 1);
494 assert_eq!(outcome.reembedded, 1);
495 assert_eq!(db::get_embedding(&conn, &id_in).unwrap().unwrap().len(), 4);
496 let untouched = db::get_embedding(&conn, &id_out).unwrap().unwrap();
497 assert_eq!(untouched.len(), 2, "out-of-namespace vector untouched");
498 }
499
500 #[test]
504 fn live_run_per_row_fallback_skips_poison_row_1598() {
505 const MARKER: &str = "reembed-poison-marker";
506 let mut conn = test_conn();
507 let id_ok_a = seed(&conn, "fb-ns", "ok-a", "healthy");
508 let id_bad = seed(&conn, "fb-ns", "bad", MARKER);
509 let id_ok_b = seed(&conn, "fb-ns", "ok-b", "healthy");
510 db::set_embedding(&conn, &id_bad, &[0.7, 0.7]).unwrap();
511
512 let emb = FixedDimEmbedder {
513 dim: 4,
514 poison_marker: Some(MARKER),
515 };
516 let mut stdout = Vec::<u8>::new();
517 let mut stderr = Vec::<u8>::new();
518 let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
519 let outcome = run_reembed_live(&mut conn, &emb, Some("fb-ns"), 16, &mut out).unwrap();
520
521 assert_eq!(
522 outcome,
523 ReembedOutcome {
524 total: 3,
525 reembedded: 2,
526 skipped: 1,
527 }
528 );
529 assert_eq!(
530 db::get_embedding(&conn, &id_ok_a).unwrap().unwrap().len(),
531 4
532 );
533 assert_eq!(
534 db::get_embedding(&conn, &id_ok_b).unwrap().unwrap().len(),
535 4
536 );
537 assert_eq!(
538 db::get_embedding(&conn, &id_bad).unwrap().unwrap().len(),
539 2,
540 "poison row keeps its previous vector"
541 );
542 let warn = String::from_utf8(stderr).unwrap();
543 assert!(
544 warn.contains(&id_bad) && warn.contains("skipped row"),
545 "WARN must name the skipped row id, got: {warn}"
546 );
547 }
548
549 #[test]
552 fn resolve_batch_size_precedence_1598() {
553 assert_eq!(resolve_batch_size(Some(7), 100), 7);
554 assert_eq!(
555 resolve_batch_size(Some(0), 100),
556 100,
557 "0 flag falls through"
558 );
559 assert_eq!(resolve_batch_size(None, 100), 100);
560 assert_eq!(
561 resolve_batch_size(None, 0),
562 crate::mcp::DEFAULT_EMBED_BACKFILL_BATCH_SIZE,
563 "double-degenerate input coerces to the compiled default"
564 );
565 }
566
567 #[test]
569 fn live_run_empty_corpus_is_noop_1598() {
570 let mut conn = test_conn();
571 let emb = FixedDimEmbedder {
572 dim: 4,
573 poison_marker: None,
574 };
575 let mut stdout = Vec::<u8>::new();
576 let mut stderr = Vec::<u8>::new();
577 let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
578 let outcome = run_reembed_live(&mut conn, &emb, None, 16, &mut out).unwrap();
579 assert_eq!(outcome, ReembedOutcome::default());
580 }
581}