1use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10
11use crate::error::Result;
12use crate::types::{CrossReference, EdgeType, MemoryId, RelationSource};
13use chrono::{DateTime, Utc};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TraversalOptions {
18 #[serde(default = "default_depth")]
20 pub depth: usize,
21 #[serde(default)]
23 pub edge_types: Vec<EdgeType>,
24 #[serde(default)]
26 pub min_score: f32,
27 #[serde(default)]
29 pub min_confidence: f32,
30 #[serde(default = "default_limit_per_hop")]
32 pub limit_per_hop: usize,
33 #[serde(default = "default_include_entities")]
35 pub include_entities: bool,
36 #[serde(default)]
38 pub direction: TraversalDirection,
39}
40
41fn default_depth() -> usize {
42 2
43}
44
45fn default_limit_per_hop() -> usize {
46 50
47}
48
49fn default_include_entities() -> bool {
50 true
51}
52
53impl Default for TraversalOptions {
54 fn default() -> Self {
55 Self {
56 depth: 2,
57 edge_types: vec![],
58 min_score: 0.0,
59 min_confidence: 0.0,
60 limit_per_hop: 50,
61 include_entities: true,
62 direction: TraversalDirection::Both,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum TraversalDirection {
71 Outgoing,
73 Incoming,
75 #[default]
77 Both,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct TraversalNode {
83 pub memory_id: MemoryId,
85 pub depth: usize,
87 pub path: Vec<MemoryId>,
89 pub edge_path: Vec<String>,
91 pub cumulative_score: f32,
93 pub connection_type: ConnectionType,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
99#[serde(rename_all = "snake_case")]
100pub enum ConnectionType {
101 Origin,
103 CrossReference,
105 SharedEntity { entity_name: String },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct TraversalResult {
112 pub start_id: MemoryId,
114 pub nodes: Vec<TraversalNode>,
116 pub discovery_edges: Vec<CrossReference>,
118 pub stats: TraversalStats,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Default)]
123pub struct TraversalStats {
124 pub nodes_visited: usize,
126 pub nodes_per_depth: HashMap<usize, usize>,
128 pub connection_type_counts: HashMap<String, usize>,
130 pub max_depth_reached: usize,
132}
133
134pub fn get_related_multi_hop(
136 conn: &Connection,
137 start_id: MemoryId,
138 options: &TraversalOptions,
139) -> Result<TraversalResult> {
140 let mut visited: HashSet<MemoryId> = HashSet::new();
141 let mut nodes: Vec<TraversalNode> = Vec::new();
142 let mut discovery_edges: Vec<CrossReference> = Vec::new();
143 let mut stats = TraversalStats::default();
144
145 let mut queue: VecDeque<(MemoryId, usize, Vec<MemoryId>, Vec<String>, f32)> = VecDeque::new();
147
148 visited.insert(start_id);
150 nodes.push(TraversalNode {
151 memory_id: start_id,
152 depth: 0,
153 path: vec![start_id],
154 edge_path: vec![],
155 cumulative_score: 1.0,
156 connection_type: ConnectionType::Origin,
157 });
158 queue.push_back((start_id, 0, vec![start_id], vec![], 1.0));
159
160 *stats.nodes_per_depth.entry(0).or_insert(0) += 1;
161 *stats
162 .connection_type_counts
163 .entry("origin".to_string())
164 .or_insert(0) += 1;
165
166 while !queue.is_empty() {
168 let level_size = queue.len();
169 let mut current_batch = Vec::with_capacity(level_size);
170 for _ in 0..level_size {
171 if let Some(item) = queue.pop_front() {
172 current_batch.push(item);
173 }
174 }
175
176 if current_batch.is_empty() {
177 break;
178 }
179
180 let current_depth = current_batch[0].1;
182
183 if current_depth >= options.depth {
184 continue;
185 }
186
187 let node_ids: Vec<MemoryId> = current_batch.iter().map(|(id, _, _, _, _)| *id).collect();
188
189 let crossrefs_map = get_edges_for_traversal_batch(
191 conn,
192 &node_ids,
193 &options.edge_types,
194 options.min_score,
195 options.min_confidence,
196 options.direction,
197 options.limit_per_hop,
198 )?;
199
200 let entity_connections_map = if options.include_entities {
202 get_entity_connections_batch(conn, &node_ids, options.limit_per_hop)?
203 } else {
204 HashMap::new()
205 };
206
207 for (current_id, _current_depth, current_path, current_edge_path, current_score) in
209 current_batch
210 {
211 if let Some(crossrefs) = crossrefs_map.get(¤t_id) {
213 for crossref in crossrefs.iter() {
214 let neighbor_id = if crossref.from_id == current_id {
216 crossref.to_id
217 } else {
218 crossref.from_id
219 };
220
221 if visited.contains(&neighbor_id) {
222 continue;
223 }
224
225 visited.insert(neighbor_id);
226
227 let mut new_path = current_path.clone();
228 new_path.push(neighbor_id);
229
230 let mut new_edge_path = current_edge_path.clone();
231 new_edge_path.push(crossref.edge_type.as_str().to_string());
232
233 let new_score = current_score * crossref.score * crossref.confidence;
234 let new_depth = current_depth + 1;
235
236 nodes.push(TraversalNode {
237 memory_id: neighbor_id,
238 depth: new_depth,
239 path: new_path.clone(),
240 edge_path: new_edge_path.clone(),
241 cumulative_score: new_score,
242 connection_type: ConnectionType::CrossReference,
243 });
244
245 discovery_edges.push(crossref.clone());
246
247 *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
248 *stats
249 .connection_type_counts
250 .entry("cross_reference".to_string())
251 .or_insert(0) += 1;
252
253 if new_depth < options.depth {
254 queue.push_back((
255 neighbor_id,
256 new_depth,
257 new_path,
258 new_edge_path,
259 new_score,
260 ));
261 }
262
263 stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
264 }
265 }
266
267 if let Some(entity_connections) = entity_connections_map.get(¤t_id) {
269 for (neighbor_id, entity_name) in
270 entity_connections.iter().take(options.limit_per_hop)
271 {
272 let neighbor_id = *neighbor_id;
273 if visited.contains(&neighbor_id) {
274 continue;
275 }
276
277 visited.insert(neighbor_id);
278
279 let mut new_path = current_path.clone();
280 new_path.push(neighbor_id);
281
282 let mut new_edge_path = current_edge_path.clone();
283 new_edge_path.push(format!("entity:{}", entity_name));
284
285 let new_depth = current_depth + 1;
286 let new_score = current_score * 0.5;
288
289 nodes.push(TraversalNode {
290 memory_id: neighbor_id,
291 depth: new_depth,
292 path: new_path.clone(),
293 edge_path: new_edge_path.clone(),
294 cumulative_score: new_score,
295 connection_type: ConnectionType::SharedEntity {
296 entity_name: entity_name.clone(),
297 },
298 });
299
300 *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
301 *stats
302 .connection_type_counts
303 .entry("shared_entity".to_string())
304 .or_insert(0) += 1;
305
306 if new_depth < options.depth {
307 queue.push_back((
308 neighbor_id,
309 new_depth,
310 new_path,
311 new_edge_path,
312 new_score,
313 ));
314 }
315
316 stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
317 }
318 }
319 }
320 }
321
322 stats.nodes_visited = nodes.len();
323
324 Ok(TraversalResult {
325 start_id,
326 nodes,
327 discovery_edges,
328 stats,
329 })
330}
331
332fn get_edges_for_traversal_batch(
337 conn: &Connection,
338 memory_ids: &[MemoryId],
339 edge_types: &[EdgeType],
340 min_score: f32,
341 min_confidence: f32,
342 direction: TraversalDirection,
343 limit_per_node: usize,
344) -> Result<HashMap<MemoryId, Vec<CrossReference>>> {
345 if memory_ids.is_empty() {
346 return Ok(HashMap::new());
347 }
348
349 let mut result: HashMap<MemoryId, Vec<CrossReference>> = HashMap::new();
350 let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
351
352 for chunk in memory_ids.chunks(100) {
354 let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
355
356 let edge_type_clause = if edge_types.is_empty() {
357 String::new()
358 } else {
359 let types: Vec<String> = edge_types
360 .iter()
361 .map(|e| format!("'{}'", e.as_str()))
362 .collect();
363 format!(" AND edge_type IN ({})", types.join(", "))
364 };
365
366 let (partition_col, filter_clause) = match direction {
368 TraversalDirection::Outgoing => ("from_id", format!("from_id IN ({})", placeholders)),
369 TraversalDirection::Incoming => ("to_id", format!("to_id IN ({})", placeholders)),
370 TraversalDirection::Both => {
371 let query = format!(
374 r#"
375 WITH ranked_edges AS (
376 SELECT *, ROW_NUMBER() OVER (
377 PARTITION BY from_id ORDER BY score * confidence DESC
378 ) as rn
379 FROM crossrefs
380 WHERE from_id IN ({placeholders}) AND valid_to IS NULL
381 AND score >= ? AND confidence >= ?
382 {edge_type_clause}
383 UNION ALL
384 SELECT *, ROW_NUMBER() OVER (
385 PARTITION BY to_id ORDER BY score * confidence DESC
386 ) as rn
387 FROM crossrefs
388 WHERE to_id IN ({placeholders}) AND from_id NOT IN ({placeholders}) AND valid_to IS NULL
389 AND score >= ? AND confidence >= ?
390 {edge_type_clause}
391 )
392 SELECT from_id, to_id, edge_type, score, confidence, strength, source,
393 source_context, created_at, valid_from, valid_to, pinned, metadata
394 FROM ranked_edges
395 WHERE rn <= ?
396 "#,
397 placeholders = placeholders,
398 edge_type_clause = edge_type_clause,
399 );
400
401 let mut stmt = conn.prepare(&query)?;
402 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
403
404 for id in chunk {
406 params.push(Box::new(*id));
407 }
408 params.push(Box::new(min_score));
409 params.push(Box::new(min_confidence));
410
411 for id in chunk {
413 params.push(Box::new(*id));
414 }
415 for id in chunk {
416 params.push(Box::new(*id));
417 }
418 params.push(Box::new(min_score));
419 params.push(Box::new(min_confidence));
420
421 params.push(Box::new(limit_per_node as i64));
423
424 let param_refs: Vec<&dyn rusqlite::ToSql> =
425 params.iter().map(|p| p.as_ref()).collect();
426
427 let crossrefs = stmt
428 .query_map(param_refs.as_slice(), crossref_from_row)?
429 .filter_map(|r| r.ok());
430
431 for crossref in crossrefs {
432 if id_set.contains(&crossref.from_id) {
433 result
434 .entry(crossref.from_id)
435 .or_default()
436 .push(crossref.clone());
437 }
438 if id_set.contains(&crossref.to_id) && crossref.from_id != crossref.to_id {
439 result.entry(crossref.to_id).or_default().push(crossref);
440 }
441 }
442
443 continue; }
445 };
446
447 let query = format!(
449 r#"
450 WITH ranked_edges AS (
451 SELECT *, ROW_NUMBER() OVER (
452 PARTITION BY {partition_col} ORDER BY score * confidence DESC
453 ) as rn
454 FROM crossrefs
455 WHERE {filter_clause} AND valid_to IS NULL
456 AND score >= ? AND confidence >= ?
457 {edge_type_clause}
458 )
459 SELECT from_id, to_id, edge_type, score, confidence, strength, source,
460 source_context, created_at, valid_from, valid_to, pinned, metadata
461 FROM ranked_edges
462 WHERE rn <= ?
463 "#,
464 partition_col = partition_col,
465 filter_clause = filter_clause,
466 edge_type_clause = edge_type_clause,
467 );
468
469 let mut stmt = conn.prepare(&query)?;
470
471 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
472 for id in chunk {
473 params.push(Box::new(*id));
474 }
475 params.push(Box::new(min_score));
476 params.push(Box::new(min_confidence));
477 params.push(Box::new(limit_per_node as i64));
478
479 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
480
481 let crossrefs = stmt
482 .query_map(param_refs.as_slice(), crossref_from_row)?
483 .filter_map(|r| r.ok());
484
485 for crossref in crossrefs {
486 match direction {
487 TraversalDirection::Outgoing => {
488 if id_set.contains(&crossref.from_id) {
489 result.entry(crossref.from_id).or_default().push(crossref);
490 }
491 }
492 TraversalDirection::Incoming => {
493 if id_set.contains(&crossref.to_id) {
494 result.entry(crossref.to_id).or_default().push(crossref);
495 }
496 }
497 TraversalDirection::Both => unreachable!(), }
499 }
500 }
501
502 Ok(result)
503}
504
505fn get_entity_connections_batch(
507 conn: &Connection,
508 memory_ids: &[MemoryId],
509 _limit: usize,
510) -> Result<HashMap<MemoryId, Vec<(MemoryId, String)>>> {
511 if memory_ids.is_empty() {
512 return Ok(HashMap::new());
513 }
514
515 let mut result: HashMap<MemoryId, Vec<(MemoryId, String)>> = HashMap::new();
516 let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
517
518 for chunk in memory_ids.chunks(100) {
519 let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
520
521 let query = format!(
522 r#"
523 SELECT DISTINCT me1.memory_id, me2.memory_id, e.name
524 FROM memory_entities me1
525 JOIN memory_entities me2 ON me1.entity_id = me2.entity_id
526 JOIN entities e ON me1.entity_id = e.id
527 WHERE me1.memory_id IN ({}) AND me2.memory_id != me1.memory_id
528 ORDER BY e.mention_count DESC
529 "#,
530 placeholders
531 );
532
533 let mut stmt = conn.prepare(&query)?;
534
535 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
536 for id in chunk {
537 params.push(Box::new(*id));
538 }
539
540 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
541
542 let rows = stmt
543 .query_map(param_refs.as_slice(), |row| {
544 Ok((
545 row.get::<_, i64>(0)?,
546 row.get::<_, i64>(1)?,
547 row.get::<_, String>(2)?,
548 ))
549 })?
550 .filter_map(|r| r.ok());
551
552 for (source_id, target_id, entity_name) in rows {
553 if id_set.contains(&source_id) {
554 result
555 .entry(source_id)
556 .or_default()
557 .push((target_id, entity_name));
558 }
559 }
560 }
561
562 Ok(result)
563}
564
565fn crossref_from_row(row: &rusqlite::Row) -> rusqlite::Result<CrossReference> {
567 let edge_type_str: String = row.get("edge_type")?;
568 let source_str: String = row.get("source")?;
569 let created_at_str: String = row.get("created_at")?;
570 let valid_from_str: String = row.get("valid_from")?;
571 let valid_to_str: Option<String> = row.get("valid_to")?;
572 let metadata_str: String = row.get("metadata")?;
573
574 Ok(CrossReference {
575 from_id: row.get("from_id")?,
576 to_id: row.get("to_id")?,
577 edge_type: edge_type_str.parse().unwrap_or(EdgeType::RelatedTo),
578 score: row.get("score")?,
579 confidence: row.get("confidence")?,
580 strength: row.get("strength")?,
581 source: match source_str.as_str() {
582 "manual" => RelationSource::Manual,
583 "llm" => RelationSource::Llm,
584 _ => RelationSource::Auto,
585 },
586 source_context: row.get("source_context")?,
587 created_at: DateTime::parse_from_rfc3339(&created_at_str)
588 .map(|dt| dt.with_timezone(&Utc))
589 .unwrap_or_else(|_| Utc::now()),
590 valid_from: DateTime::parse_from_rfc3339(&valid_from_str)
591 .map(|dt| dt.with_timezone(&Utc))
592 .unwrap_or_else(|_| Utc::now()),
593 valid_to: valid_to_str.and_then(|s| {
594 DateTime::parse_from_rfc3339(&s)
595 .map(|dt| dt.with_timezone(&Utc))
596 .ok()
597 }),
598 pinned: row.get::<_, i32>("pinned")? != 0,
599 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
600 })
601}
602
603pub fn find_path(
605 conn: &Connection,
606 from_id: MemoryId,
607 to_id: MemoryId,
608 max_depth: usize,
609) -> Result<Option<TraversalNode>> {
610 let options = TraversalOptions {
611 depth: max_depth,
612 include_entities: true,
613 ..Default::default()
614 };
615
616 let result = get_related_multi_hop(conn, from_id, &options)?;
617
618 Ok(result.nodes.into_iter().find(|n| n.memory_id == to_id))
620}
621
622pub fn get_neighborhood(
624 conn: &Connection,
625 center_id: MemoryId,
626 radius: usize,
627) -> Result<Vec<MemoryId>> {
628 let options = TraversalOptions {
629 depth: radius,
630 include_entities: true,
631 ..Default::default()
632 };
633
634 let result = get_related_multi_hop(conn, center_id, &options)?;
635
636 Ok(result.nodes.into_iter().map(|n| n.memory_id).collect())
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use crate::intelligence::entities::{EntityRelation, EntityType, ExtractedEntity};
643 use crate::storage::entity_queries::{link_entity_to_memory, upsert_entity};
644 use crate::storage::queries::{create_crossref, create_memory};
645 use crate::storage::Storage;
646 use crate::types::{CreateCrossRefInput, CreateMemoryInput, MemoryType};
647
648 fn create_test_memory(conn: &Connection, content: &str) -> MemoryId {
649 let input = CreateMemoryInput {
650 content: content.to_string(),
651 memory_type: MemoryType::Note,
652 tags: vec![],
653 importance: None,
654 metadata: Default::default(),
655 scope: Default::default(),
656 workspace: None,
657 tier: Default::default(),
658 defer_embedding: false,
659 ttl_seconds: None,
660 dedup_mode: Default::default(),
661 dedup_threshold: None,
662 event_time: None,
663 event_duration_seconds: None,
664 trigger_pattern: None,
665 summary_of_id: None,
666 };
667 create_memory(conn, &input).unwrap().id
668 }
669
670 fn create_test_crossref(
671 conn: &Connection,
672 from_id: MemoryId,
673 to_id: MemoryId,
674 edge_type: EdgeType,
675 ) -> crate::error::Result<()> {
676 let input = CreateCrossRefInput {
677 from_id,
678 to_id,
679 edge_type,
680 strength: None,
681 source_context: None,
682 pinned: false,
683 };
684 create_crossref(conn, &input)?;
685 Ok(())
686 }
687
688 #[test]
689 fn test_multi_hop_traversal() {
690 let storage = Storage::open_in_memory().unwrap();
691 storage
692 .with_transaction(|conn| {
693 let id_a = create_test_memory(conn, "Memory A");
695 let id_b = create_test_memory(conn, "Memory B");
696 let id_c = create_test_memory(conn, "Memory C");
697 let id_d = create_test_memory(conn, "Memory D");
698
699 create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
701 create_test_crossref(conn, id_b, id_c, EdgeType::RelatedTo)?;
702 create_test_crossref(conn, id_c, id_d, EdgeType::RelatedTo)?;
703
704 let options = TraversalOptions {
706 depth: 1,
707 include_entities: false,
708 ..Default::default()
709 };
710 let result = get_related_multi_hop(conn, id_a, &options)?;
711 assert_eq!(result.nodes.len(), 2); assert!(result.nodes.iter().any(|n| n.memory_id == id_a));
713 assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
714
715 let options = TraversalOptions {
717 depth: 2,
718 include_entities: false,
719 ..Default::default()
720 };
721 let result = get_related_multi_hop(conn, id_a, &options)?;
722 assert_eq!(result.nodes.len(), 3); assert!(result.nodes.iter().any(|n| n.memory_id == id_c));
724
725 let options = TraversalOptions {
727 depth: 3,
728 include_entities: false,
729 ..Default::default()
730 };
731 let result = get_related_multi_hop(conn, id_a, &options)?;
732 assert_eq!(result.nodes.len(), 4); Ok(())
735 })
736 .unwrap();
737 }
738
739 #[test]
740 fn test_entity_based_connections() {
741 let storage = Storage::open_in_memory().unwrap();
742 storage
743 .with_transaction(|conn| {
744 let id_a = create_test_memory(conn, "Memory about Rust programming");
746 let id_b = create_test_memory(conn, "Another memory about Rust");
747 let id_c = create_test_memory(conn, "Memory about Python");
748
749 let entity = ExtractedEntity {
751 text: "Rust".to_string(),
752 normalized: "rust".to_string(),
753 entity_type: EntityType::Concept,
754 confidence: 0.9,
755 offset: 0,
756 length: 4,
757 suggested_relation: EntityRelation::Mentions,
758 };
759 let entity_id = upsert_entity(conn, &entity)?;
760 let _ = link_entity_to_memory(
761 conn,
762 id_a,
763 entity_id,
764 EntityRelation::Mentions,
765 0.9,
766 None,
767 )?;
768 let _ = link_entity_to_memory(
769 conn,
770 id_b,
771 entity_id,
772 EntityRelation::Mentions,
773 0.8,
774 None,
775 )?;
776
777 let options = TraversalOptions {
779 depth: 1,
780 include_entities: true,
781 ..Default::default()
782 };
783 let result = get_related_multi_hop(conn, id_a, &options)?;
784
785 assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
787 let b_node = result.nodes.iter().find(|n| n.memory_id == id_b).unwrap();
788 assert!(matches!(
789 &b_node.connection_type,
790 ConnectionType::SharedEntity { entity_name } if entity_name == "Rust"
791 ));
792
793 assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
795
796 Ok(())
797 })
798 .unwrap();
799 }
800
801 #[test]
802 fn test_find_path() {
803 let storage = Storage::open_in_memory().unwrap();
804 storage
805 .with_transaction(|conn| {
806 let id_a = create_test_memory(conn, "Start");
807 let id_b = create_test_memory(conn, "Middle");
808 let id_c = create_test_memory(conn, "End");
809
810 create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
811 create_test_crossref(conn, id_b, id_c, EdgeType::DependsOn)?;
812
813 let path = find_path(conn, id_a, id_c, 5)?;
814 assert!(path.is_some());
815 let path = path.unwrap();
816 assert_eq!(path.memory_id, id_c);
817 assert_eq!(path.depth, 2);
818 assert_eq!(path.path.len(), 3);
819 assert_eq!(path.path, vec![id_a, id_b, id_c]);
820
821 Ok(())
822 })
823 .unwrap();
824 }
825
826 #[test]
827 fn test_traversal_direction() {
828 let storage = Storage::open_in_memory().unwrap();
829 storage
830 .with_transaction(|conn| {
831 let id_a = create_test_memory(conn, "A");
832 let id_b = create_test_memory(conn, "B");
833 let id_c = create_test_memory(conn, "C");
834
835 create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
837 create_test_crossref(conn, id_c, id_b, EdgeType::RelatedTo)?;
838
839 let options = TraversalOptions {
841 depth: 1,
842 direction: TraversalDirection::Outgoing,
843 include_entities: false,
844 ..Default::default()
845 };
846 let result = get_related_multi_hop(conn, id_b, &options)?;
847 assert_eq!(result.nodes.len(), 1); let options = TraversalOptions {
851 depth: 1,
852 direction: TraversalDirection::Incoming,
853 include_entities: false,
854 ..Default::default()
855 };
856 let result = get_related_multi_hop(conn, id_b, &options)?;
857 assert_eq!(result.nodes.len(), 3); Ok(())
860 })
861 .unwrap();
862 }
863
864 #[test]
865 fn test_edge_type_filter() {
866 let storage = Storage::open_in_memory().unwrap();
867 storage
868 .with_transaction(|conn| {
869 let id_a = create_test_memory(conn, "A");
870 let id_b = create_test_memory(conn, "B");
871 let id_c = create_test_memory(conn, "C");
872
873 create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
874 create_test_crossref(conn, id_a, id_c, EdgeType::DependsOn)?;
875
876 let options = TraversalOptions {
878 depth: 1,
879 edge_types: vec![EdgeType::RelatedTo],
880 include_entities: false,
881 ..Default::default()
882 };
883 let result = get_related_multi_hop(conn, id_a, &options)?;
884 assert_eq!(result.nodes.len(), 2); assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
886 assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
887
888 Ok(())
889 })
890 .unwrap();
891 }
892}