1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories_with_hops;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
18pub struct RecallArgs {
19 pub query: String,
20 #[arg(short = 'k', long, default_value = "10")]
26 pub k: usize,
27 #[arg(long, value_enum)]
31 pub r#type: Option<MemoryType>,
32 #[arg(long)]
33 pub namespace: Option<String>,
34 #[arg(long)]
35 pub no_graph: bool,
36 #[arg(long)]
42 pub precise: bool,
43 #[arg(long, default_value = "2")]
44 pub max_hops: u32,
45 #[arg(long, default_value = "0.3")]
46 pub min_weight: f64,
47 #[arg(long, value_name = "N")]
53 pub max_graph_results: Option<usize>,
54 #[arg(long, alias = "min-distance", default_value = "1.0")]
59 pub max_distance: f32,
60 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
61 pub format: JsonOutputFormat,
62 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
63 pub db: Option<String>,
64 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
66 pub json: bool,
67 #[arg(long, conflicts_with = "namespace")]
72 pub all_namespaces: bool,
73}
74
75pub fn run(args: RecallArgs) -> Result<(), AppError> {
76 let start = std::time::Instant::now();
77 let _ = args.format;
78 if args.query.trim().is_empty() {
79 return Err(AppError::Validation(
80 "query não pode estar vazia".to_string(),
81 ));
82 }
83 let namespaces: Vec<String> = if args.all_namespaces {
87 Vec::new()
88 } else {
89 vec![crate::namespace::resolve_namespace(
90 args.namespace.as_deref(),
91 )?]
92 };
93 let namespace_for_graph = namespaces
95 .first()
96 .cloned()
97 .unwrap_or_else(|| "global".to_string());
98 let paths = AppPaths::resolve(args.db.as_deref())?;
99
100 if !paths.db.exists() {
101 return Err(AppError::NotFound(erros::banco_nao_encontrado(
102 &paths.db.display().to_string(),
103 )));
104 }
105
106 output::emit_progress_i18n(
107 "Computing query embedding...",
108 "Calculando embedding da consulta...",
109 );
110 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
111
112 let conn = open_ro(&paths.db)?;
113
114 let memory_type_str = args.r#type.map(|t| t.as_str());
115 let effective_k = if args.precise { 100_000 } else { args.k };
118 let knn_results =
119 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
120
121 let mut direct_matches = Vec::new();
122 let mut memory_ids: Vec<i64> = Vec::new();
123 for (memory_id, distance) in knn_results {
124 let row = {
125 let mut stmt = conn.prepare_cached(
126 "SELECT id, namespace, name, type, description, body, body_hash,
127 session_id, source, metadata, created_at, updated_at
128 FROM memories WHERE id=?1 AND deleted_at IS NULL",
129 )?;
130 stmt.query_row(rusqlite::params![memory_id], |r| {
131 Ok(memories::MemoryRow {
132 id: r.get(0)?,
133 namespace: r.get(1)?,
134 name: r.get(2)?,
135 memory_type: r.get(3)?,
136 description: r.get(4)?,
137 body: r.get(5)?,
138 body_hash: r.get(6)?,
139 session_id: r.get(7)?,
140 source: r.get(8)?,
141 metadata: r.get(9)?,
142 created_at: r.get(10)?,
143 updated_at: r.get(11)?,
144 })
145 })
146 .ok()
147 };
148 if let Some(row) = row {
149 let snippet: String = row.body.chars().take(300).collect();
150 direct_matches.push(RecallItem {
151 memory_id: row.id,
152 name: row.name,
153 namespace: row.namespace,
154 memory_type: row.memory_type,
155 description: row.description,
156 snippet,
157 distance,
158 source: "direct".to_string(),
159 graph_depth: None,
161 });
162 memory_ids.push(memory_id);
163 }
164 }
165
166 let mut graph_matches = Vec::new();
167 if !args.no_graph {
168 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
169 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
170
171 let all_seed_ids: Vec<i64> = memory_ids
172 .iter()
173 .chain(entity_ids.iter())
174 .copied()
175 .collect();
176
177 if !all_seed_ids.is_empty() {
178 let graph_memory_ids = traverse_from_memories_with_hops(
179 &conn,
180 &all_seed_ids,
181 &namespace_for_graph,
182 args.min_weight,
183 args.max_hops,
184 )?;
185
186 for (graph_mem_id, hop) in graph_memory_ids {
187 if let Some(cap) = args.max_graph_results {
190 if graph_matches.len() >= cap {
191 break;
192 }
193 }
194 let row = {
195 let mut stmt = conn.prepare_cached(
196 "SELECT id, namespace, name, type, description, body, body_hash,
197 session_id, source, metadata, created_at, updated_at
198 FROM memories WHERE id=?1 AND deleted_at IS NULL",
199 )?;
200 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
201 Ok(memories::MemoryRow {
202 id: r.get(0)?,
203 namespace: r.get(1)?,
204 name: r.get(2)?,
205 memory_type: r.get(3)?,
206 description: r.get(4)?,
207 body: r.get(5)?,
208 body_hash: r.get(6)?,
209 session_id: r.get(7)?,
210 source: r.get(8)?,
211 metadata: r.get(9)?,
212 created_at: r.get(10)?,
213 updated_at: r.get(11)?,
214 })
215 })
216 .ok()
217 };
218 if let Some(row) = row {
219 let snippet: String = row.body.chars().take(300).collect();
220 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
224 graph_matches.push(RecallItem {
225 memory_id: row.id,
226 name: row.name,
227 namespace: row.namespace,
228 memory_type: row.memory_type,
229 description: row.description,
230 snippet,
231 distance: graph_distance,
232 source: "graph".to_string(),
233 graph_depth: Some(hop),
234 });
235 }
236 }
237 }
238 }
239
240 if args.max_distance < 1.0 {
242 let has_relevant = direct_matches
243 .iter()
244 .any(|item| item.distance <= args.max_distance);
245 if !has_relevant {
246 return Err(AppError::NotFound(erros::sem_resultados_recall(
247 args.max_distance,
248 &args.query,
249 &namespace_for_graph,
250 )));
251 }
252 }
253
254 let results: Vec<RecallItem> = direct_matches
255 .iter()
256 .cloned()
257 .chain(graph_matches.iter().cloned())
258 .collect();
259
260 output::emit_json(&RecallResponse {
261 query: args.query,
262 k: args.k,
263 direct_matches,
264 graph_matches,
265 results,
266 elapsed_ms: start.elapsed().as_millis() as u64,
267 })?;
268
269 Ok(())
270}
271
272#[cfg(test)]
273mod testes {
274 use crate::output::{RecallItem, RecallResponse};
275
276 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
277 RecallItem {
278 memory_id: 1,
279 name: name.to_string(),
280 namespace: "global".to_string(),
281 memory_type: "fact".to_string(),
282 description: "desc".to_string(),
283 snippet: "snippet".to_string(),
284 distance,
285 source: source.to_string(),
286 graph_depth: if source == "graph" { Some(0) } else { None },
287 }
288 }
289
290 #[test]
291 fn recall_response_serializa_campos_obrigatorios() {
292 let resp = RecallResponse {
293 query: "rust memory".to_string(),
294 k: 5,
295 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
296 graph_matches: vec![],
297 results: vec![make_item("mem-a", 0.12, "direct")],
298 elapsed_ms: 42,
299 };
300
301 let json = serde_json::to_value(&resp).expect("serialização falhou");
302 assert_eq!(json["query"], "rust memory");
303 assert_eq!(json["k"], 5);
304 assert_eq!(json["elapsed_ms"], 42u64);
305 assert!(json["direct_matches"].is_array());
306 assert!(json["graph_matches"].is_array());
307 assert!(json["results"].is_array());
308 }
309
310 #[test]
311 fn recall_item_serializa_type_renomeado() {
312 let item = make_item("mem-teste", 0.25, "direct");
313 let json = serde_json::to_value(&item).expect("serialização falhou");
314
315 assert_eq!(json["type"], "fact");
317 assert_eq!(json["distance"], 0.25f32);
318 assert_eq!(json["source"], "direct");
319 }
320
321 #[test]
322 fn recall_response_results_contem_direct_e_graph() {
323 let direct = make_item("d-mem", 0.10, "direct");
324 let graph = make_item("g-mem", 0.0, "graph");
325
326 let resp = RecallResponse {
327 query: "query".to_string(),
328 k: 10,
329 direct_matches: vec![direct.clone()],
330 graph_matches: vec![graph.clone()],
331 results: vec![direct, graph],
332 elapsed_ms: 10,
333 };
334
335 let json = serde_json::to_value(&resp).expect("serialização falhou");
336 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
337 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
338 assert_eq!(json["results"].as_array().unwrap().len(), 2);
339 assert_eq!(json["results"][0]["source"], "direct");
340 assert_eq!(json["results"][1]["source"], "graph");
341 }
342
343 #[test]
344 fn recall_response_vazio_serializa_arrays_vazios() {
345 let resp = RecallResponse {
346 query: "nada".to_string(),
347 k: 3,
348 direct_matches: vec![],
349 graph_matches: vec![],
350 results: vec![],
351 elapsed_ms: 1,
352 };
353
354 let json = serde_json::to_value(&resp).expect("serialização falhou");
355 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
356 assert_eq!(json["results"].as_array().unwrap().len(), 0);
357 }
358
359 #[test]
360 fn graph_matches_distance_uses_hop_count_proxy() {
361 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
367 for &(hop, expected) in cases {
368 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
369 assert!(
370 (d - expected).abs() < 0.001,
371 "hop={hop} expected={expected} got={d}"
372 );
373 }
374 }
375}