Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories_with_hops;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11/// Arguments for the `recall` subcommand.
12///
13/// When `--namespace` is omitted the query runs against the `global` namespace,
14/// which is the default namespace used by `remember` when no `--namespace` flag
15/// is provided. Pass an explicit `--namespace` value to search a different
16/// isolated namespace.
17#[derive(clap::Args)]
18pub struct RecallArgs {
19    pub query: String,
20    /// Maximum number of direct vector matches to return.
21    ///
22    /// Note: this flag controls only `direct_matches`. Graph traversal results
23    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
24    /// cap them independently. The `results` field aggregates both lists.
25    #[arg(short = 'k', long, default_value = "10")]
26    pub k: usize,
27    /// Filter by memory.type. Note: distinct from graph entity_type
28    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker)
29    /// used in --entities-file.
30    #[arg(long, value_enum)]
31    pub r#type: Option<MemoryType>,
32    #[arg(long)]
33    pub namespace: Option<String>,
34    #[arg(long)]
35    pub no_graph: bool,
36    /// Disable -k cap and return all direct matches without truncation.
37    ///
38    /// When set, the `-k`/`--k` flag is ignored for `direct_matches` and the
39    /// response includes every match above the distance threshold. Useful when
40    /// callers need the complete set rather than a top-N preview.
41    #[arg(long)]
42    pub precise: bool,
43    #[arg(long, default_value = "2")]
44    pub max_hops: u32,
45    #[arg(long, default_value = "0.3")]
46    pub min_weight: f64,
47    /// Cap the size of `graph_matches` to at most N entries.
48    ///
49    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
50    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
51    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
52    #[arg(long, value_name = "N")]
53    pub max_graph_results: Option<usize>,
54    /// Filter results by maximum distance. Results with distance greater than this value
55    /// are excluded. If all matches exceed this threshold, the command exits with code 4
56    /// (`not found`) per the documented public contract.
57    /// Default `1.0` disables the filter and preserves the top-k behavior.
58    #[arg(long, alias = "min-distance", default_value = "1.0")]
59    pub max_distance: f32,
60    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
61    pub format: JsonOutputFormat,
62    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
63    pub db: Option<String>,
64    /// Accept `--json` as a no-op because output is already JSON by default.
65    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
66    pub json: bool,
67    /// Search across all namespaces instead of a single namespace.
68    ///
69    /// Cannot be combined with `--namespace`. When set, the query runs against
70    /// every namespace and results include a `namespace` field to identify origin.
71    #[arg(long, conflicts_with = "namespace")]
72    pub all_namespaces: bool,
73}
74
75pub fn run(args: RecallArgs) -> Result<(), AppError> {
76    let start = std::time::Instant::now();
77    let _ = args.format;
78    if args.query.trim().is_empty() {
79        return Err(AppError::Validation(
80            "query não pode estar vazia".to_string(),
81        ));
82    }
83    // Resolve the list of namespaces to search:
84    // - empty vec  => all namespaces (sentinel used by knn_search)
85    // - single vec => one namespace (default or --namespace value)
86    let namespaces: Vec<String> = if args.all_namespaces {
87        Vec::new()
88    } else {
89        vec![crate::namespace::resolve_namespace(
90            args.namespace.as_deref(),
91        )?]
92    };
93    // Single namespace string used for graph traversal and error messages.
94    let namespace_for_graph = namespaces
95        .first()
96        .cloned()
97        .unwrap_or_else(|| "global".to_string());
98    let paths = AppPaths::resolve(args.db.as_deref())?;
99
100    if !paths.db.exists() {
101        return Err(AppError::NotFound(erros::banco_nao_encontrado(
102            &paths.db.display().to_string(),
103        )));
104    }
105
106    output::emit_progress_i18n(
107        "Computing query embedding...",
108        "Calculando embedding da consulta...",
109    );
110    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
111
112    let conn = open_ro(&paths.db)?;
113
114    let memory_type_str = args.r#type.map(|t| t.as_str());
115    // When --precise is set, lift the -k cap so every match is returned; the
116    // max_distance filter below will trim irrelevant results instead.
117    let effective_k = if args.precise { 100_000 } else { args.k };
118    let knn_results =
119        memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
120
121    let mut direct_matches = Vec::new();
122    let mut memory_ids: Vec<i64> = Vec::new();
123    for (memory_id, distance) in knn_results {
124        let row = {
125            let mut stmt = conn.prepare_cached(
126                "SELECT id, namespace, name, type, description, body, body_hash,
127                        session_id, source, metadata, created_at, updated_at
128                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
129            )?;
130            stmt.query_row(rusqlite::params![memory_id], |r| {
131                Ok(memories::MemoryRow {
132                    id: r.get(0)?,
133                    namespace: r.get(1)?,
134                    name: r.get(2)?,
135                    memory_type: r.get(3)?,
136                    description: r.get(4)?,
137                    body: r.get(5)?,
138                    body_hash: r.get(6)?,
139                    session_id: r.get(7)?,
140                    source: r.get(8)?,
141                    metadata: r.get(9)?,
142                    created_at: r.get(10)?,
143                    updated_at: r.get(11)?,
144                })
145            })
146            .ok()
147        };
148        if let Some(row) = row {
149            let snippet: String = row.body.chars().take(300).collect();
150            direct_matches.push(RecallItem {
151                memory_id: row.id,
152                name: row.name,
153                namespace: row.namespace,
154                memory_type: row.memory_type,
155                description: row.description,
156                snippet,
157                distance,
158                source: "direct".to_string(),
159                // Direct vector matches do not have a graph depth; rely on `distance`.
160                graph_depth: None,
161            });
162            memory_ids.push(memory_id);
163        }
164    }
165
166    let mut graph_matches = Vec::new();
167    if !args.no_graph {
168        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
169        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
170
171        let all_seed_ids: Vec<i64> = memory_ids
172            .iter()
173            .chain(entity_ids.iter())
174            .copied()
175            .collect();
176
177        if !all_seed_ids.is_empty() {
178            let graph_memory_ids = traverse_from_memories_with_hops(
179                &conn,
180                &all_seed_ids,
181                &namespace_for_graph,
182                args.min_weight,
183                args.max_hops,
184            )?;
185
186            for (graph_mem_id, hop) in graph_memory_ids {
187                // v1.0.23: respect the optional cap on graph results so dense
188                // neighbourhoods do not flood the response unintentionally.
189                if let Some(cap) = args.max_graph_results {
190                    if graph_matches.len() >= cap {
191                        break;
192                    }
193                }
194                let row = {
195                    let mut stmt = conn.prepare_cached(
196                        "SELECT id, namespace, name, type, description, body, body_hash,
197                                session_id, source, metadata, created_at, updated_at
198                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
199                    )?;
200                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
201                        Ok(memories::MemoryRow {
202                            id: r.get(0)?,
203                            namespace: r.get(1)?,
204                            name: r.get(2)?,
205                            memory_type: r.get(3)?,
206                            description: r.get(4)?,
207                            body: r.get(5)?,
208                            body_hash: r.get(6)?,
209                            session_id: r.get(7)?,
210                            source: r.get(8)?,
211                            metadata: r.get(9)?,
212                            created_at: r.get(10)?,
213                            updated_at: r.get(11)?,
214                        })
215                    })
216                    .ok()
217                };
218                if let Some(row) = row {
219                    let snippet: String = row.body.chars().take(300).collect();
220                    // Compute approximate distance from graph hop count.
221                    // Real cosine distance for graph matches is reserved for v1.0.26
222                    // (would require re-embedding, which adds 200-500ms latency).
223                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
224                    graph_matches.push(RecallItem {
225                        memory_id: row.id,
226                        name: row.name,
227                        namespace: row.namespace,
228                        memory_type: row.memory_type,
229                        description: row.description,
230                        snippet,
231                        distance: graph_distance,
232                        source: "graph".to_string(),
233                        graph_depth: Some(hop),
234                    });
235                }
236            }
237        }
238    }
239
240    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
241    if args.max_distance < 1.0 {
242        let has_relevant = direct_matches
243            .iter()
244            .any(|item| item.distance <= args.max_distance);
245        if !has_relevant {
246            return Err(AppError::NotFound(erros::sem_resultados_recall(
247                args.max_distance,
248                &args.query,
249                &namespace_for_graph,
250            )));
251        }
252    }
253
254    let results: Vec<RecallItem> = direct_matches
255        .iter()
256        .cloned()
257        .chain(graph_matches.iter().cloned())
258        .collect();
259
260    output::emit_json(&RecallResponse {
261        query: args.query,
262        k: args.k,
263        direct_matches,
264        graph_matches,
265        results,
266        elapsed_ms: start.elapsed().as_millis() as u64,
267    })?;
268
269    Ok(())
270}
271
272#[cfg(test)]
273mod testes {
274    use crate::output::{RecallItem, RecallResponse};
275
276    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
277        RecallItem {
278            memory_id: 1,
279            name: name.to_string(),
280            namespace: "global".to_string(),
281            memory_type: "fact".to_string(),
282            description: "desc".to_string(),
283            snippet: "snippet".to_string(),
284            distance,
285            source: source.to_string(),
286            graph_depth: if source == "graph" { Some(0) } else { None },
287        }
288    }
289
290    #[test]
291    fn recall_response_serializa_campos_obrigatorios() {
292        let resp = RecallResponse {
293            query: "rust memory".to_string(),
294            k: 5,
295            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
296            graph_matches: vec![],
297            results: vec![make_item("mem-a", 0.12, "direct")],
298            elapsed_ms: 42,
299        };
300
301        let json = serde_json::to_value(&resp).expect("serialização falhou");
302        assert_eq!(json["query"], "rust memory");
303        assert_eq!(json["k"], 5);
304        assert_eq!(json["elapsed_ms"], 42u64);
305        assert!(json["direct_matches"].is_array());
306        assert!(json["graph_matches"].is_array());
307        assert!(json["results"].is_array());
308    }
309
310    #[test]
311    fn recall_item_serializa_type_renomeado() {
312        let item = make_item("mem-teste", 0.25, "direct");
313        let json = serde_json::to_value(&item).expect("serialização falhou");
314
315        // O campo memory_type é renomeado para "type" no JSON
316        assert_eq!(json["type"], "fact");
317        assert_eq!(json["distance"], 0.25f32);
318        assert_eq!(json["source"], "direct");
319    }
320
321    #[test]
322    fn recall_response_results_contem_direct_e_graph() {
323        let direct = make_item("d-mem", 0.10, "direct");
324        let graph = make_item("g-mem", 0.0, "graph");
325
326        let resp = RecallResponse {
327            query: "query".to_string(),
328            k: 10,
329            direct_matches: vec![direct.clone()],
330            graph_matches: vec![graph.clone()],
331            results: vec![direct, graph],
332            elapsed_ms: 10,
333        };
334
335        let json = serde_json::to_value(&resp).expect("serialização falhou");
336        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
337        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
338        assert_eq!(json["results"].as_array().unwrap().len(), 2);
339        assert_eq!(json["results"][0]["source"], "direct");
340        assert_eq!(json["results"][1]["source"], "graph");
341    }
342
343    #[test]
344    fn recall_response_vazio_serializa_arrays_vazios() {
345        let resp = RecallResponse {
346            query: "nada".to_string(),
347            k: 3,
348            direct_matches: vec![],
349            graph_matches: vec![],
350            results: vec![],
351            elapsed_ms: 1,
352        };
353
354        let json = serde_json::to_value(&resp).expect("serialização falhou");
355        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
356        assert_eq!(json["results"].as_array().unwrap().len(), 0);
357    }
358
359    #[test]
360    fn graph_matches_distance_uses_hop_count_proxy() {
361        // Verify the hop-count proxy formula: 1.0 - 1.0 / (hop + 1.0)
362        // hop=0 → 0.0 (seed-level entity, identity distance)
363        // hop=1 → 0.5
364        // hop=2 → ≈ 0.667
365        // hop=3 → 0.75
366        let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
367        for &(hop, expected) in cases {
368            let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
369            assert!(
370                (d - expected).abs() < 0.001,
371                "hop={hop} expected={expected} got={d}"
372            );
373        }
374    }
375}