1use crate::errors::AppError;
4use rusqlite::{params, Connection};
5
6pub fn traverse_from_memories(
40 conn: &Connection,
41 seed_memory_ids: &[i64],
42 namespace: &str,
43 min_weight: f64,
44 max_hops: u32,
45) -> Result<Vec<i64>, AppError> {
46 if seed_memory_ids.is_empty() || max_hops == 0 {
47 return Ok(vec![]);
48 }
49
50 let mut seed_entities: Vec<i64> = Vec::new();
52 for &mem_id in seed_memory_ids {
53 let mut stmt =
54 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
55 let ids: Vec<i64> = stmt
56 .query_map(params![mem_id], |r| r.get(0))?
57 .filter_map(|r| r.ok())
58 .collect();
59 seed_entities.extend(ids);
60 }
61 seed_entities.sort_unstable();
62 seed_entities.dedup();
63
64 if seed_entities.is_empty() {
65 return Ok(vec![]);
66 }
67
68 use std::collections::HashSet;
70 let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
71 let mut frontier = seed_entities.clone();
72
73 for _ in 0..max_hops {
74 if frontier.is_empty() {
75 break;
76 }
77 let mut next_frontier = Vec::new();
78
79 for &entity_id in &frontier {
80 let mut stmt = conn.prepare_cached(
81 "SELECT target_id FROM relationships
82 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
83 )?;
84 let neighbors: Vec<i64> = stmt
85 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
86 .filter_map(|r| r.ok())
87 .filter(|id| !visited.contains(id))
88 .collect();
89
90 for id in neighbors {
91 visited.insert(id);
92 next_frontier.push(id);
93 }
94 }
95 frontier = next_frontier;
96 }
97
98 let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
100 let graph_only_entities: Vec<i64> = visited
101 .into_iter()
102 .filter(|id| !seed_entities.contains(id))
103 .collect();
104
105 let mut result_ids: Vec<i64> = Vec::new();
106 for &entity_id in &graph_only_entities {
107 let mut stmt = conn.prepare_cached(
108 "SELECT DISTINCT me.memory_id
109 FROM memory_entities me
110 JOIN memories m ON m.id = me.memory_id
111 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
112 )?;
113 let mem_ids: Vec<i64> = stmt
114 .query_map(params![entity_id], |r| r.get(0))?
115 .filter_map(|r| r.ok())
116 .filter(|id| !seed_set.contains(id))
117 .collect();
118 result_ids.extend(mem_ids);
119 }
120
121 result_ids.sort_unstable();
122 result_ids.dedup();
123 Ok(result_ids)
124}
125
126pub fn traverse_from_memories_with_hops(
136 conn: &Connection,
137 seed_memory_ids: &[i64],
138 namespace: &str,
139 min_weight: f64,
140 max_hops: u32,
141) -> Result<Vec<(i64, u32)>, AppError> {
142 if seed_memory_ids.is_empty() || max_hops == 0 {
143 return Ok(vec![]);
144 }
145
146 let mut seed_entities: Vec<i64> = Vec::new();
148 for &mem_id in seed_memory_ids {
149 let mut stmt =
150 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151 let ids: Vec<i64> = stmt
152 .query_map(params![mem_id], |r| r.get(0))?
153 .filter_map(|r| r.ok())
154 .collect();
155 seed_entities.extend(ids);
156 }
157 seed_entities.sort_unstable();
158 seed_entities.dedup();
159
160 if seed_entities.is_empty() {
161 return Ok(vec![]);
162 }
163
164 use std::collections::HashMap;
166 let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
167 let mut frontier = seed_entities.clone();
168
169 for hop in 1..=max_hops {
170 if frontier.is_empty() {
171 break;
172 }
173 let mut next_frontier = Vec::new();
174
175 for &entity_id in &frontier {
176 let mut stmt = conn.prepare_cached(
177 "SELECT target_id FROM relationships
178 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
179 )?;
180 let neighbors: Vec<i64> = stmt
181 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
182 .filter_map(|r| r.ok())
183 .filter(|id| !entity_depth.contains_key(id))
184 .collect();
185
186 for id in neighbors {
187 entity_depth.insert(id, hop);
188 next_frontier.push(id);
189 }
190 }
191 frontier = next_frontier;
192 }
193
194 let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().cloned().collect();
196 let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().cloned().collect();
197
198 let mut result: Vec<(i64, u32)> = Vec::new();
199 let mut seen_memories: std::collections::HashSet<i64> = std::collections::HashSet::new();
200
201 for (&entity_id, &hop) in &entity_depth {
202 if seed_entity_set.contains(&entity_id) {
203 continue;
204 }
205 let mut stmt = conn.prepare_cached(
206 "SELECT DISTINCT me.memory_id
207 FROM memory_entities me
208 JOIN memories m ON m.id = me.memory_id
209 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
210 )?;
211 let mem_ids: Vec<i64> = stmt
212 .query_map(params![entity_id], |r| r.get(0))?
213 .filter_map(|r| r.ok())
214 .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
215 .collect();
216
217 for mem_id in mem_ids {
218 seen_memories.insert(mem_id);
219 result.push((mem_id, hop));
220 }
221 }
222
223 result.sort_unstable_by_key(|&(id, _)| id);
224 Ok(result)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use rusqlite::Connection;
231
232 fn setup_db() -> Connection {
233 let conn = Connection::open_in_memory().unwrap();
234 conn.execute_batch(
235 "CREATE TABLE memories (
236 id INTEGER PRIMARY KEY,
237 namespace TEXT NOT NULL,
238 deleted_at TEXT
239 );
240 CREATE TABLE memory_entities (
241 memory_id INTEGER NOT NULL,
242 entity_id INTEGER NOT NULL
243 );
244 CREATE TABLE relationships (
245 source_id INTEGER NOT NULL,
246 target_id INTEGER NOT NULL,
247 weight REAL NOT NULL,
248 namespace TEXT NOT NULL
249 );",
250 )
251 .unwrap();
252 conn
253 }
254
255 fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
256 conn.execute(
257 "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
258 params![
259 id,
260 namespace,
261 if deleted { Some("2024-01-01") } else { None }
262 ],
263 )
264 .unwrap();
265 }
266
267 fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
268 conn.execute(
269 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
270 params![memory_id, entity_id],
271 )
272 .unwrap();
273 }
274
275 fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
276 conn.execute(
277 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
278 params![src, tgt, weight, ns],
279 )
280 .unwrap();
281 }
282
283 #[test]
286 fn retorna_vazio_quando_seeds_vazio() {
287 let conn = setup_db();
288 let resultado = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
289 assert!(resultado.is_empty());
290 }
291
292 #[test]
293 fn retorna_vazio_quando_max_hops_zero() {
294 let conn = setup_db();
295 insert_memory(&conn, 1, "ns", false);
296 link_memory_entity(&conn, 1, 10);
297 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
298 assert!(resultado.is_empty());
299 }
300
301 #[test]
302 fn retorna_vazio_quando_seed_sem_entidades() {
303 let conn = setup_db();
304 insert_memory(&conn, 1, "ns", false);
305 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
307 assert!(resultado.is_empty());
308 }
309
310 #[test]
311 fn retorna_vazio_quando_sem_relacionamentos() {
312 let conn = setup_db();
313 insert_memory(&conn, 1, "ns", false);
314 link_memory_entity(&conn, 1, 10);
315 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
317 assert!(resultado.is_empty());
318 }
319
320 #[test]
323 fn traversal_basico_um_hop() {
324 let conn = setup_db();
325
326 insert_memory(&conn, 1, "ns", false);
328 link_memory_entity(&conn, 1, 10);
329
330 insert_memory(&conn, 2, "ns", false);
332 link_memory_entity(&conn, 2, 20);
333
334 insert_relationship(&conn, 10, 20, 1.0, "ns");
336
337 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
338 assert_eq!(resultado, vec![2]);
339 }
340
341 #[test]
342 fn traversal_dois_hops() {
343 let conn = setup_db();
344
345 insert_memory(&conn, 1, "ns", false);
346 link_memory_entity(&conn, 1, 10);
347
348 insert_memory(&conn, 2, "ns", false);
349 link_memory_entity(&conn, 2, 20);
350
351 insert_memory(&conn, 3, "ns", false);
352 link_memory_entity(&conn, 3, 30);
353
354 insert_relationship(&conn, 10, 20, 1.0, "ns");
356 insert_relationship(&conn, 20, 30, 1.0, "ns");
357
358 let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
359 resultado.sort_unstable();
360 assert_eq!(resultado, vec![2, 3]);
361 }
362
363 #[test]
364 fn max_hops_limita_profundidade() {
365 let conn = setup_db();
366
367 insert_memory(&conn, 1, "ns", false);
368 link_memory_entity(&conn, 1, 10);
369
370 insert_memory(&conn, 2, "ns", false);
371 link_memory_entity(&conn, 2, 20);
372
373 insert_memory(&conn, 3, "ns", false);
374 link_memory_entity(&conn, 3, 30);
375
376 insert_relationship(&conn, 10, 20, 1.0, "ns");
377 insert_relationship(&conn, 20, 30, 1.0, "ns");
378
379 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
381 assert_eq!(resultado, vec![2]);
382 assert!(!resultado.contains(&3));
383 }
384
385 #[test]
388 fn relacionamento_com_peso_abaixo_do_minimo_ignorado() {
389 let conn = setup_db();
390
391 insert_memory(&conn, 1, "ns", false);
392 link_memory_entity(&conn, 1, 10);
393
394 insert_memory(&conn, 2, "ns", false);
395 link_memory_entity(&conn, 2, 20);
396
397 insert_relationship(&conn, 10, 20, 0.3, "ns");
399
400 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
401 assert!(resultado.is_empty());
402 }
403
404 #[test]
405 fn relacionamento_com_peso_exatamente_no_minimo_incluido() {
406 let conn = setup_db();
407
408 insert_memory(&conn, 1, "ns", false);
409 link_memory_entity(&conn, 1, 10);
410
411 insert_memory(&conn, 2, "ns", false);
412 link_memory_entity(&conn, 2, 20);
413
414 insert_relationship(&conn, 10, 20, 0.5, "ns");
415
416 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
417 assert_eq!(resultado, vec![2]);
418 }
419
420 #[test]
423 fn relacionamento_de_namespace_diferente_ignorado() {
424 let conn = setup_db();
425
426 insert_memory(&conn, 1, "ns_a", false);
427 link_memory_entity(&conn, 1, 10);
428
429 insert_memory(&conn, 2, "ns_a", false);
430 link_memory_entity(&conn, 2, 20);
431
432 insert_relationship(&conn, 10, 20, 1.0, "ns_b");
434
435 let resultado = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
436 assert!(resultado.is_empty());
437 }
438
439 #[test]
442 fn seeds_nao_aparecem_no_resultado() {
443 let conn = setup_db();
444
445 insert_memory(&conn, 1, "ns", false);
446 link_memory_entity(&conn, 1, 10);
447
448 insert_memory(&conn, 2, "ns", false);
449 link_memory_entity(&conn, 2, 20);
450
451 insert_relationship(&conn, 10, 20, 1.0, "ns");
453 insert_relationship(&conn, 20, 10, 1.0, "ns");
454
455 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
456 assert!(!resultado.contains(&1));
458 assert_eq!(resultado, vec![2]);
459 }
460
461 #[test]
464 fn memorias_deletadas_nao_incluidas() {
465 let conn = setup_db();
466
467 insert_memory(&conn, 1, "ns", false);
468 link_memory_entity(&conn, 1, 10);
469
470 insert_memory(&conn, 2, "ns", true);
472 link_memory_entity(&conn, 2, 20);
473
474 insert_relationship(&conn, 10, 20, 1.0, "ns");
475
476 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
477 assert!(resultado.is_empty());
478 }
479
480 #[test]
483 fn multiplos_seeds_unidos_no_resultado() {
484 let conn = setup_db();
485
486 insert_memory(&conn, 1, "ns", false);
487 link_memory_entity(&conn, 1, 10);
488
489 insert_memory(&conn, 2, "ns", false);
490 link_memory_entity(&conn, 2, 20);
491
492 insert_memory(&conn, 3, "ns", false);
493 link_memory_entity(&conn, 3, 30);
494
495 insert_memory(&conn, 4, "ns", false);
496 link_memory_entity(&conn, 4, 40);
497
498 insert_relationship(&conn, 10, 30, 1.0, "ns");
499 insert_relationship(&conn, 20, 40, 1.0, "ns");
500
501 let mut resultado = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
502 resultado.sort_unstable();
503 assert_eq!(resultado, vec![3, 4]);
504 }
505
506 #[test]
509 fn resultado_sem_duplicatas() {
510 let conn = setup_db();
511
512 insert_memory(&conn, 1, "ns", false);
513 link_memory_entity(&conn, 1, 10);
514 link_memory_entity(&conn, 1, 11); insert_memory(&conn, 2, "ns", false);
517 link_memory_entity(&conn, 2, 20);
518
519 insert_relationship(&conn, 10, 20, 1.0, "ns");
521 insert_relationship(&conn, 11, 20, 1.0, "ns");
522
523 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
524 assert_eq!(resultado.len(), 1);
526 assert_eq!(resultado, vec![2]);
527 }
528
529 #[test]
532 fn single_node_sem_vizinhos_retorna_vazio() {
533 let conn = setup_db();
534
535 insert_memory(&conn, 1, "ns", false);
536 link_memory_entity(&conn, 1, 10);
537 let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
540 assert!(resultado.is_empty());
541 }
542
543 #[test]
546 fn ciclo_nao_causa_loop_infinito() {
547 let conn = setup_db();
548
549 insert_memory(&conn, 1, "ns", false);
550 link_memory_entity(&conn, 1, 10);
551
552 insert_memory(&conn, 2, "ns", false);
553 link_memory_entity(&conn, 2, 20);
554
555 insert_memory(&conn, 3, "ns", false);
556 link_memory_entity(&conn, 3, 30);
557
558 insert_relationship(&conn, 10, 20, 1.0, "ns");
560 insert_relationship(&conn, 20, 30, 1.0, "ns");
561 insert_relationship(&conn, 30, 10, 1.0, "ns");
562
563 let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
564 resultado.sort_unstable();
565 assert_eq!(resultado, vec![2, 3]);
567 }
568}