1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13 pub query: String,
14 #[arg(short = 'k', long, default_value = "10")]
15 pub k: usize,
16 #[arg(long, value_enum)]
17 pub r#type: Option<MemoryType>,
18 #[arg(long)]
19 pub namespace: Option<String>,
20 #[arg(long)]
21 pub no_graph: bool,
22 #[arg(long)]
23 pub precise: bool,
24 #[arg(long, default_value = "2")]
25 pub max_hops: u32,
26 #[arg(long, default_value = "0.3")]
27 pub min_weight: f64,
28 #[arg(long, default_value = "1.0")]
32 pub min_distance: f32,
33 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
34 pub format: JsonOutputFormat,
35 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
36 pub db: Option<String>,
37 #[arg(long, hide = true)]
39 pub json: bool,
40}
41
42pub fn run(args: RecallArgs) -> Result<(), AppError> {
43 let start = std::time::Instant::now();
44 let _ = args.format;
45 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
46 let paths = AppPaths::resolve(args.db.as_deref())?;
47
48 output::emit_progress_i18n(
49 "Computing query embedding...",
50 "Calculando embedding da consulta...",
51 );
52 let embedder = crate::embedder::get_embedder(&paths.models)?;
53 let embedding = crate::embedder::embed_query(embedder, &args.query)?;
54
55 let conn = open_ro(&paths.db)?;
56
57 let memory_type_str = args.r#type.map(|t| t.as_str());
58 let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
59
60 let mut direct_matches = Vec::new();
61 let mut memory_ids: Vec<i64> = Vec::new();
62 for (memory_id, distance) in knn_results {
63 let row = {
64 let mut stmt = conn.prepare_cached(
65 "SELECT id, namespace, name, type, description, body, body_hash,
66 session_id, source, metadata, created_at, updated_at
67 FROM memories WHERE id=?1 AND deleted_at IS NULL",
68 )?;
69 stmt.query_row(rusqlite::params![memory_id], |r| {
70 Ok(memories::MemoryRow {
71 id: r.get(0)?,
72 namespace: r.get(1)?,
73 name: r.get(2)?,
74 memory_type: r.get(3)?,
75 description: r.get(4)?,
76 body: r.get(5)?,
77 body_hash: r.get(6)?,
78 session_id: r.get(7)?,
79 source: r.get(8)?,
80 metadata: r.get(9)?,
81 created_at: r.get(10)?,
82 updated_at: r.get(11)?,
83 })
84 })
85 .ok()
86 };
87 if let Some(row) = row {
88 let snippet: String = row.body.chars().take(300).collect();
89 direct_matches.push(RecallItem {
90 memory_id: row.id,
91 name: row.name,
92 namespace: row.namespace,
93 memory_type: row.memory_type,
94 description: row.description,
95 snippet,
96 distance,
97 source: "direct".to_string(),
98 });
99 memory_ids.push(memory_id);
100 }
101 }
102
103 let mut graph_matches = Vec::new();
104 if !args.no_graph {
105 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
106 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
107
108 let all_seed_ids: Vec<i64> = memory_ids
109 .iter()
110 .chain(entity_ids.iter())
111 .copied()
112 .collect();
113
114 if !all_seed_ids.is_empty() {
115 let graph_memory_ids = traverse_from_memories(
116 &conn,
117 &all_seed_ids,
118 &namespace,
119 args.min_weight,
120 args.max_hops,
121 )?;
122
123 for graph_mem_id in graph_memory_ids {
124 let row = {
125 let mut stmt = conn.prepare_cached(
126 "SELECT id, namespace, name, type, description, body, body_hash,
127 session_id, source, metadata, created_at, updated_at
128 FROM memories WHERE id=?1 AND deleted_at IS NULL",
129 )?;
130 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
131 Ok(memories::MemoryRow {
132 id: r.get(0)?,
133 namespace: r.get(1)?,
134 name: r.get(2)?,
135 memory_type: r.get(3)?,
136 description: r.get(4)?,
137 body: r.get(5)?,
138 body_hash: r.get(6)?,
139 session_id: r.get(7)?,
140 source: r.get(8)?,
141 metadata: r.get(9)?,
142 created_at: r.get(10)?,
143 updated_at: r.get(11)?,
144 })
145 })
146 .ok()
147 };
148 if let Some(row) = row {
149 let snippet: String = row.body.chars().take(300).collect();
150 graph_matches.push(RecallItem {
151 memory_id: row.id,
152 name: row.name,
153 namespace: row.namespace,
154 memory_type: row.memory_type,
155 description: row.description,
156 snippet,
157 distance: 0.0,
158 source: "graph".to_string(),
159 });
160 }
161 }
162 }
163 }
164
165 if args.min_distance < 1.0 {
167 let has_relevant = direct_matches
168 .iter()
169 .any(|item| item.distance <= args.min_distance);
170 if !has_relevant {
171 return Err(AppError::NotFound(erros::sem_resultados_recall(
172 args.min_distance,
173 &args.query,
174 &namespace,
175 )));
176 }
177 }
178
179 let results: Vec<RecallItem> = direct_matches
180 .iter()
181 .cloned()
182 .chain(graph_matches.iter().cloned())
183 .collect();
184
185 output::emit_json(&RecallResponse {
186 query: args.query,
187 k: args.k,
188 direct_matches,
189 graph_matches,
190 results,
191 elapsed_ms: start.elapsed().as_millis() as u64,
192 })?;
193
194 Ok(())
195}
196
197#[cfg(test)]
198mod testes {
199 use crate::output::{RecallItem, RecallResponse};
200
201 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
202 RecallItem {
203 memory_id: 1,
204 name: name.to_string(),
205 namespace: "global".to_string(),
206 memory_type: "fact".to_string(),
207 description: "desc".to_string(),
208 snippet: "snippet".to_string(),
209 distance,
210 source: source.to_string(),
211 }
212 }
213
214 #[test]
215 fn recall_response_serializa_campos_obrigatorios() {
216 let resp = RecallResponse {
217 query: "rust memory".to_string(),
218 k: 5,
219 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
220 graph_matches: vec![],
221 results: vec![make_item("mem-a", 0.12, "direct")],
222 elapsed_ms: 42,
223 };
224
225 let json = serde_json::to_value(&resp).expect("serialização falhou");
226 assert_eq!(json["query"], "rust memory");
227 assert_eq!(json["k"], 5);
228 assert_eq!(json["elapsed_ms"], 42u64);
229 assert!(json["direct_matches"].is_array());
230 assert!(json["graph_matches"].is_array());
231 assert!(json["results"].is_array());
232 }
233
234 #[test]
235 fn recall_item_serializa_type_renomeado() {
236 let item = make_item("mem-teste", 0.25, "direct");
237 let json = serde_json::to_value(&item).expect("serialização falhou");
238
239 assert_eq!(json["type"], "fact");
241 assert_eq!(json["distance"], 0.25f32);
242 assert_eq!(json["source"], "direct");
243 }
244
245 #[test]
246 fn recall_response_results_contem_direct_e_graph() {
247 let direct = make_item("d-mem", 0.10, "direct");
248 let graph = make_item("g-mem", 0.0, "graph");
249
250 let resp = RecallResponse {
251 query: "query".to_string(),
252 k: 10,
253 direct_matches: vec![direct.clone()],
254 graph_matches: vec![graph.clone()],
255 results: vec![direct, graph],
256 elapsed_ms: 10,
257 };
258
259 let json = serde_json::to_value(&resp).expect("serialização falhou");
260 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
261 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
262 assert_eq!(json["results"].as_array().unwrap().len(), 2);
263 assert_eq!(json["results"][0]["source"], "direct");
264 assert_eq!(json["results"][1]["source"], "graph");
265 }
266
267 #[test]
268 fn recall_response_vazio_serializa_arrays_vazios() {
269 let resp = RecallResponse {
270 query: "nada".to_string(),
271 k: 3,
272 direct_matches: vec![],
273 graph_matches: vec![],
274 results: vec![],
275 elapsed_ms: 1,
276 };
277
278 let json = serde_json::to_value(&resp).expect("serialização falhou");
279 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
280 assert_eq!(json["results"].as_array().unwrap().len(), 0);
281 }
282}