Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13    pub query: String,
14    #[arg(short = 'k', long, default_value = "10")]
15    pub k: usize,
16    #[arg(long, value_enum)]
17    pub r#type: Option<MemoryType>,
18    #[arg(long)]
19    pub namespace: Option<String>,
20    #[arg(long)]
21    pub no_graph: bool,
22    #[arg(long)]
23    pub precise: bool,
24    #[arg(long, default_value = "2")]
25    pub max_hops: u32,
26    #[arg(long, default_value = "0.3")]
27    pub min_weight: f64,
28    /// Filtrar resultados por distance máxima. Se todos os matches tiverem distance > min_distance,
29    /// comando sai com exit 4 (not found) conforme contrato documentado em AGENT_PROTOCOL.md.
30    /// Default 1.0 (desativado, mantém comportamento v2.0.0 de sempre retornar top-k).
31    #[arg(long, default_value = "1.0")]
32    pub min_distance: f32,
33    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
34    pub format: JsonOutputFormat,
35    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
36    pub db: Option<String>,
37    /// Aceita --json como no-op: output já é JSON por default.
38    #[arg(long, hide = true)]
39    pub json: bool,
40}
41
42pub fn run(args: RecallArgs) -> Result<(), AppError> {
43    let start = std::time::Instant::now();
44    let _ = args.format;
45    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
46    let paths = AppPaths::resolve(args.db.as_deref())?;
47
48    output::emit_progress_i18n(
49        "Computing query embedding...",
50        "Calculando embedding da consulta...",
51    );
52    let embedder = crate::embedder::get_embedder(&paths.models)?;
53    let embedding = crate::embedder::embed_query(embedder, &args.query)?;
54
55    let conn = open_ro(&paths.db)?;
56
57    let memory_type_str = args.r#type.map(|t| t.as_str());
58    let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
59
60    let mut direct_matches = Vec::new();
61    let mut memory_ids: Vec<i64> = Vec::new();
62    for (memory_id, distance) in knn_results {
63        let row = {
64            let mut stmt = conn.prepare_cached(
65                "SELECT id, namespace, name, type, description, body, body_hash,
66                        session_id, source, metadata, created_at, updated_at
67                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
68            )?;
69            stmt.query_row(rusqlite::params![memory_id], |r| {
70                Ok(memories::MemoryRow {
71                    id: r.get(0)?,
72                    namespace: r.get(1)?,
73                    name: r.get(2)?,
74                    memory_type: r.get(3)?,
75                    description: r.get(4)?,
76                    body: r.get(5)?,
77                    body_hash: r.get(6)?,
78                    session_id: r.get(7)?,
79                    source: r.get(8)?,
80                    metadata: r.get(9)?,
81                    created_at: r.get(10)?,
82                    updated_at: r.get(11)?,
83                })
84            })
85            .ok()
86        };
87        if let Some(row) = row {
88            let snippet: String = row.body.chars().take(300).collect();
89            direct_matches.push(RecallItem {
90                memory_id: row.id,
91                name: row.name,
92                namespace: row.namespace,
93                memory_type: row.memory_type,
94                description: row.description,
95                snippet,
96                distance,
97                source: "direct".to_string(),
98            });
99            memory_ids.push(memory_id);
100        }
101    }
102
103    let mut graph_matches = Vec::new();
104    if !args.no_graph {
105        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
106        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
107
108        let all_seed_ids: Vec<i64> = memory_ids
109            .iter()
110            .chain(entity_ids.iter())
111            .copied()
112            .collect();
113
114        if !all_seed_ids.is_empty() {
115            let graph_memory_ids = traverse_from_memories(
116                &conn,
117                &all_seed_ids,
118                &namespace,
119                args.min_weight,
120                args.max_hops,
121            )?;
122
123            for graph_mem_id in graph_memory_ids {
124                let row = {
125                    let mut stmt = conn.prepare_cached(
126                        "SELECT id, namespace, name, type, description, body, body_hash,
127                                session_id, source, metadata, created_at, updated_at
128                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
129                    )?;
130                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
131                        Ok(memories::MemoryRow {
132                            id: r.get(0)?,
133                            namespace: r.get(1)?,
134                            name: r.get(2)?,
135                            memory_type: r.get(3)?,
136                            description: r.get(4)?,
137                            body: r.get(5)?,
138                            body_hash: r.get(6)?,
139                            session_id: r.get(7)?,
140                            source: r.get(8)?,
141                            metadata: r.get(9)?,
142                            created_at: r.get(10)?,
143                            updated_at: r.get(11)?,
144                        })
145                    })
146                    .ok()
147                };
148                if let Some(row) = row {
149                    let snippet: String = row.body.chars().take(300).collect();
150                    graph_matches.push(RecallItem {
151                        memory_id: row.id,
152                        name: row.name,
153                        namespace: row.namespace,
154                        memory_type: row.memory_type,
155                        description: row.description,
156                        snippet,
157                        distance: 0.0,
158                        source: "graph".to_string(),
159                    });
160                }
161            }
162        }
163    }
164
165    // Filtrar por min_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
166    if args.min_distance < 1.0 {
167        let has_relevant = direct_matches
168            .iter()
169            .any(|item| item.distance <= args.min_distance);
170        if !has_relevant {
171            return Err(AppError::NotFound(erros::sem_resultados_recall(
172                args.min_distance,
173                &args.query,
174                &namespace,
175            )));
176        }
177    }
178
179    let results: Vec<RecallItem> = direct_matches
180        .iter()
181        .cloned()
182        .chain(graph_matches.iter().cloned())
183        .collect();
184
185    output::emit_json(&RecallResponse {
186        query: args.query,
187        k: args.k,
188        direct_matches,
189        graph_matches,
190        results,
191        elapsed_ms: start.elapsed().as_millis() as u64,
192    })?;
193
194    Ok(())
195}
196
197#[cfg(test)]
198mod testes {
199    use crate::output::{RecallItem, RecallResponse};
200
201    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
202        RecallItem {
203            memory_id: 1,
204            name: name.to_string(),
205            namespace: "global".to_string(),
206            memory_type: "fact".to_string(),
207            description: "desc".to_string(),
208            snippet: "snippet".to_string(),
209            distance,
210            source: source.to_string(),
211        }
212    }
213
214    #[test]
215    fn recall_response_serializa_campos_obrigatorios() {
216        let resp = RecallResponse {
217            query: "rust memory".to_string(),
218            k: 5,
219            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
220            graph_matches: vec![],
221            results: vec![make_item("mem-a", 0.12, "direct")],
222            elapsed_ms: 42,
223        };
224
225        let json = serde_json::to_value(&resp).expect("serialização falhou");
226        assert_eq!(json["query"], "rust memory");
227        assert_eq!(json["k"], 5);
228        assert_eq!(json["elapsed_ms"], 42u64);
229        assert!(json["direct_matches"].is_array());
230        assert!(json["graph_matches"].is_array());
231        assert!(json["results"].is_array());
232    }
233
234    #[test]
235    fn recall_item_serializa_type_renomeado() {
236        let item = make_item("mem-teste", 0.25, "direct");
237        let json = serde_json::to_value(&item).expect("serialização falhou");
238
239        // O campo memory_type é renomeado para "type" no JSON
240        assert_eq!(json["type"], "fact");
241        assert_eq!(json["distance"], 0.25f32);
242        assert_eq!(json["source"], "direct");
243    }
244
245    #[test]
246    fn recall_response_results_contem_direct_e_graph() {
247        let direct = make_item("d-mem", 0.10, "direct");
248        let graph = make_item("g-mem", 0.0, "graph");
249
250        let resp = RecallResponse {
251            query: "query".to_string(),
252            k: 10,
253            direct_matches: vec![direct.clone()],
254            graph_matches: vec![graph.clone()],
255            results: vec![direct, graph],
256            elapsed_ms: 10,
257        };
258
259        let json = serde_json::to_value(&resp).expect("serialização falhou");
260        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
261        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
262        assert_eq!(json["results"].as_array().unwrap().len(), 2);
263        assert_eq!(json["results"][0]["source"], "direct");
264        assert_eq!(json["results"][1]["source"], "graph");
265    }
266
267    #[test]
268    fn recall_response_vazio_serializa_arrays_vazios() {
269        let resp = RecallResponse {
270            query: "nada".to_string(),
271            k: 3,
272            direct_matches: vec![],
273            graph_matches: vec![],
274            results: vec![],
275            elapsed_ms: 1,
276        };
277
278        let json = serde_json::to_value(&resp).expect("serialização falhou");
279        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
280        assert_eq!(json["results"].as_array().unwrap().len(), 0);
281    }
282}