1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, OutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13 pub query: String,
14 #[arg(short = 'k', long, default_value = "10")]
15 pub k: usize,
16 #[arg(long, value_enum)]
17 pub r#type: Option<MemoryType>,
18 #[arg(long)]
19 pub namespace: Option<String>,
20 #[arg(long)]
21 pub no_graph: bool,
22 #[arg(long)]
23 pub precise: bool,
24 #[arg(long, default_value = "2")]
25 pub max_hops: u32,
26 #[arg(long, default_value = "0.3")]
27 pub min_weight: f64,
28 #[arg(long, default_value = "1.0")]
32 pub min_distance: f32,
33 #[arg(long, value_enum, default_value = "json")]
34 pub format: OutputFormat,
35 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
36 pub db: Option<String>,
37 #[arg(long, hide = true)]
39 pub json: bool,
40}
41
42pub fn run(args: RecallArgs) -> Result<(), AppError> {
43 let start = std::time::Instant::now();
44 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
45 let paths = AppPaths::resolve(args.db.as_deref())?;
46
47 output::emit_progress_i18n(
48 "Computing query embedding...",
49 "Calculando embedding da consulta...",
50 );
51 let embedder = crate::embedder::get_embedder(&paths.models)?;
52 let embedding = crate::embedder::embed_query(embedder, &args.query)?;
53
54 let conn = open_ro(&paths.db)?;
55
56 let memory_type_str = args.r#type.map(|t| t.as_str());
57 let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
58
59 let mut direct_matches = Vec::new();
60 let mut memory_ids: Vec<i64> = Vec::new();
61 for (memory_id, distance) in knn_results {
62 let row = {
63 let mut stmt = conn.prepare_cached(
64 "SELECT id, namespace, name, type, description, body, body_hash,
65 session_id, source, metadata, created_at, updated_at
66 FROM memories WHERE id=?1 AND deleted_at IS NULL",
67 )?;
68 stmt.query_row(rusqlite::params![memory_id], |r| {
69 Ok(memories::MemoryRow {
70 id: r.get(0)?,
71 namespace: r.get(1)?,
72 name: r.get(2)?,
73 memory_type: r.get(3)?,
74 description: r.get(4)?,
75 body: r.get(5)?,
76 body_hash: r.get(6)?,
77 session_id: r.get(7)?,
78 source: r.get(8)?,
79 metadata: r.get(9)?,
80 created_at: r.get(10)?,
81 updated_at: r.get(11)?,
82 })
83 })
84 .ok()
85 };
86 if let Some(row) = row {
87 let snippet: String = row.body.chars().take(300).collect();
88 direct_matches.push(RecallItem {
89 memory_id: row.id,
90 name: row.name,
91 namespace: row.namespace,
92 memory_type: row.memory_type,
93 description: row.description,
94 snippet,
95 distance,
96 source: "direct".to_string(),
97 });
98 memory_ids.push(memory_id);
99 }
100 }
101
102 let mut graph_matches = Vec::new();
103 if !args.no_graph {
104 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
105 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
106
107 let all_seed_ids: Vec<i64> = memory_ids
108 .iter()
109 .chain(entity_ids.iter())
110 .copied()
111 .collect();
112
113 if !all_seed_ids.is_empty() {
114 let graph_memory_ids = traverse_from_memories(
115 &conn,
116 &all_seed_ids,
117 &namespace,
118 args.min_weight,
119 args.max_hops,
120 )?;
121
122 for graph_mem_id in graph_memory_ids {
123 let row = {
124 let mut stmt = conn.prepare_cached(
125 "SELECT id, namespace, name, type, description, body, body_hash,
126 session_id, source, metadata, created_at, updated_at
127 FROM memories WHERE id=?1 AND deleted_at IS NULL",
128 )?;
129 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
130 Ok(memories::MemoryRow {
131 id: r.get(0)?,
132 namespace: r.get(1)?,
133 name: r.get(2)?,
134 memory_type: r.get(3)?,
135 description: r.get(4)?,
136 body: r.get(5)?,
137 body_hash: r.get(6)?,
138 session_id: r.get(7)?,
139 source: r.get(8)?,
140 metadata: r.get(9)?,
141 created_at: r.get(10)?,
142 updated_at: r.get(11)?,
143 })
144 })
145 .ok()
146 };
147 if let Some(row) = row {
148 let snippet: String = row.body.chars().take(300).collect();
149 graph_matches.push(RecallItem {
150 memory_id: row.id,
151 name: row.name,
152 namespace: row.namespace,
153 memory_type: row.memory_type,
154 description: row.description,
155 snippet,
156 distance: 0.0,
157 source: "graph".to_string(),
158 });
159 }
160 }
161 }
162 }
163
164 if args.min_distance < 1.0 {
166 let has_relevant = direct_matches
167 .iter()
168 .any(|item| item.distance <= args.min_distance);
169 if !has_relevant {
170 return Err(AppError::NotFound(erros::sem_resultados_recall(
171 args.min_distance,
172 &args.query,
173 &namespace,
174 )));
175 }
176 }
177
178 let results: Vec<RecallItem> = direct_matches
179 .iter()
180 .cloned()
181 .chain(graph_matches.iter().cloned())
182 .collect();
183
184 output::emit_json(&RecallResponse {
185 query: args.query,
186 k: args.k,
187 direct_matches,
188 graph_matches,
189 results,
190 elapsed_ms: start.elapsed().as_millis() as u64,
191 })?;
192
193 Ok(())
194}
195
196#[cfg(test)]
197mod testes {
198 use crate::output::{RecallItem, RecallResponse};
199
200 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
201 RecallItem {
202 memory_id: 1,
203 name: name.to_string(),
204 namespace: "global".to_string(),
205 memory_type: "fact".to_string(),
206 description: "desc".to_string(),
207 snippet: "snippet".to_string(),
208 distance,
209 source: source.to_string(),
210 }
211 }
212
213 #[test]
214 fn recall_response_serializa_campos_obrigatorios() {
215 let resp = RecallResponse {
216 query: "rust memory".to_string(),
217 k: 5,
218 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
219 graph_matches: vec![],
220 results: vec![make_item("mem-a", 0.12, "direct")],
221 elapsed_ms: 42,
222 };
223
224 let json = serde_json::to_value(&resp).expect("serialização falhou");
225 assert_eq!(json["query"], "rust memory");
226 assert_eq!(json["k"], 5);
227 assert_eq!(json["elapsed_ms"], 42u64);
228 assert!(json["direct_matches"].is_array());
229 assert!(json["graph_matches"].is_array());
230 assert!(json["results"].is_array());
231 }
232
233 #[test]
234 fn recall_item_serializa_type_renomeado() {
235 let item = make_item("mem-teste", 0.25, "direct");
236 let json = serde_json::to_value(&item).expect("serialização falhou");
237
238 assert_eq!(json["type"], "fact");
240 assert_eq!(json["distance"], 0.25f32);
241 assert_eq!(json["source"], "direct");
242 }
243
244 #[test]
245 fn recall_response_results_contem_direct_e_graph() {
246 let direct = make_item("d-mem", 0.10, "direct");
247 let graph = make_item("g-mem", 0.0, "graph");
248
249 let resp = RecallResponse {
250 query: "query".to_string(),
251 k: 10,
252 direct_matches: vec![direct.clone()],
253 graph_matches: vec![graph.clone()],
254 results: vec![direct, graph],
255 elapsed_ms: 10,
256 };
257
258 let json = serde_json::to_value(&resp).expect("serialização falhou");
259 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
260 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
261 assert_eq!(json["results"].as_array().unwrap().len(), 2);
262 assert_eq!(json["results"][0]["source"], "direct");
263 assert_eq!(json["results"][1]["source"], "graph");
264 }
265
266 #[test]
267 fn recall_response_vazio_serializa_arrays_vazios() {
268 let resp = RecallResponse {
269 query: "nada".to_string(),
270 k: 3,
271 direct_matches: vec![],
272 graph_matches: vec![],
273 results: vec![],
274 elapsed_ms: 1,
275 };
276
277 let json = serde_json::to_value(&resp).expect("serialização falhou");
278 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
279 assert_eq!(json["results"].as_array().unwrap().len(), 0);
280 }
281}