Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13    pub query: String,
14    #[arg(short = 'k', long, default_value = "10")]
15    pub k: usize,
16    #[arg(long, value_enum)]
17    pub r#type: Option<MemoryType>,
18    #[arg(long)]
19    pub namespace: Option<String>,
20    #[arg(long)]
21    pub no_graph: bool,
22    #[arg(long)]
23    pub precise: bool,
24    #[arg(long, default_value = "2")]
25    pub max_hops: u32,
26    #[arg(long, default_value = "0.3")]
27    pub min_weight: f64,
28    /// Filtrar resultados por distance máxima. Se todos os matches tiverem distance > min_distance,
29    /// comando sai com exit 4 (not found) conforme contrato documentado em AGENT_PROTOCOL.md.
30    /// Default 1.0 (desativado, mantém comportamento v2.0.0 de sempre retornar top-k).
31    #[arg(long, default_value = "1.0")]
32    pub min_distance: f32,
33    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
34    pub format: JsonOutputFormat,
35    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
36    pub db: Option<String>,
37    /// Aceita --json como no-op: output já é JSON por default.
38    #[arg(long, hide = true)]
39    pub json: bool,
40}
41
42pub fn run(args: RecallArgs) -> Result<(), AppError> {
43    let start = std::time::Instant::now();
44    let _ = args.format;
45    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
46    let paths = AppPaths::resolve(args.db.as_deref())?;
47
48    output::emit_progress_i18n(
49        "Computing query embedding...",
50        "Calculando embedding da consulta...",
51    );
52    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
53
54    let conn = open_ro(&paths.db)?;
55
56    let memory_type_str = args.r#type.map(|t| t.as_str());
57    let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
58
59    let mut direct_matches = Vec::new();
60    let mut memory_ids: Vec<i64> = Vec::new();
61    for (memory_id, distance) in knn_results {
62        let row = {
63            let mut stmt = conn.prepare_cached(
64                "SELECT id, namespace, name, type, description, body, body_hash,
65                        session_id, source, metadata, created_at, updated_at
66                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
67            )?;
68            stmt.query_row(rusqlite::params![memory_id], |r| {
69                Ok(memories::MemoryRow {
70                    id: r.get(0)?,
71                    namespace: r.get(1)?,
72                    name: r.get(2)?,
73                    memory_type: r.get(3)?,
74                    description: r.get(4)?,
75                    body: r.get(5)?,
76                    body_hash: r.get(6)?,
77                    session_id: r.get(7)?,
78                    source: r.get(8)?,
79                    metadata: r.get(9)?,
80                    created_at: r.get(10)?,
81                    updated_at: r.get(11)?,
82                })
83            })
84            .ok()
85        };
86        if let Some(row) = row {
87            let snippet: String = row.body.chars().take(300).collect();
88            direct_matches.push(RecallItem {
89                memory_id: row.id,
90                name: row.name,
91                namespace: row.namespace,
92                memory_type: row.memory_type,
93                description: row.description,
94                snippet,
95                distance,
96                source: "direct".to_string(),
97            });
98            memory_ids.push(memory_id);
99        }
100    }
101
102    let mut graph_matches = Vec::new();
103    if !args.no_graph {
104        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
105        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
106
107        let all_seed_ids: Vec<i64> = memory_ids
108            .iter()
109            .chain(entity_ids.iter())
110            .copied()
111            .collect();
112
113        if !all_seed_ids.is_empty() {
114            let graph_memory_ids = traverse_from_memories(
115                &conn,
116                &all_seed_ids,
117                &namespace,
118                args.min_weight,
119                args.max_hops,
120            )?;
121
122            for graph_mem_id in graph_memory_ids {
123                let row = {
124                    let mut stmt = conn.prepare_cached(
125                        "SELECT id, namespace, name, type, description, body, body_hash,
126                                session_id, source, metadata, created_at, updated_at
127                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
128                    )?;
129                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
130                        Ok(memories::MemoryRow {
131                            id: r.get(0)?,
132                            namespace: r.get(1)?,
133                            name: r.get(2)?,
134                            memory_type: r.get(3)?,
135                            description: r.get(4)?,
136                            body: r.get(5)?,
137                            body_hash: r.get(6)?,
138                            session_id: r.get(7)?,
139                            source: r.get(8)?,
140                            metadata: r.get(9)?,
141                            created_at: r.get(10)?,
142                            updated_at: r.get(11)?,
143                        })
144                    })
145                    .ok()
146                };
147                if let Some(row) = row {
148                    let snippet: String = row.body.chars().take(300).collect();
149                    graph_matches.push(RecallItem {
150                        memory_id: row.id,
151                        name: row.name,
152                        namespace: row.namespace,
153                        memory_type: row.memory_type,
154                        description: row.description,
155                        snippet,
156                        distance: 0.0,
157                        source: "graph".to_string(),
158                    });
159                }
160            }
161        }
162    }
163
164    // Filtrar por min_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
165    if args.min_distance < 1.0 {
166        let has_relevant = direct_matches
167            .iter()
168            .any(|item| item.distance <= args.min_distance);
169        if !has_relevant {
170            return Err(AppError::NotFound(erros::sem_resultados_recall(
171                args.min_distance,
172                &args.query,
173                &namespace,
174            )));
175        }
176    }
177
178    let results: Vec<RecallItem> = direct_matches
179        .iter()
180        .cloned()
181        .chain(graph_matches.iter().cloned())
182        .collect();
183
184    output::emit_json(&RecallResponse {
185        query: args.query,
186        k: args.k,
187        direct_matches,
188        graph_matches,
189        results,
190        elapsed_ms: start.elapsed().as_millis() as u64,
191    })?;
192
193    Ok(())
194}
195
196#[cfg(test)]
197mod testes {
198    use crate::output::{RecallItem, RecallResponse};
199
200    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
201        RecallItem {
202            memory_id: 1,
203            name: name.to_string(),
204            namespace: "global".to_string(),
205            memory_type: "fact".to_string(),
206            description: "desc".to_string(),
207            snippet: "snippet".to_string(),
208            distance,
209            source: source.to_string(),
210        }
211    }
212
213    #[test]
214    fn recall_response_serializa_campos_obrigatorios() {
215        let resp = RecallResponse {
216            query: "rust memory".to_string(),
217            k: 5,
218            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
219            graph_matches: vec![],
220            results: vec![make_item("mem-a", 0.12, "direct")],
221            elapsed_ms: 42,
222        };
223
224        let json = serde_json::to_value(&resp).expect("serialização falhou");
225        assert_eq!(json["query"], "rust memory");
226        assert_eq!(json["k"], 5);
227        assert_eq!(json["elapsed_ms"], 42u64);
228        assert!(json["direct_matches"].is_array());
229        assert!(json["graph_matches"].is_array());
230        assert!(json["results"].is_array());
231    }
232
233    #[test]
234    fn recall_item_serializa_type_renomeado() {
235        let item = make_item("mem-teste", 0.25, "direct");
236        let json = serde_json::to_value(&item).expect("serialização falhou");
237
238        // O campo memory_type é renomeado para "type" no JSON
239        assert_eq!(json["type"], "fact");
240        assert_eq!(json["distance"], 0.25f32);
241        assert_eq!(json["source"], "direct");
242    }
243
244    #[test]
245    fn recall_response_results_contem_direct_e_graph() {
246        let direct = make_item("d-mem", 0.10, "direct");
247        let graph = make_item("g-mem", 0.0, "graph");
248
249        let resp = RecallResponse {
250            query: "query".to_string(),
251            k: 10,
252            direct_matches: vec![direct.clone()],
253            graph_matches: vec![graph.clone()],
254            results: vec![direct, graph],
255            elapsed_ms: 10,
256        };
257
258        let json = serde_json::to_value(&resp).expect("serialização falhou");
259        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
260        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
261        assert_eq!(json["results"].as_array().unwrap().len(), 2);
262        assert_eq!(json["results"][0]["source"], "direct");
263        assert_eq!(json["results"][1]["source"], "graph");
264    }
265
266    #[test]
267    fn recall_response_vazio_serializa_arrays_vazios() {
268        let resp = RecallResponse {
269            query: "nada".to_string(),
270            k: 3,
271            direct_matches: vec![],
272            graph_matches: vec![],
273            results: vec![],
274            elapsed_ms: 1,
275        };
276
277        let json = serde_json::to_value(&resp).expect("serialização falhou");
278        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
279        assert_eq!(json["results"].as_array().unwrap().len(), 0);
280    }
281}