1use crate::errors::AppError;
9use rusqlite::{params, Connection};
10
11pub fn traverse_from_memories(
45 conn: &Connection,
46 seed_memory_ids: &[i64],
47 namespace: &str,
48 min_weight: f64,
49 max_hops: u32,
50) -> Result<Vec<i64>, AppError> {
51 if seed_memory_ids.is_empty() || max_hops == 0 {
52 return Ok(vec![]);
53 }
54
55 let mut seed_entities: Vec<i64> = Vec::new();
57 for &mem_id in seed_memory_ids {
58 let mut stmt =
59 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
60 let ids: Vec<i64> = stmt
61 .query_map(params![mem_id], |r| r.get(0))?
62 .filter_map(|r| r.ok())
63 .collect();
64 seed_entities.extend(ids);
65 }
66 seed_entities.sort_unstable();
67 seed_entities.dedup();
68
69 if seed_entities.is_empty() {
70 return Ok(vec![]);
71 }
72
73 use std::collections::HashSet;
75 let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
76 let mut frontier = seed_entities.clone();
77
78 for _ in 0..max_hops {
79 if frontier.is_empty() {
80 break;
81 }
82 let mut next_frontier = Vec::new();
83
84 for &entity_id in &frontier {
85 let mut stmt = conn.prepare_cached(
86 "SELECT target_id FROM relationships
87 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
88 )?;
89 let neighbors: Vec<i64> = stmt
90 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
91 .filter_map(|r| r.ok())
92 .filter(|id| !visited.contains(id))
93 .collect();
94
95 for id in neighbors {
96 visited.insert(id);
97 next_frontier.push(id);
98 }
99 }
100 frontier = next_frontier;
101 }
102
103 let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
105 let graph_only_entities: Vec<i64> = visited
106 .into_iter()
107 .filter(|id| !seed_entities.contains(id))
108 .collect();
109
110 let mut result_ids: Vec<i64> = Vec::new();
111 for &entity_id in &graph_only_entities {
112 let mut stmt = conn.prepare_cached(
113 "SELECT DISTINCT me.memory_id
114 FROM memory_entities me
115 JOIN memories m ON m.id = me.memory_id
116 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
117 )?;
118 let mem_ids: Vec<i64> = stmt
119 .query_map(params![entity_id], |r| r.get(0))?
120 .filter_map(|r| r.ok())
121 .filter(|id| !seed_set.contains(id))
122 .collect();
123 result_ids.extend(mem_ids);
124 }
125
126 result_ids.sort_unstable();
127 result_ids.dedup();
128 Ok(result_ids)
129}
130
131pub fn traverse_from_memories_with_hops(
141 conn: &Connection,
142 seed_memory_ids: &[i64],
143 namespace: &str,
144 min_weight: f64,
145 max_hops: u32,
146) -> Result<Vec<(i64, u32)>, AppError> {
147 if seed_memory_ids.is_empty() || max_hops == 0 {
148 return Ok(vec![]);
149 }
150
151 let mut seed_entities: Vec<i64> = Vec::new();
153 for &mem_id in seed_memory_ids {
154 let mut stmt =
155 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
156 let ids: Vec<i64> = stmt
157 .query_map(params![mem_id], |r| r.get(0))?
158 .filter_map(|r| r.ok())
159 .collect();
160 seed_entities.extend(ids);
161 }
162 seed_entities.sort_unstable();
163 seed_entities.dedup();
164
165 if seed_entities.is_empty() {
166 return Ok(vec![]);
167 }
168
169 use std::collections::HashMap;
171 let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
172 let mut frontier = seed_entities.clone();
173
174 for hop in 1..=max_hops {
175 if frontier.is_empty() {
176 break;
177 }
178 let mut next_frontier = Vec::new();
179
180 for &entity_id in &frontier {
181 let mut stmt = conn.prepare_cached(
182 "SELECT target_id FROM relationships
183 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
184 )?;
185 let neighbors: Vec<i64> = stmt
186 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
187 .filter_map(|r| r.ok())
188 .filter(|id| !entity_depth.contains_key(id))
189 .collect();
190
191 for id in neighbors {
192 entity_depth.insert(id, hop);
193 next_frontier.push(id);
194 }
195 }
196 frontier = next_frontier;
197 }
198
199 let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().cloned().collect();
201 let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().cloned().collect();
202
203 let mut result: Vec<(i64, u32)> = Vec::new();
204 let mut seen_memories: std::collections::HashSet<i64> = std::collections::HashSet::new();
205
206 for (&entity_id, &hop) in &entity_depth {
207 if seed_entity_set.contains(&entity_id) {
208 continue;
209 }
210 let mut stmt = conn.prepare_cached(
211 "SELECT DISTINCT me.memory_id
212 FROM memory_entities me
213 JOIN memories m ON m.id = me.memory_id
214 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
215 )?;
216 let mem_ids: Vec<i64> = stmt
217 .query_map(params![entity_id], |r| r.get(0))?
218 .filter_map(|r| r.ok())
219 .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
220 .collect();
221
222 for mem_id in mem_ids {
223 seen_memories.insert(mem_id);
224 result.push((mem_id, hop));
225 }
226 }
227
228 result.sort_unstable_by_key(|&(id, _)| id);
229 Ok(result)
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use rusqlite::Connection;
236
237 fn setup_db() -> Connection {
238 let conn = Connection::open_in_memory().unwrap();
239 conn.execute_batch(
240 "CREATE TABLE memories (
241 id INTEGER PRIMARY KEY,
242 namespace TEXT NOT NULL,
243 deleted_at TEXT
244 );
245 CREATE TABLE memory_entities (
246 memory_id INTEGER NOT NULL,
247 entity_id INTEGER NOT NULL
248 );
249 CREATE TABLE relationships (
250 source_id INTEGER NOT NULL,
251 target_id INTEGER NOT NULL,
252 weight REAL NOT NULL,
253 namespace TEXT NOT NULL
254 );",
255 )
256 .unwrap();
257 conn
258 }
259
260 fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
261 conn.execute(
262 "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
263 params![
264 id,
265 namespace,
266 if deleted { Some("2024-01-01") } else { None }
267 ],
268 )
269 .unwrap();
270 }
271
272 fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
273 conn.execute(
274 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
275 params![memory_id, entity_id],
276 )
277 .unwrap();
278 }
279
280 fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
281 conn.execute(
282 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
283 params![src, tgt, weight, ns],
284 )
285 .unwrap();
286 }
287
288 #[test]
291 fn returns_empty_when_seeds_empty() {
292 let conn = setup_db();
293 let result = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
294 assert!(result.is_empty());
295 }
296
297 #[test]
298 fn returns_empty_when_max_hops_zero() {
299 let conn = setup_db();
300 insert_memory(&conn, 1, "ns", false);
301 link_memory_entity(&conn, 1, 10);
302 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
303 assert!(result.is_empty());
304 }
305
306 #[test]
307 fn returns_empty_when_seed_has_no_entities() {
308 let conn = setup_db();
309 insert_memory(&conn, 1, "ns", false);
310 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
312 assert!(result.is_empty());
313 }
314
315 #[test]
316 fn returns_empty_when_no_relationships() {
317 let conn = setup_db();
318 insert_memory(&conn, 1, "ns", false);
319 link_memory_entity(&conn, 1, 10);
320 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
322 assert!(result.is_empty());
323 }
324
325 #[test]
328 fn traversal_basic_one_hop() {
329 let conn = setup_db();
330
331 insert_memory(&conn, 1, "ns", false);
333 link_memory_entity(&conn, 1, 10);
334
335 insert_memory(&conn, 2, "ns", false);
337 link_memory_entity(&conn, 2, 20);
338
339 insert_relationship(&conn, 10, 20, 1.0, "ns");
341
342 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
343 assert_eq!(result, vec![2]);
344 }
345
346 #[test]
347 fn traversal_two_hops() {
348 let conn = setup_db();
349
350 insert_memory(&conn, 1, "ns", false);
351 link_memory_entity(&conn, 1, 10);
352
353 insert_memory(&conn, 2, "ns", false);
354 link_memory_entity(&conn, 2, 20);
355
356 insert_memory(&conn, 3, "ns", false);
357 link_memory_entity(&conn, 3, 30);
358
359 insert_relationship(&conn, 10, 20, 1.0, "ns");
361 insert_relationship(&conn, 20, 30, 1.0, "ns");
362
363 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
364 result.sort_unstable();
365 assert_eq!(result, vec![2, 3]);
366 }
367
368 #[test]
369 fn max_hops_limits_depth() {
370 let conn = setup_db();
371
372 insert_memory(&conn, 1, "ns", false);
373 link_memory_entity(&conn, 1, 10);
374
375 insert_memory(&conn, 2, "ns", false);
376 link_memory_entity(&conn, 2, 20);
377
378 insert_memory(&conn, 3, "ns", false);
379 link_memory_entity(&conn, 3, 30);
380
381 insert_relationship(&conn, 10, 20, 1.0, "ns");
382 insert_relationship(&conn, 20, 30, 1.0, "ns");
383
384 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
386 assert_eq!(result, vec![2]);
387 assert!(!result.contains(&3));
388 }
389
390 #[test]
393 fn relationship_with_weight_below_min_ignored() {
394 let conn = setup_db();
395
396 insert_memory(&conn, 1, "ns", false);
397 link_memory_entity(&conn, 1, 10);
398
399 insert_memory(&conn, 2, "ns", false);
400 link_memory_entity(&conn, 2, 20);
401
402 insert_relationship(&conn, 10, 20, 0.3, "ns");
404
405 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
406 assert!(result.is_empty());
407 }
408
409 #[test]
410 fn relationship_with_weight_exactly_at_min_included() {
411 let conn = setup_db();
412
413 insert_memory(&conn, 1, "ns", false);
414 link_memory_entity(&conn, 1, 10);
415
416 insert_memory(&conn, 2, "ns", false);
417 link_memory_entity(&conn, 2, 20);
418
419 insert_relationship(&conn, 10, 20, 0.5, "ns");
420
421 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
422 assert_eq!(result, vec![2]);
423 }
424
425 #[test]
428 fn relationship_from_different_namespace_ignored() {
429 let conn = setup_db();
430
431 insert_memory(&conn, 1, "ns_a", false);
432 link_memory_entity(&conn, 1, 10);
433
434 insert_memory(&conn, 2, "ns_a", false);
435 link_memory_entity(&conn, 2, 20);
436
437 insert_relationship(&conn, 10, 20, 1.0, "ns_b");
439
440 let result = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
441 assert!(result.is_empty());
442 }
443
444 #[test]
447 fn seeds_do_not_appear_in_result() {
448 let conn = setup_db();
449
450 insert_memory(&conn, 1, "ns", false);
451 link_memory_entity(&conn, 1, 10);
452
453 insert_memory(&conn, 2, "ns", false);
454 link_memory_entity(&conn, 2, 20);
455
456 insert_relationship(&conn, 10, 20, 1.0, "ns");
458 insert_relationship(&conn, 20, 10, 1.0, "ns");
459
460 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
461 assert!(!result.contains(&1));
463 assert_eq!(result, vec![2]);
464 }
465
466 #[test]
469 fn deleted_memories_not_included() {
470 let conn = setup_db();
471
472 insert_memory(&conn, 1, "ns", false);
473 link_memory_entity(&conn, 1, 10);
474
475 insert_memory(&conn, 2, "ns", true);
477 link_memory_entity(&conn, 2, 20);
478
479 insert_relationship(&conn, 10, 20, 1.0, "ns");
480
481 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
482 assert!(result.is_empty());
483 }
484
485 #[test]
488 fn multiple_seeds_merged_in_result() {
489 let conn = setup_db();
490
491 insert_memory(&conn, 1, "ns", false);
492 link_memory_entity(&conn, 1, 10);
493
494 insert_memory(&conn, 2, "ns", false);
495 link_memory_entity(&conn, 2, 20);
496
497 insert_memory(&conn, 3, "ns", false);
498 link_memory_entity(&conn, 3, 30);
499
500 insert_memory(&conn, 4, "ns", false);
501 link_memory_entity(&conn, 4, 40);
502
503 insert_relationship(&conn, 10, 30, 1.0, "ns");
504 insert_relationship(&conn, 20, 40, 1.0, "ns");
505
506 let mut result = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
507 result.sort_unstable();
508 assert_eq!(result, vec![3, 4]);
509 }
510
511 #[test]
514 fn result_without_duplicates() {
515 let conn = setup_db();
516
517 insert_memory(&conn, 1, "ns", false);
518 link_memory_entity(&conn, 1, 10);
519 link_memory_entity(&conn, 1, 11); insert_memory(&conn, 2, "ns", false);
522 link_memory_entity(&conn, 2, 20);
523
524 insert_relationship(&conn, 10, 20, 1.0, "ns");
526 insert_relationship(&conn, 11, 20, 1.0, "ns");
527
528 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
529 assert_eq!(result.len(), 1);
531 assert_eq!(result, vec![2]);
532 }
533
534 #[test]
537 fn single_node_without_neighbors_returns_empty() {
538 let conn = setup_db();
539
540 insert_memory(&conn, 1, "ns", false);
541 link_memory_entity(&conn, 1, 10);
542 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
545 assert!(result.is_empty());
546 }
547
548 #[test]
551 fn cycle_does_not_cause_infinite_loop() {
552 let conn = setup_db();
553
554 insert_memory(&conn, 1, "ns", false);
555 link_memory_entity(&conn, 1, 10);
556
557 insert_memory(&conn, 2, "ns", false);
558 link_memory_entity(&conn, 2, 20);
559
560 insert_memory(&conn, 3, "ns", false);
561 link_memory_entity(&conn, 3, 30);
562
563 insert_relationship(&conn, 10, 20, 1.0, "ns");
565 insert_relationship(&conn, 20, 30, 1.0, "ns");
566 insert_relationship(&conn, 30, 10, 1.0, "ns");
567
568 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
569 result.sort_unstable();
570 assert_eq!(result, vec![2, 3]);
572 }
573}