1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20#[command(after_long_help = "EXAMPLES:\n \
21 # Semantic search for top 5 matches\n \
22 sqlite-graphrag recall \"authentication design\" --k 5\n\n \
23 # Disable automatic graph expansion\n \
24 sqlite-graphrag recall \"JWT tokens\" --k 3 --no-graph\n\n \
25 # Limit graph traversal depth and minimum edge weight\n \
26 sqlite-graphrag recall \"auth\" --k 5 --max-hops 2 --min-weight 0.3\n\n \
27 # Filter by memory type\n \
28 sqlite-graphrag recall \"deployment\" --type decision --k 10\n\n \
29 # Cap results by distance threshold\n \
30 sqlite-graphrag recall \"API design\" --k 5 --max-distance 0.8\n\n \
31NOTES:\n \
32 When --no-graph is active, graph traversal is skipped and every result has\n \
33 source=\"direct\". The source field is therefore redundant with --no-graph and\n \
34 may be ignored by callers in that mode.")]
35pub struct RecallArgs {
36 #[arg(
37 allow_hyphen_values = true,
38 help = "Search query string (semantic vector search via sqlite-vec)"
39 )]
40 pub query: String,
41 #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
49 pub k: usize,
50 #[arg(long, value_enum)]
54 pub r#type: Option<MemoryType>,
55 #[arg(long)]
56 pub namespace: Option<String>,
57 #[arg(long)]
58 pub no_graph: bool,
59 #[arg(long)]
65 pub precise: bool,
66 #[arg(long, default_value = "2")]
67 pub max_hops: u32,
68 #[arg(long, default_value = "0.3")]
69 pub min_weight: f64,
70 #[arg(long, value_name = "N")]
76 pub max_graph_results: Option<usize>,
77 #[arg(long, alias = "min-distance", default_value = "1.0")]
82 pub max_distance: f32,
83 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
84 pub format: JsonOutputFormat,
85 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
86 pub db: Option<String>,
87 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
89 pub json: bool,
90 #[arg(long, conflicts_with = "namespace")]
95 pub all_namespaces: bool,
96 #[arg(
100 long,
101 help = "Skip live query embedding; use FTS5 BM25 + LIKE prefix only"
102 )]
103 pub fallback_fts_only: bool,
104}
105
106#[tracing::instrument(skip_all, level = "debug", name = "recall")]
107pub fn run(args: RecallArgs) -> Result<(), AppError> {
108 let start = std::time::Instant::now();
109 let _ = args.format;
110 tracing::debug!(target: "recall", query = %args.query, k = args.k, "searching");
111
112 if args.no_graph {
114 if args.max_hops != 2 {
115 return Err(AppError::Validation(
116 "--max-hops has no effect with --no-graph; remove one".to_string(),
117 ));
118 }
119 if (args.min_weight - 0.3).abs() > f64::EPSILON {
120 return Err(AppError::Validation(
121 "--min-weight has no effect with --no-graph; remove one".to_string(),
122 ));
123 }
124 }
125
126 if args.query.trim().is_empty() {
127 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
128 }
129 let namespaces: Vec<String> = if args.all_namespaces {
133 Vec::new()
134 } else {
135 vec![crate::namespace::resolve_namespace(
136 args.namespace.as_deref(),
137 )?]
138 };
139 let namespace_for_graph = namespaces
141 .first()
142 .cloned()
143 .unwrap_or_else(|| "global".to_string());
144 let paths = AppPaths::resolve(args.db.as_deref())?;
145
146 crate::storage::connection::ensure_db_ready(&paths)?;
147
148 output::emit_progress_i18n(
149 "Computing query embedding...",
150 "Calculando embedding da consulta...",
151 );
152 let conn = open_ro(&paths.db)?;
153 let (embedding, vec_degraded, vec_error) = if args.fallback_fts_only {
159 (None, true, Some("fallback_fts_only requested".to_string()))
160 } else {
161 match crate::embedder::try_embed_query_with_fallback(&paths.models, &args.query) {
162 Ok(v) => (Some(v), false, None),
163 Err(reason) => {
164 let msg = reason.to_string();
165 tracing::warn!(target: "recall", fallback_reason = %msg, "live embedding failed; falling back to FTS5");
166 (None, true, Some(msg))
167 }
168 }
169 };
170
171 let memory_type_str = args.r#type.map(|t| t.as_str());
172 let effective_k = if args.precise { 100_000 } else { args.k };
175
176 let (direct_matches, memory_ids): (Vec<RecallItem>, Vec<i64>) =
181 if let Some(emb) = embedding.as_ref() {
182 let knn_results =
183 memories::knn_search(&conn, emb, &namespaces, memory_type_str, effective_k)?;
184 let mut items: Vec<RecallItem> = Vec::with_capacity(knn_results.len());
185 let mut memory_ids: Vec<i64> = Vec::with_capacity(knn_results.len());
186 for (memory_id, distance) in knn_results {
187 let row = {
188 let mut stmt = conn.prepare_cached(
189 "SELECT id, namespace, name, type, description, body, body_hash,
190 session_id, source, metadata, created_at, updated_at
191 FROM memories WHERE id=?1 AND deleted_at IS NULL",
192 )?;
193 stmt.query_row(rusqlite::params![memory_id], |r| {
194 Ok(memories::MemoryRow {
195 id: r.get(0)?,
196 namespace: r.get(1)?,
197 name: r.get(2)?,
198 memory_type: r.get(3)?,
199 description: r.get(4)?,
200 body: r.get(5)?,
201 body_hash: r.get(6)?,
202 session_id: r.get(7)?,
203 source: r.get(8)?,
204 metadata: r.get(9)?,
205 created_at: r.get(10)?,
206 updated_at: r.get(11)?,
207 deleted_at: None,
208 })
209 })
210 .ok()
211 };
212 if let Some(row) = row {
213 let snippet: String = row.body.chars().take(300).collect();
214 items.push(RecallItem {
215 memory_id: row.id,
216 name: row.name,
217 namespace: row.namespace,
218 memory_type: row.memory_type,
219 description: row.description,
220 snippet,
221 distance,
222 score: RecallItem::score_from_distance(distance),
223 source: "direct".to_string(),
224 graph_depth: None,
225 });
226 memory_ids.push(memory_id);
227 }
228 }
229 (items, memory_ids)
230 } else {
231 let fts_rows = memories::fts_search(
237 &conn,
238 &args.query,
239 &namespace_for_graph,
240 memory_type_str,
241 effective_k,
242 )?;
243 let mut items: Vec<RecallItem> = Vec::with_capacity(fts_rows.len());
244 for (rank, row) in fts_rows.into_iter().enumerate() {
245 let dist = 1.0 - 1.0 / (rank as f32 + 1.0);
246 let snippet: String = row.body.chars().take(300).collect();
247 items.push(RecallItem {
248 memory_id: row.id,
249 name: row.name,
250 namespace: row.namespace,
251 memory_type: row.memory_type,
252 description: row.description,
253 snippet,
254 distance: dist,
255 score: RecallItem::score_from_distance(dist),
256 source: "fts_fallback".to_string(),
257 graph_depth: None,
258 });
259 }
260 (items, Vec::new())
261 };
262
263 let mut graph_matches = Vec::with_capacity(8);
264 if let Some(emb) = (!args.no_graph).then_some(()).and(embedding.as_ref()) {
265 let entity_knn = entities::knn_search(&conn, emb, &namespace_for_graph, 5)?;
266 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
267
268 let all_seed_ids: Vec<i64> = memory_ids
269 .iter()
270 .chain(entity_ids.iter())
271 .copied()
272 .collect();
273
274 if !all_seed_ids.is_empty() {
275 let graph_memory_ids = traverse_from_memories_with_hops(
276 &conn,
277 &all_seed_ids,
278 &namespace_for_graph,
279 args.min_weight,
280 args.max_hops,
281 )?;
282
283 for (graph_mem_id, hop) in graph_memory_ids {
284 if let Some(cap) = args.max_graph_results {
287 if graph_matches.len() >= cap {
288 break;
289 }
290 }
291 let row = {
292 let mut stmt = conn.prepare_cached(
293 "SELECT id, namespace, name, type, description, body, body_hash,
294 session_id, source, metadata, created_at, updated_at
295 FROM memories WHERE id=?1 AND deleted_at IS NULL",
296 )?;
297 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
298 Ok(memories::MemoryRow {
299 id: r.get(0)?,
300 namespace: r.get(1)?,
301 name: r.get(2)?,
302 memory_type: r.get(3)?,
303 description: r.get(4)?,
304 body: r.get(5)?,
305 body_hash: r.get(6)?,
306 session_id: r.get(7)?,
307 source: r.get(8)?,
308 metadata: r.get(9)?,
309 created_at: r.get(10)?,
310 updated_at: r.get(11)?,
311 deleted_at: None,
312 })
313 })
314 .ok()
315 };
316 if let Some(row) = row {
317 let snippet: String = row.body.chars().take(300).collect();
318 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
319 graph_matches.push(RecallItem {
320 memory_id: row.id,
321 name: row.name,
322 namespace: row.namespace,
323 memory_type: row.memory_type,
324 description: row.description,
325 snippet,
326 distance: graph_distance,
327 score: RecallItem::score_from_distance(graph_distance),
328 source: "graph".to_string(),
329 graph_depth: Some(hop),
330 });
331 }
332 }
333 }
334 }
335
336 if args.max_distance < 1.0 && !vec_degraded {
338 let has_relevant = direct_matches
339 .iter()
340 .any(|item| item.distance <= args.max_distance);
341 if !has_relevant {
342 return Err(AppError::NotFound(errors_msg::no_recall_results(
343 args.max_distance,
344 &args.query,
345 &namespace_for_graph,
346 )));
347 }
348 }
349
350 let results: Vec<RecallItem> = direct_matches
351 .iter()
352 .cloned()
353 .chain(graph_matches.iter().cloned())
354 .collect();
355
356 let warning = if vec_degraded {
357 Some(
358 "live query embedding unavailable; results are FTS5 BM25 only (semantic relevance reduced)"
359 .to_string(),
360 )
361 } else {
362 None
363 };
364
365 output::emit_json(&RecallResponse {
366 query: args.query,
367 k: args.k,
368 direct_matches,
369 graph_matches,
370 results,
371 elapsed_ms: start.elapsed().as_millis() as u64,
372 vec_degraded,
373 vec_error,
374 warning,
375 })?;
376
377 Ok(())
378}
379
380#[cfg(test)]
381mod tests {
382 use crate::output::{RecallItem, RecallResponse};
383
384 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
385 RecallItem {
386 memory_id: 1,
387 name: name.to_string(),
388 namespace: "global".to_string(),
389 memory_type: "fact".to_string(),
390 description: "desc".to_string(),
391 snippet: "snippet".to_string(),
392 distance,
393 score: RecallItem::score_from_distance(distance),
394 source: source.to_string(),
395 graph_depth: if source == "graph" { Some(0) } else { None },
396 }
397 }
398
399 #[test]
401 fn recall_item_score_is_present_and_finite_for_direct_match() {
402 let item = make_item("mem", 0.25, "direct");
403 let json = serde_json::to_value(&item).expect("serialization failed");
404 let score = json["score"].as_f64().expect("score must be a number");
405 assert!(
406 (0.0..=1.0).contains(&score),
407 "score must be in [0, 1], got {score}"
408 );
409 assert!(
410 (score - 0.75).abs() < 1e-6,
411 "score must equal 1 - distance for canonical case"
412 );
413 }
414
415 #[test]
416 fn recall_item_score_clamps_distance_outside_unit_range() {
417 assert_eq!(RecallItem::score_from_distance(2.0), 0.0);
419 assert_eq!(RecallItem::score_from_distance(-0.5), 1.0);
420 assert_eq!(RecallItem::score_from_distance(f32::NAN), 0.0);
421 }
422
423 #[test]
424 fn recall_response_serializes_required_fields() {
425 let resp = RecallResponse {
426 query: "rust memory".to_string(),
427 k: 5,
428 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
429 graph_matches: vec![],
430 results: vec![make_item("mem-a", 0.12, "direct")],
431 elapsed_ms: 42,
432 vec_degraded: false,
433 vec_error: None,
434 warning: None,
435 };
436
437 let json = serde_json::to_value(&resp).expect("serialization failed");
438 assert_eq!(json["query"], "rust memory");
439 assert_eq!(json["k"], 5);
440 assert_eq!(json["elapsed_ms"], 42u64);
441 assert!(json["direct_matches"].is_array());
442 assert!(json["graph_matches"].is_array());
443 assert!(json["results"].is_array());
444 }
445
446 #[test]
447 fn recall_item_serializes_renamed_type() {
448 let item = make_item("mem-test", 0.25, "direct");
449 let json = serde_json::to_value(&item).expect("serialization failed");
450
451 assert_eq!(json["type"], "fact");
453 assert_eq!(json["distance"], 0.25f32);
454 assert_eq!(json["source"], "direct");
455 }
456
457 #[test]
458 fn recall_response_results_contains_direct_and_graph() {
459 let direct = make_item("d-mem", 0.10, "direct");
460 let graph = make_item("g-mem", 0.0, "graph");
461
462 let resp = RecallResponse {
463 query: "query".to_string(),
464 k: 10,
465 direct_matches: vec![direct.clone()],
466 graph_matches: vec![graph.clone()],
467 results: vec![direct, graph],
468 elapsed_ms: 10,
469 vec_degraded: false,
470 vec_error: None,
471 warning: None,
472 };
473
474 let json = serde_json::to_value(&resp).expect("serialization failed");
475 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
476 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
477 assert_eq!(json["results"].as_array().unwrap().len(), 2);
478 assert_eq!(json["results"][0]["source"], "direct");
479 assert_eq!(json["results"][1]["source"], "graph");
480 }
481
482 #[test]
483 fn recall_response_empty_serializes_empty_arrays() {
484 let resp = RecallResponse {
485 query: "nothing".to_string(),
486 k: 3,
487 direct_matches: vec![],
488 graph_matches: vec![],
489 results: vec![],
490 elapsed_ms: 1,
491 vec_degraded: false,
492 vec_error: None,
493 warning: None,
494 };
495
496 let json = serde_json::to_value(&resp).expect("serialization failed");
497 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
498 assert_eq!(json["results"].as_array().unwrap().len(), 0);
499 }
500
501 #[test]
502 fn graph_matches_distance_uses_hop_count_proxy() {
503 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
509 for &(hop, expected) in cases {
510 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
511 assert!(
512 (d - expected).abs() < 0.001,
513 "hop={hop} expected={expected} got={d}"
514 );
515 }
516 }
517}