Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13    pub query: String,
14    #[arg(short = 'k', long, default_value = "10")]
15    pub k: usize,
16    #[arg(long, value_enum)]
17    pub r#type: Option<MemoryType>,
18    #[arg(long)]
19    pub namespace: Option<String>,
20    #[arg(long)]
21    pub no_graph: bool,
22    #[arg(long)]
23    pub precise: bool,
24    #[arg(long, default_value = "2")]
25    pub max_hops: u32,
26    #[arg(long, default_value = "0.3")]
27    pub min_weight: f64,
28    /// Filter results by maximum distance. Results with distance greater than this value
29    /// are excluded. If all matches exceed this threshold, the command exits with code 4
30    /// (`not found`) per the documented public contract.
31    /// Default `1.0` disables the filter and preserves the top-k behavior.
32    #[arg(long, alias = "min-distance", default_value = "1.0")]
33    pub max_distance: f32,
34    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
35    pub format: JsonOutputFormat,
36    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
37    pub db: Option<String>,
38    /// Accept `--json` as a no-op because output is already JSON by default.
39    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40    pub json: bool,
41}
42
43pub fn run(args: RecallArgs) -> Result<(), AppError> {
44    let start = std::time::Instant::now();
45    let _ = args.format;
46    if args.query.trim().is_empty() {
47        return Err(AppError::Validation(
48            "query não pode estar vazia".to_string(),
49        ));
50    }
51    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
52    let paths = AppPaths::resolve(args.db.as_deref())?;
53
54    if !paths.db.exists() {
55        return Err(AppError::NotFound(erros::banco_nao_encontrado(
56            &paths.db.display().to_string(),
57        )));
58    }
59
60    output::emit_progress_i18n(
61        "Computing query embedding...",
62        "Calculando embedding da consulta...",
63    );
64    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
65
66    let conn = open_ro(&paths.db)?;
67
68    let memory_type_str = args.r#type.map(|t| t.as_str());
69    let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
70
71    let mut direct_matches = Vec::new();
72    let mut memory_ids: Vec<i64> = Vec::new();
73    for (memory_id, distance) in knn_results {
74        let row = {
75            let mut stmt = conn.prepare_cached(
76                "SELECT id, namespace, name, type, description, body, body_hash,
77                        session_id, source, metadata, created_at, updated_at
78                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
79            )?;
80            stmt.query_row(rusqlite::params![memory_id], |r| {
81                Ok(memories::MemoryRow {
82                    id: r.get(0)?,
83                    namespace: r.get(1)?,
84                    name: r.get(2)?,
85                    memory_type: r.get(3)?,
86                    description: r.get(4)?,
87                    body: r.get(5)?,
88                    body_hash: r.get(6)?,
89                    session_id: r.get(7)?,
90                    source: r.get(8)?,
91                    metadata: r.get(9)?,
92                    created_at: r.get(10)?,
93                    updated_at: r.get(11)?,
94                })
95            })
96            .ok()
97        };
98        if let Some(row) = row {
99            let snippet: String = row.body.chars().take(300).collect();
100            direct_matches.push(RecallItem {
101                memory_id: row.id,
102                name: row.name,
103                namespace: row.namespace,
104                memory_type: row.memory_type,
105                description: row.description,
106                snippet,
107                distance,
108                source: "direct".to_string(),
109            });
110            memory_ids.push(memory_id);
111        }
112    }
113
114    let mut graph_matches = Vec::new();
115    if !args.no_graph {
116        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
117        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
118
119        let all_seed_ids: Vec<i64> = memory_ids
120            .iter()
121            .chain(entity_ids.iter())
122            .copied()
123            .collect();
124
125        if !all_seed_ids.is_empty() {
126            let graph_memory_ids = traverse_from_memories(
127                &conn,
128                &all_seed_ids,
129                &namespace,
130                args.min_weight,
131                args.max_hops,
132            )?;
133
134            for graph_mem_id in graph_memory_ids {
135                let row = {
136                    let mut stmt = conn.prepare_cached(
137                        "SELECT id, namespace, name, type, description, body, body_hash,
138                                session_id, source, metadata, created_at, updated_at
139                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
140                    )?;
141                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
142                        Ok(memories::MemoryRow {
143                            id: r.get(0)?,
144                            namespace: r.get(1)?,
145                            name: r.get(2)?,
146                            memory_type: r.get(3)?,
147                            description: r.get(4)?,
148                            body: r.get(5)?,
149                            body_hash: r.get(6)?,
150                            session_id: r.get(7)?,
151                            source: r.get(8)?,
152                            metadata: r.get(9)?,
153                            created_at: r.get(10)?,
154                            updated_at: r.get(11)?,
155                        })
156                    })
157                    .ok()
158                };
159                if let Some(row) = row {
160                    let snippet: String = row.body.chars().take(300).collect();
161                    graph_matches.push(RecallItem {
162                        memory_id: row.id,
163                        name: row.name,
164                        namespace: row.namespace,
165                        memory_type: row.memory_type,
166                        description: row.description,
167                        snippet,
168                        distance: 0.0,
169                        source: "graph".to_string(),
170                    });
171                }
172            }
173        }
174    }
175
176    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
177    if args.max_distance < 1.0 {
178        let has_relevant = direct_matches
179            .iter()
180            .any(|item| item.distance <= args.max_distance);
181        if !has_relevant {
182            return Err(AppError::NotFound(erros::sem_resultados_recall(
183                args.max_distance,
184                &args.query,
185                &namespace,
186            )));
187        }
188    }
189
190    let results: Vec<RecallItem> = direct_matches
191        .iter()
192        .cloned()
193        .chain(graph_matches.iter().cloned())
194        .collect();
195
196    output::emit_json(&RecallResponse {
197        query: args.query,
198        k: args.k,
199        direct_matches,
200        graph_matches,
201        results,
202        elapsed_ms: start.elapsed().as_millis() as u64,
203    })?;
204
205    Ok(())
206}
207
208#[cfg(test)]
209mod testes {
210    use crate::output::{RecallItem, RecallResponse};
211
212    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
213        RecallItem {
214            memory_id: 1,
215            name: name.to_string(),
216            namespace: "global".to_string(),
217            memory_type: "fact".to_string(),
218            description: "desc".to_string(),
219            snippet: "snippet".to_string(),
220            distance,
221            source: source.to_string(),
222        }
223    }
224
225    #[test]
226    fn recall_response_serializa_campos_obrigatorios() {
227        let resp = RecallResponse {
228            query: "rust memory".to_string(),
229            k: 5,
230            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
231            graph_matches: vec![],
232            results: vec![make_item("mem-a", 0.12, "direct")],
233            elapsed_ms: 42,
234        };
235
236        let json = serde_json::to_value(&resp).expect("serialização falhou");
237        assert_eq!(json["query"], "rust memory");
238        assert_eq!(json["k"], 5);
239        assert_eq!(json["elapsed_ms"], 42u64);
240        assert!(json["direct_matches"].is_array());
241        assert!(json["graph_matches"].is_array());
242        assert!(json["results"].is_array());
243    }
244
245    #[test]
246    fn recall_item_serializa_type_renomeado() {
247        let item = make_item("mem-teste", 0.25, "direct");
248        let json = serde_json::to_value(&item).expect("serialização falhou");
249
250        // O campo memory_type é renomeado para "type" no JSON
251        assert_eq!(json["type"], "fact");
252        assert_eq!(json["distance"], 0.25f32);
253        assert_eq!(json["source"], "direct");
254    }
255
256    #[test]
257    fn recall_response_results_contem_direct_e_graph() {
258        let direct = make_item("d-mem", 0.10, "direct");
259        let graph = make_item("g-mem", 0.0, "graph");
260
261        let resp = RecallResponse {
262            query: "query".to_string(),
263            k: 10,
264            direct_matches: vec![direct.clone()],
265            graph_matches: vec![graph.clone()],
266            results: vec![direct, graph],
267            elapsed_ms: 10,
268        };
269
270        let json = serde_json::to_value(&resp).expect("serialização falhou");
271        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
272        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
273        assert_eq!(json["results"].as_array().unwrap().len(), 2);
274        assert_eq!(json["results"][0]["source"], "direct");
275        assert_eq!(json["results"][1]["source"], "graph");
276    }
277
278    #[test]
279    fn recall_response_vazio_serializa_arrays_vazios() {
280        let resp = RecallResponse {
281            query: "nada".to_string(),
282            k: 3,
283            direct_matches: vec![],
284            graph_matches: vec![],
285            results: vec![],
286            elapsed_ms: 1,
287        };
288
289        let json = serde_json::to_value(&resp).expect("serialização falhou");
290        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
291        assert_eq!(json["results"].as_array().unwrap().len(), 0);
292    }
293}