1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13 pub query: String,
14 #[arg(short = 'k', long, default_value = "10")]
15 pub k: usize,
16 #[arg(long, value_enum)]
17 pub r#type: Option<MemoryType>,
18 #[arg(long)]
19 pub namespace: Option<String>,
20 #[arg(long)]
21 pub no_graph: bool,
22 #[arg(long)]
23 pub precise: bool,
24 #[arg(long, default_value = "2")]
25 pub max_hops: u32,
26 #[arg(long, default_value = "0.3")]
27 pub min_weight: f64,
28 #[arg(long, alias = "min-distance", default_value = "1.0")]
33 pub max_distance: f32,
34 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
35 pub format: JsonOutputFormat,
36 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
37 pub db: Option<String>,
38 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40 pub json: bool,
41}
42
43pub fn run(args: RecallArgs) -> Result<(), AppError> {
44 let start = std::time::Instant::now();
45 let _ = args.format;
46 if args.query.trim().is_empty() {
47 return Err(AppError::Validation(
48 "query não pode estar vazia".to_string(),
49 ));
50 }
51 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
52 let paths = AppPaths::resolve(args.db.as_deref())?;
53
54 if !paths.db.exists() {
55 return Err(AppError::NotFound(erros::banco_nao_encontrado(
56 &paths.db.display().to_string(),
57 )));
58 }
59
60 output::emit_progress_i18n(
61 "Computing query embedding...",
62 "Calculando embedding da consulta...",
63 );
64 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
65
66 let conn = open_ro(&paths.db)?;
67
68 let memory_type_str = args.r#type.map(|t| t.as_str());
69 let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
70
71 let mut direct_matches = Vec::new();
72 let mut memory_ids: Vec<i64> = Vec::new();
73 for (memory_id, distance) in knn_results {
74 let row = {
75 let mut stmt = conn.prepare_cached(
76 "SELECT id, namespace, name, type, description, body, body_hash,
77 session_id, source, metadata, created_at, updated_at
78 FROM memories WHERE id=?1 AND deleted_at IS NULL",
79 )?;
80 stmt.query_row(rusqlite::params![memory_id], |r| {
81 Ok(memories::MemoryRow {
82 id: r.get(0)?,
83 namespace: r.get(1)?,
84 name: r.get(2)?,
85 memory_type: r.get(3)?,
86 description: r.get(4)?,
87 body: r.get(5)?,
88 body_hash: r.get(6)?,
89 session_id: r.get(7)?,
90 source: r.get(8)?,
91 metadata: r.get(9)?,
92 created_at: r.get(10)?,
93 updated_at: r.get(11)?,
94 })
95 })
96 .ok()
97 };
98 if let Some(row) = row {
99 let snippet: String = row.body.chars().take(300).collect();
100 direct_matches.push(RecallItem {
101 memory_id: row.id,
102 name: row.name,
103 namespace: row.namespace,
104 memory_type: row.memory_type,
105 description: row.description,
106 snippet,
107 distance,
108 source: "direct".to_string(),
109 });
110 memory_ids.push(memory_id);
111 }
112 }
113
114 let mut graph_matches = Vec::new();
115 if !args.no_graph {
116 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
117 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
118
119 let all_seed_ids: Vec<i64> = memory_ids
120 .iter()
121 .chain(entity_ids.iter())
122 .copied()
123 .collect();
124
125 if !all_seed_ids.is_empty() {
126 let graph_memory_ids = traverse_from_memories(
127 &conn,
128 &all_seed_ids,
129 &namespace,
130 args.min_weight,
131 args.max_hops,
132 )?;
133
134 for graph_mem_id in graph_memory_ids {
135 let row = {
136 let mut stmt = conn.prepare_cached(
137 "SELECT id, namespace, name, type, description, body, body_hash,
138 session_id, source, metadata, created_at, updated_at
139 FROM memories WHERE id=?1 AND deleted_at IS NULL",
140 )?;
141 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
142 Ok(memories::MemoryRow {
143 id: r.get(0)?,
144 namespace: r.get(1)?,
145 name: r.get(2)?,
146 memory_type: r.get(3)?,
147 description: r.get(4)?,
148 body: r.get(5)?,
149 body_hash: r.get(6)?,
150 session_id: r.get(7)?,
151 source: r.get(8)?,
152 metadata: r.get(9)?,
153 created_at: r.get(10)?,
154 updated_at: r.get(11)?,
155 })
156 })
157 .ok()
158 };
159 if let Some(row) = row {
160 let snippet: String = row.body.chars().take(300).collect();
161 graph_matches.push(RecallItem {
162 memory_id: row.id,
163 name: row.name,
164 namespace: row.namespace,
165 memory_type: row.memory_type,
166 description: row.description,
167 snippet,
168 distance: 0.0,
169 source: "graph".to_string(),
170 });
171 }
172 }
173 }
174 }
175
176 if args.max_distance < 1.0 {
178 let has_relevant = direct_matches
179 .iter()
180 .any(|item| item.distance <= args.max_distance);
181 if !has_relevant {
182 return Err(AppError::NotFound(erros::sem_resultados_recall(
183 args.max_distance,
184 &args.query,
185 &namespace,
186 )));
187 }
188 }
189
190 let results: Vec<RecallItem> = direct_matches
191 .iter()
192 .cloned()
193 .chain(graph_matches.iter().cloned())
194 .collect();
195
196 output::emit_json(&RecallResponse {
197 query: args.query,
198 k: args.k,
199 direct_matches,
200 graph_matches,
201 results,
202 elapsed_ms: start.elapsed().as_millis() as u64,
203 })?;
204
205 Ok(())
206}
207
208#[cfg(test)]
209mod testes {
210 use crate::output::{RecallItem, RecallResponse};
211
212 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
213 RecallItem {
214 memory_id: 1,
215 name: name.to_string(),
216 namespace: "global".to_string(),
217 memory_type: "fact".to_string(),
218 description: "desc".to_string(),
219 snippet: "snippet".to_string(),
220 distance,
221 source: source.to_string(),
222 }
223 }
224
225 #[test]
226 fn recall_response_serializa_campos_obrigatorios() {
227 let resp = RecallResponse {
228 query: "rust memory".to_string(),
229 k: 5,
230 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
231 graph_matches: vec![],
232 results: vec![make_item("mem-a", 0.12, "direct")],
233 elapsed_ms: 42,
234 };
235
236 let json = serde_json::to_value(&resp).expect("serialização falhou");
237 assert_eq!(json["query"], "rust memory");
238 assert_eq!(json["k"], 5);
239 assert_eq!(json["elapsed_ms"], 42u64);
240 assert!(json["direct_matches"].is_array());
241 assert!(json["graph_matches"].is_array());
242 assert!(json["results"].is_array());
243 }
244
245 #[test]
246 fn recall_item_serializa_type_renomeado() {
247 let item = make_item("mem-teste", 0.25, "direct");
248 let json = serde_json::to_value(&item).expect("serialização falhou");
249
250 assert_eq!(json["type"], "fact");
252 assert_eq!(json["distance"], 0.25f32);
253 assert_eq!(json["source"], "direct");
254 }
255
256 #[test]
257 fn recall_response_results_contem_direct_e_graph() {
258 let direct = make_item("d-mem", 0.10, "direct");
259 let graph = make_item("g-mem", 0.0, "graph");
260
261 let resp = RecallResponse {
262 query: "query".to_string(),
263 k: 10,
264 direct_matches: vec![direct.clone()],
265 graph_matches: vec![graph.clone()],
266 results: vec![direct, graph],
267 elapsed_ms: 10,
268 };
269
270 let json = serde_json::to_value(&resp).expect("serialização falhou");
271 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
272 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
273 assert_eq!(json["results"].as_array().unwrap().len(), 2);
274 assert_eq!(json["results"][0]["source"], "direct");
275 assert_eq!(json["results"][1]["source"], "graph");
276 }
277
278 #[test]
279 fn recall_response_vazio_serializa_arrays_vazios() {
280 let resp = RecallResponse {
281 query: "nada".to_string(),
282 k: 3,
283 direct_matches: vec![],
284 graph_matches: vec![],
285 results: vec![],
286 elapsed_ms: 1,
287 };
288
289 let json = serde_json::to_value(&resp).expect("serialização falhou");
290 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
291 assert_eq!(json["results"].as_array().unwrap().len(), 0);
292 }
293}