use std::path::Path;
use anyhow::Result;
use clap::Args;
use serde_json::json;
use crate::cli::CliOutput;
use crate::config::AppConfig;
use crate::db;
use crate::embeddings::Embed;
pub const EXIT_NO_EMBEDDER: i32 = 2;
pub const EXIT_EMBEDDER_INIT_FAILED: i32 = 3;
#[derive(Args, Debug, Clone)]
pub struct ReembedArgs {
#[arg(long)]
pub namespace: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub batch: Option<usize>,
#[arg(long)]
pub json: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ReembedPlan {
pub(crate) total_rows: u64,
pub(crate) rows_with_embeddings: u64,
pub(crate) rows_missing_embeddings: u64,
pub(crate) target_model: String,
pub(crate) target_dim: usize,
pub(crate) backend: String,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub(crate) struct ReembedOutcome {
pub(crate) total: usize,
pub(crate) reembedded: usize,
pub(crate) skipped: usize,
}
pub(crate) fn build_plan(
conn: &rusqlite::Connection,
namespace: Option<&str>,
target_model: &str,
target_dim: usize,
backend: &str,
) -> Result<ReembedPlan> {
let (total_rows, rows_with_embeddings) = db::embedding_coverage(conn, namespace)?;
Ok(ReembedPlan {
total_rows,
rows_with_embeddings,
rows_missing_embeddings: total_rows.saturating_sub(rows_with_embeddings),
target_model: target_model.to_string(),
target_dim,
backend: backend.to_string(),
})
}
pub(crate) fn resolve_batch_size(batch_flag: Option<usize>, resolved_default: usize) -> usize {
batch_flag
.filter(|&n| n > 0)
.or(Some(resolved_default).filter(|&n| n > 0))
.unwrap_or(crate::mcp::DEFAULT_EMBED_BACKFILL_BATCH_SIZE)
}
pub(crate) fn run_reembed_live(
conn: &mut rusqlite::Connection,
emb: &dyn Embed,
namespace: Option<&str>,
batch_size: usize,
out: &mut CliOutput<'_>,
) -> Result<ReembedOutcome> {
let mut outcome = ReembedOutcome::default();
let mut cursor: Option<String> = None;
loop {
let chunk = db::get_memory_texts_batch(conn, namespace, cursor.as_deref(), batch_size)?;
if chunk.is_empty() {
break;
}
outcome.total += chunk.len();
cursor = chunk.last().map(|(id, _, _)| id.clone());
let embedded = crate::mcp::embed_rows_with_fallback(emb, &chunk);
for (id, reason) in &embedded.skipped {
writeln!(
out.stderr,
"reembed: skipped row {id}: {reason} (previous vector kept, #1598)"
)?;
}
outcome.skipped += embedded.skipped.len();
if embedded.entries.is_empty() {
continue;
}
outcome.reembedded += db::set_embeddings_batch_reembed(conn, &embedded.entries)?;
}
Ok(outcome)
}
pub async fn cmd_reembed(
db_path: &Path,
args: &ReembedArgs,
app_config: &AppConfig,
out: &mut CliOutput<'_>,
) -> Result<i32> {
let feature_tier = app_config.effective_tier(None);
let tier_config = feature_tier.config();
let resolved = app_config.resolve_embeddings();
let tier_model = if crate::config::is_api_embed_backend(&resolved.backend) {
tier_config.embedding_model
} else {
crate::daemon_runtime::resolve_embedder_model(&tier_config, app_config)
};
let Some(tier_model) = tier_model else {
writeln!(
out.stderr,
"reembed: tier '{}' is keyword-only (no embedding model) — reembed \
requires an embedding-capable tier (set `tier = \"semantic\"` or \
above in config.toml, or configure [embeddings] / \
AI_MEMORY_EMBED_* for an API backend)",
feature_tier.as_str()
)?;
return Ok(EXIT_NO_EMBEDDER);
};
let resolved_for_build = resolved.clone();
let built = tokio::task::spawn_blocking(move || {
crate::embeddings::Embedder::from_resolved(&resolved_for_build, Some(tier_model))
})
.await?;
let embedder = match built {
Ok(Some(emb)) => emb,
Ok(None) => {
writeln!(
out.stderr,
"reembed: resolver returned no embedder for tier '{}'",
feature_tier.as_str()
)?;
return Ok(EXIT_NO_EMBEDDER);
}
Err(e) => {
writeln!(
out.stderr,
"reembed: embedder init failed (backend={}, model={}, url={}, \
source={}): {e:#}",
resolved.backend,
resolved.model,
resolved.url,
resolved.source.as_str(),
)?;
return Ok(EXIT_EMBEDDER_INIT_FAILED);
}
};
let mut conn = db::open(db_path)?;
let ns = args.namespace.as_deref();
let target_model = embedder.model_description();
let target_dim = embedder.dim();
let plan = build_plan(&conn, ns, &target_model, target_dim, &resolved.backend)?;
let stored_dims = db::distinct_embedding_dims(&conn, ns)?;
writeln!(
out.stderr,
"reembed: PRE-FLIGHT — stored embedding dims: {stored_dims:?}; target: \
{target_dim}-dim ({target_model}); every scanned row's vector will be \
REPLACED (vector-space migration, #1598)"
)?;
if stored_dims.iter().any(|&d| d != target_dim) {
writeln!(
out.stderr,
"reembed: NOTE — stored dims {stored_dims:?} differ from target \
{target_dim}; recall dim-guards skip mismatched vectors until the \
sweep completes"
)?;
}
if args.dry_run {
if args.json {
writeln!(
out.stdout,
"{}",
serde_json::to_string(&json!({
"total_rows": plan.total_rows,
"rows_with_embeddings": plan.rows_with_embeddings,
"rows_missing_embeddings": plan.rows_missing_embeddings,
"target_model": plan.target_model,
"target_dim": plan.target_dim,
"backend": plan.backend,
}))?
)?;
} else {
writeln!(
out.stdout,
"reembed plan: total_rows={} rows_with_embeddings={} \
rows_missing_embeddings={} target_model='{}' target_dim={} \
backend={} (dry-run: nothing written)",
plan.total_rows,
plan.rows_with_embeddings,
plan.rows_missing_embeddings,
plan.target_model,
plan.target_dim,
plan.backend,
)?;
}
return Ok(0);
}
let batch_size = resolve_batch_size(args.batch, resolved.backfill_batch as usize);
let started = std::time::Instant::now();
let outcome = run_reembed_live(&mut conn, &embedder, ns, batch_size, out)?;
let duration_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
if args.json {
writeln!(
out.stdout,
"{}",
serde_json::to_string(&json!({
"total": outcome.total,
"reembedded": outcome.reembedded,
"skipped": outcome.skipped,
"model": target_model,
"dim": target_dim,
"duration_ms": duration_ms,
}))?
)?;
} else {
writeln!(
out.stdout,
"reembed: {}/{} re-embedded, {} skipped (model {target_model}, \
{target_dim}-dim, {duration_ms} ms)",
outcome.reembedded, outcome.total, outcome.skipped,
)?;
}
Ok(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{Memory, Tier};
fn seed(conn: &rusqlite::Connection, ns: &str, title: &str, content: &str) -> String {
let now = chrono::Utc::now().to_rfc3339();
let mem = Memory {
id: uuid::Uuid::new_v4().to_string(),
tier: Tier::Long,
namespace: ns.to_string(),
title: title.to_string(),
content: content.to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: now.clone(),
updated_at: now,
last_accessed_at: None,
expires_at: None,
metadata: serde_json::json!({}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
};
db::insert(conn, &mem).unwrap()
}
fn test_conn() -> rusqlite::Connection {
db::open(std::path::Path::new(":memory:")).unwrap()
}
struct FixedDimEmbedder {
dim: usize,
poison_marker: Option<&'static str>,
}
impl Embed for FixedDimEmbedder {
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
if let Some(marker) = self.poison_marker
&& text.contains(marker)
{
anyhow::bail!("test: synthetic per-row embed failure");
}
Ok(vec![0.5_f32; self.dim])
}
}
#[test]
fn build_plan_counts_and_namespace_filter_1598() {
let conn = test_conn();
let id_a = seed(&conn, "plan-a", "a-1", "content");
seed(&conn, "plan-a", "a-2", "content");
seed(&conn, "plan-b", "b-1", "content");
db::set_embedding(&conn, &id_a, &[0.1, 0.2]).unwrap();
let all = build_plan(&conn, None, "model-x (8-dim, remote)", 8, "openrouter").unwrap();
assert_eq!(
all,
ReembedPlan {
total_rows: 3,
rows_with_embeddings: 1,
rows_missing_embeddings: 2,
target_model: "model-x (8-dim, remote)".to_string(),
target_dim: 8,
backend: "openrouter".to_string(),
}
);
let only_a = build_plan(&conn, Some("plan-a"), "m", 8, "b").unwrap();
assert_eq!(only_a.total_rows, 2);
assert_eq!(only_a.rows_with_embeddings, 1);
assert_eq!(only_a.rows_missing_embeddings, 1);
let none = build_plan(&conn, Some("plan-nope"), "m", 8, "b").unwrap();
assert_eq!(none.total_rows, 0);
assert_eq!(none.rows_missing_embeddings, 0);
}
#[test]
fn live_run_replaces_existing_vectors_1598() {
let mut conn = test_conn();
let id_old = seed(&conn, "live-ns", "old", "already embedded");
let id_new = seed(&conn, "live-ns", "new", "never embedded");
db::set_embedding(&conn, &id_old, &[0.1, 0.2, 0.3, 0.4]).unwrap();
let emb = FixedDimEmbedder {
dim: 8,
poison_marker: None,
};
let mut stdout = Vec::<u8>::new();
let mut stderr = Vec::<u8>::new();
let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
let outcome = run_reembed_live(&mut conn, &emb, Some("live-ns"), 1, &mut out).unwrap();
assert_eq!(
outcome,
ReembedOutcome {
total: 2,
reembedded: 2,
skipped: 0,
}
);
assert_eq!(
db::get_embedding(&conn, &id_old).unwrap().unwrap().len(),
8,
"existing vector replaced at the new dim"
);
assert_eq!(db::get_embedding(&conn, &id_new).unwrap().unwrap().len(), 8);
}
#[test]
fn live_run_namespace_filter_leaves_others_untouched_1598() {
let mut conn = test_conn();
let id_in = seed(&conn, "ns-in", "in", "inside the filter");
let id_out = seed(&conn, "ns-out", "out", "outside the filter");
db::set_embedding(&conn, &id_out, &[0.9, 0.8]).unwrap();
let emb = FixedDimEmbedder {
dim: 4,
poison_marker: None,
};
let mut stdout = Vec::<u8>::new();
let mut stderr = Vec::<u8>::new();
let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
let outcome = run_reembed_live(&mut conn, &emb, Some("ns-in"), 16, &mut out).unwrap();
assert_eq!(outcome.total, 1);
assert_eq!(outcome.reembedded, 1);
assert_eq!(db::get_embedding(&conn, &id_in).unwrap().unwrap().len(), 4);
let untouched = db::get_embedding(&conn, &id_out).unwrap().unwrap();
assert_eq!(untouched.len(), 2, "out-of-namespace vector untouched");
}
#[test]
fn live_run_per_row_fallback_skips_poison_row_1598() {
const MARKER: &str = "reembed-poison-marker";
let mut conn = test_conn();
let id_ok_a = seed(&conn, "fb-ns", "ok-a", "healthy");
let id_bad = seed(&conn, "fb-ns", "bad", MARKER);
let id_ok_b = seed(&conn, "fb-ns", "ok-b", "healthy");
db::set_embedding(&conn, &id_bad, &[0.7, 0.7]).unwrap();
let emb = FixedDimEmbedder {
dim: 4,
poison_marker: Some(MARKER),
};
let mut stdout = Vec::<u8>::new();
let mut stderr = Vec::<u8>::new();
let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
let outcome = run_reembed_live(&mut conn, &emb, Some("fb-ns"), 16, &mut out).unwrap();
assert_eq!(
outcome,
ReembedOutcome {
total: 3,
reembedded: 2,
skipped: 1,
}
);
assert_eq!(
db::get_embedding(&conn, &id_ok_a).unwrap().unwrap().len(),
4
);
assert_eq!(
db::get_embedding(&conn, &id_ok_b).unwrap().unwrap().len(),
4
);
assert_eq!(
db::get_embedding(&conn, &id_bad).unwrap().unwrap().len(),
2,
"poison row keeps its previous vector"
);
let warn = String::from_utf8(stderr).unwrap();
assert!(
warn.contains(&id_bad) && warn.contains("skipped row"),
"WARN must name the skipped row id, got: {warn}"
);
}
#[test]
fn resolve_batch_size_precedence_1598() {
assert_eq!(resolve_batch_size(Some(7), 100), 7);
assert_eq!(
resolve_batch_size(Some(0), 100),
100,
"0 flag falls through"
);
assert_eq!(resolve_batch_size(None, 100), 100);
assert_eq!(
resolve_batch_size(None, 0),
crate::mcp::DEFAULT_EMBED_BACKFILL_BATCH_SIZE,
"double-degenerate input coerces to the compiled default"
);
}
#[test]
fn live_run_empty_corpus_is_noop_1598() {
let mut conn = test_conn();
let emb = FixedDimEmbedder {
dim: 4,
poison_marker: None,
};
let mut stdout = Vec::<u8>::new();
let mut stderr = Vec::<u8>::new();
let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
let outcome = run_reembed_live(&mut conn, &emb, None, 16, &mut out).unwrap();
assert_eq!(outcome, ReembedOutcome::default());
}
}