1use crate::errors::AppError;
4use rusqlite::{params, Connection};
5
6pub fn traverse_from_memories(
40 conn: &Connection,
41 seed_memory_ids: &[i64],
42 namespace: &str,
43 min_weight: f64,
44 max_hops: u32,
45) -> Result<Vec<i64>, AppError> {
46 if seed_memory_ids.is_empty() || max_hops == 0 {
47 return Ok(vec![]);
48 }
49
50 let mut seed_entities: Vec<i64> = Vec::new();
52 for &mem_id in seed_memory_ids {
53 let mut stmt =
54 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
55 let ids: Vec<i64> = stmt
56 .query_map(params![mem_id], |r| r.get(0))?
57 .filter_map(|r| r.ok())
58 .collect();
59 seed_entities.extend(ids);
60 }
61 seed_entities.sort_unstable();
62 seed_entities.dedup();
63
64 if seed_entities.is_empty() {
65 return Ok(vec![]);
66 }
67
68 use std::collections::HashSet;
70 let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
71 let mut frontier = seed_entities.clone();
72
73 for _ in 0..max_hops {
74 if frontier.is_empty() {
75 break;
76 }
77 let mut next_frontier = Vec::new();
78
79 for &entity_id in &frontier {
80 let mut stmt = conn.prepare_cached(
81 "SELECT target_id FROM relationships
82 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
83 )?;
84 let neighbors: Vec<i64> = stmt
85 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
86 .filter_map(|r| r.ok())
87 .filter(|id| !visited.contains(id))
88 .collect();
89
90 for id in neighbors {
91 visited.insert(id);
92 next_frontier.push(id);
93 }
94 }
95 frontier = next_frontier;
96 }
97
98 let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
100 let graph_only_entities: Vec<i64> = visited
101 .into_iter()
102 .filter(|id| !seed_entities.contains(id))
103 .collect();
104
105 let mut result_ids: Vec<i64> = Vec::new();
106 for &entity_id in &graph_only_entities {
107 let mut stmt = conn.prepare_cached(
108 "SELECT DISTINCT me.memory_id
109 FROM memory_entities me
110 JOIN memories m ON m.id = me.memory_id
111 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
112 )?;
113 let mem_ids: Vec<i64> = stmt
114 .query_map(params![entity_id], |r| r.get(0))?
115 .filter_map(|r| r.ok())
116 .filter(|id| !seed_set.contains(id))
117 .collect();
118 result_ids.extend(mem_ids);
119 }
120
121 result_ids.sort_unstable();
122 result_ids.dedup();
123 Ok(result_ids)
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use rusqlite::Connection;
130
131 fn setup_db() -> Connection {
132 let conn = Connection::open_in_memory().unwrap();
133 conn.execute_batch(
134 "CREATE TABLE memories (
135 id INTEGER PRIMARY KEY,
136 namespace TEXT NOT NULL,
137 deleted_at TEXT
138 );
139 CREATE TABLE memory_entities (
140 memory_id INTEGER NOT NULL,
141 entity_id INTEGER NOT NULL
142 );
143 CREATE TABLE relationships (
144 source_id INTEGER NOT NULL,
145 target_id INTEGER NOT NULL,
146 weight REAL NOT NULL,
147 namespace TEXT NOT NULL
148 );",
149 )
150 .unwrap();
151 conn
152 }
153
154 fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
155 conn.execute(
156 "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
157 params![
158 id,
159 namespace,
160 if deleted { Some("2024-01-01") } else { None }
161 ],
162 )
163 .unwrap();
164 }
165
166 fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
167 conn.execute(
168 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
169 params![memory_id, entity_id],
170 )
171 .unwrap();
172 }
173
174 fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
175 conn.execute(
176 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
177 params![src, tgt, weight, ns],
178 )
179 .unwrap();
180 }
181
182 #[test]
185 fn retorna_vazio_quando_seeds_vazio() {
186 let conn = setup_db();
187 let resultado = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
188 assert!(resultado.is_empty());
189 }
190
191 #[test]
192 fn retorna_vazio_quando_max_hops_zero() {
193 let conn = setup_db();
194 insert_memory(&conn, 1, "ns", false);
195 link_memory_entity(&conn, 1, 10);
196 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
197 assert!(resultado.is_empty());
198 }
199
200 #[test]
201 fn retorna_vazio_quando_seed_sem_entidades() {
202 let conn = setup_db();
203 insert_memory(&conn, 1, "ns", false);
204 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
206 assert!(resultado.is_empty());
207 }
208
209 #[test]
210 fn retorna_vazio_quando_sem_relacionamentos() {
211 let conn = setup_db();
212 insert_memory(&conn, 1, "ns", false);
213 link_memory_entity(&conn, 1, 10);
214 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
216 assert!(resultado.is_empty());
217 }
218
219 #[test]
222 fn traversal_basico_um_hop() {
223 let conn = setup_db();
224
225 insert_memory(&conn, 1, "ns", false);
227 link_memory_entity(&conn, 1, 10);
228
229 insert_memory(&conn, 2, "ns", false);
231 link_memory_entity(&conn, 2, 20);
232
233 insert_relationship(&conn, 10, 20, 1.0, "ns");
235
236 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
237 assert_eq!(resultado, vec![2]);
238 }
239
240 #[test]
241 fn traversal_dois_hops() {
242 let conn = setup_db();
243
244 insert_memory(&conn, 1, "ns", false);
245 link_memory_entity(&conn, 1, 10);
246
247 insert_memory(&conn, 2, "ns", false);
248 link_memory_entity(&conn, 2, 20);
249
250 insert_memory(&conn, 3, "ns", false);
251 link_memory_entity(&conn, 3, 30);
252
253 insert_relationship(&conn, 10, 20, 1.0, "ns");
255 insert_relationship(&conn, 20, 30, 1.0, "ns");
256
257 let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
258 resultado.sort_unstable();
259 assert_eq!(resultado, vec![2, 3]);
260 }
261
262 #[test]
263 fn max_hops_limita_profundidade() {
264 let conn = setup_db();
265
266 insert_memory(&conn, 1, "ns", false);
267 link_memory_entity(&conn, 1, 10);
268
269 insert_memory(&conn, 2, "ns", false);
270 link_memory_entity(&conn, 2, 20);
271
272 insert_memory(&conn, 3, "ns", false);
273 link_memory_entity(&conn, 3, 30);
274
275 insert_relationship(&conn, 10, 20, 1.0, "ns");
276 insert_relationship(&conn, 20, 30, 1.0, "ns");
277
278 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
280 assert_eq!(resultado, vec![2]);
281 assert!(!resultado.contains(&3));
282 }
283
284 #[test]
287 fn relacionamento_com_peso_abaixo_do_minimo_ignorado() {
288 let conn = setup_db();
289
290 insert_memory(&conn, 1, "ns", false);
291 link_memory_entity(&conn, 1, 10);
292
293 insert_memory(&conn, 2, "ns", false);
294 link_memory_entity(&conn, 2, 20);
295
296 insert_relationship(&conn, 10, 20, 0.3, "ns");
298
299 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
300 assert!(resultado.is_empty());
301 }
302
303 #[test]
304 fn relacionamento_com_peso_exatamente_no_minimo_incluido() {
305 let conn = setup_db();
306
307 insert_memory(&conn, 1, "ns", false);
308 link_memory_entity(&conn, 1, 10);
309
310 insert_memory(&conn, 2, "ns", false);
311 link_memory_entity(&conn, 2, 20);
312
313 insert_relationship(&conn, 10, 20, 0.5, "ns");
314
315 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
316 assert_eq!(resultado, vec![2]);
317 }
318
319 #[test]
322 fn relacionamento_de_namespace_diferente_ignorado() {
323 let conn = setup_db();
324
325 insert_memory(&conn, 1, "ns_a", false);
326 link_memory_entity(&conn, 1, 10);
327
328 insert_memory(&conn, 2, "ns_a", false);
329 link_memory_entity(&conn, 2, 20);
330
331 insert_relationship(&conn, 10, 20, 1.0, "ns_b");
333
334 let resultado = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
335 assert!(resultado.is_empty());
336 }
337
338 #[test]
341 fn seeds_nao_aparecem_no_resultado() {
342 let conn = setup_db();
343
344 insert_memory(&conn, 1, "ns", false);
345 link_memory_entity(&conn, 1, 10);
346
347 insert_memory(&conn, 2, "ns", false);
348 link_memory_entity(&conn, 2, 20);
349
350 insert_relationship(&conn, 10, 20, 1.0, "ns");
352 insert_relationship(&conn, 20, 10, 1.0, "ns");
353
354 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
355 assert!(!resultado.contains(&1));
357 assert_eq!(resultado, vec![2]);
358 }
359
360 #[test]
363 fn memorias_deletadas_nao_incluidas() {
364 let conn = setup_db();
365
366 insert_memory(&conn, 1, "ns", false);
367 link_memory_entity(&conn, 1, 10);
368
369 insert_memory(&conn, 2, "ns", true);
371 link_memory_entity(&conn, 2, 20);
372
373 insert_relationship(&conn, 10, 20, 1.0, "ns");
374
375 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
376 assert!(resultado.is_empty());
377 }
378
379 #[test]
382 fn multiplos_seeds_unidos_no_resultado() {
383 let conn = setup_db();
384
385 insert_memory(&conn, 1, "ns", false);
386 link_memory_entity(&conn, 1, 10);
387
388 insert_memory(&conn, 2, "ns", false);
389 link_memory_entity(&conn, 2, 20);
390
391 insert_memory(&conn, 3, "ns", false);
392 link_memory_entity(&conn, 3, 30);
393
394 insert_memory(&conn, 4, "ns", false);
395 link_memory_entity(&conn, 4, 40);
396
397 insert_relationship(&conn, 10, 30, 1.0, "ns");
398 insert_relationship(&conn, 20, 40, 1.0, "ns");
399
400 let mut resultado = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
401 resultado.sort_unstable();
402 assert_eq!(resultado, vec![3, 4]);
403 }
404
405 #[test]
408 fn resultado_sem_duplicatas() {
409 let conn = setup_db();
410
411 insert_memory(&conn, 1, "ns", false);
412 link_memory_entity(&conn, 1, 10);
413 link_memory_entity(&conn, 1, 11); insert_memory(&conn, 2, "ns", false);
416 link_memory_entity(&conn, 2, 20);
417
418 insert_relationship(&conn, 10, 20, 1.0, "ns");
420 insert_relationship(&conn, 11, 20, 1.0, "ns");
421
422 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
423 assert_eq!(resultado.len(), 1);
425 assert_eq!(resultado, vec![2]);
426 }
427
428 #[test]
431 fn single_node_sem_vizinhos_retorna_vazio() {
432 let conn = setup_db();
433
434 insert_memory(&conn, 1, "ns", false);
435 link_memory_entity(&conn, 1, 10);
436 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
439 assert!(resultado.is_empty());
440 }
441
442 #[test]
445 fn ciclo_nao_causa_loop_infinito() {
446 let conn = setup_db();
447
448 insert_memory(&conn, 1, "ns", false);
449 link_memory_entity(&conn, 1, 10);
450
451 insert_memory(&conn, 2, "ns", false);
452 link_memory_entity(&conn, 2, 20);
453
454 insert_memory(&conn, 3, "ns", false);
455 link_memory_entity(&conn, 3, 30);
456
457 insert_relationship(&conn, 10, 20, 1.0, "ns");
459 insert_relationship(&conn, 20, 30, 1.0, "ns");
460 insert_relationship(&conn, 30, 10, 1.0, "ns");
461
462 let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
463 resultado.sort_unstable();
464 assert_eq!(resultado, vec![2, 3]);
466 }
467}