Skip to main content

engram/storage/
graph_queries.rs

1//! Multi-hop graph traversal queries
2//!
3//! Provides graph traversal capabilities for exploring memory relationships
4//! at various depths, with support for filtering by edge type and combining
5//! with entity-based connections.
6
7use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10
11use crate::error::Result;
12use crate::types::{CrossReference, EdgeType, MemoryId, RelationSource};
13use chrono::{DateTime, Utc};
14
15/// Options for multi-hop graph traversal
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TraversalOptions {
18    /// Maximum traversal depth (1 = direct relations only)
19    #[serde(default = "default_depth")]
20    pub depth: usize,
21    /// Filter by edge types (empty = all types)
22    #[serde(default)]
23    pub edge_types: Vec<EdgeType>,
24    /// Minimum score threshold
25    #[serde(default)]
26    pub min_score: f32,
27    /// Minimum confidence threshold
28    #[serde(default)]
29    pub min_confidence: f32,
30    /// Maximum number of results per hop
31    #[serde(default = "default_limit_per_hop")]
32    pub limit_per_hop: usize,
33    /// Include entity-based connections
34    #[serde(default = "default_include_entities")]
35    pub include_entities: bool,
36    /// Direction of traversal
37    #[serde(default)]
38    pub direction: TraversalDirection,
39}
40
41fn default_depth() -> usize {
42    2
43}
44
45fn default_limit_per_hop() -> usize {
46    50
47}
48
49fn default_include_entities() -> bool {
50    true
51}
52
53impl Default for TraversalOptions {
54    fn default() -> Self {
55        Self {
56            depth: 2,
57            edge_types: vec![],
58            min_score: 0.0,
59            min_confidence: 0.0,
60            limit_per_hop: 50,
61            include_entities: true,
62            direction: TraversalDirection::Both,
63        }
64    }
65}
66
67/// Direction of graph traversal
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum TraversalDirection {
71    /// Follow outgoing edges only (from -> to)
72    Outgoing,
73    /// Follow incoming edges only (to -> from)
74    Incoming,
75    /// Follow edges in both directions
76    #[default]
77    Both,
78}
79
80/// A node in the traversal result with path information
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct TraversalNode {
83    /// Memory ID
84    pub memory_id: MemoryId,
85    /// Depth from the starting node (0 = start node)
86    pub depth: usize,
87    /// Path of memory IDs from start to this node
88    pub path: Vec<MemoryId>,
89    /// Edge types along the path
90    pub edge_path: Vec<String>,
91    /// Cumulative score (product of edge scores)
92    pub cumulative_score: f32,
93    /// How this node was reached
94    pub connection_type: ConnectionType,
95}
96
97/// How a node was reached during traversal
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
99#[serde(rename_all = "snake_case")]
100pub enum ConnectionType {
101    /// Starting node
102    Origin,
103    /// Connected via cross-reference edge
104    CrossReference,
105    /// Connected via shared entity
106    SharedEntity { entity_name: String },
107}
108
109/// Result of a multi-hop traversal
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct TraversalResult {
112    /// Starting memory ID
113    pub start_id: MemoryId,
114    /// All nodes found during traversal
115    pub nodes: Vec<TraversalNode>,
116    /// Edges that led to newly discovered nodes
117    pub discovery_edges: Vec<CrossReference>,
118    /// Statistics about the traversal
119    pub stats: TraversalStats,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Default)]
123pub struct TraversalStats {
124    /// Total nodes visited
125    pub nodes_visited: usize,
126    /// Nodes at each depth level
127    pub nodes_per_depth: HashMap<usize, usize>,
128    /// Count by connection type
129    pub connection_type_counts: HashMap<String, usize>,
130    /// Maximum depth reached
131    pub max_depth_reached: usize,
132}
133
134/// Get related memories with multi-hop traversal
135pub fn get_related_multi_hop(
136    conn: &Connection,
137    start_id: MemoryId,
138    options: &TraversalOptions,
139) -> Result<TraversalResult> {
140    let mut visited: HashSet<MemoryId> = HashSet::new();
141    let mut nodes: Vec<TraversalNode> = Vec::new();
142    let mut discovery_edges: Vec<CrossReference> = Vec::new();
143    let mut stats = TraversalStats::default();
144
145    // Queue: (memory_id, depth, path, edge_path, cumulative_score)
146    let mut queue: VecDeque<(MemoryId, usize, Vec<MemoryId>, Vec<String>, f32)> = VecDeque::new();
147
148    // Start with the origin node
149    visited.insert(start_id);
150    nodes.push(TraversalNode {
151        memory_id: start_id,
152        depth: 0,
153        path: vec![start_id],
154        edge_path: vec![],
155        cumulative_score: 1.0,
156        connection_type: ConnectionType::Origin,
157    });
158    queue.push_back((start_id, 0, vec![start_id], vec![], 1.0));
159
160    *stats.nodes_per_depth.entry(0).or_insert(0) += 1;
161    *stats
162        .connection_type_counts
163        .entry("origin".to_string())
164        .or_insert(0) += 1;
165
166    // Level-based BFS traversal
167    while !queue.is_empty() {
168        let level_size = queue.len();
169        let mut current_batch = Vec::with_capacity(level_size);
170        for _ in 0..level_size {
171            if let Some(item) = queue.pop_front() {
172                current_batch.push(item);
173            }
174        }
175
176        if current_batch.is_empty() {
177            break;
178        }
179
180        // All nodes in this batch should be at the same depth
181        let current_depth = current_batch[0].1;
182
183        if current_depth >= options.depth {
184            continue;
185        }
186
187        let node_ids: Vec<MemoryId> = current_batch.iter().map(|(id, _, _, _, _)| *id).collect();
188
189        // Batch fetch cross-reference edges (with SQL-level per-node limiting)
190        let crossrefs_map = get_edges_for_traversal_batch(
191            conn,
192            &node_ids,
193            &options.edge_types,
194            options.min_score,
195            options.min_confidence,
196            options.direction,
197            options.limit_per_hop,
198        )?;
199
200        // Batch fetch entity-based connections if enabled
201        let entity_connections_map = if options.include_entities {
202            get_entity_connections_batch(conn, &node_ids, options.limit_per_hop)?
203        } else {
204            HashMap::new()
205        };
206
207        // Process each node in the batch
208        for (current_id, _current_depth, current_path, current_edge_path, current_score) in
209            current_batch
210        {
211            // Process cross-references (already limited per-node in SQL)
212            if let Some(crossrefs) = crossrefs_map.get(&current_id) {
213                for crossref in crossrefs.iter() {
214                    // Determine the neighbor ID based on direction
215                    let neighbor_id = if crossref.from_id == current_id {
216                        crossref.to_id
217                    } else {
218                        crossref.from_id
219                    };
220
221                    if visited.contains(&neighbor_id) {
222                        continue;
223                    }
224
225                    visited.insert(neighbor_id);
226
227                    let mut new_path = current_path.clone();
228                    new_path.push(neighbor_id);
229
230                    let mut new_edge_path = current_edge_path.clone();
231                    new_edge_path.push(crossref.edge_type.as_str().to_string());
232
233                    let new_score = current_score * crossref.score * crossref.confidence;
234                    let new_depth = current_depth + 1;
235
236                    nodes.push(TraversalNode {
237                        memory_id: neighbor_id,
238                        depth: new_depth,
239                        path: new_path.clone(),
240                        edge_path: new_edge_path.clone(),
241                        cumulative_score: new_score,
242                        connection_type: ConnectionType::CrossReference,
243                    });
244
245                    discovery_edges.push(crossref.clone());
246
247                    *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
248                    *stats
249                        .connection_type_counts
250                        .entry("cross_reference".to_string())
251                        .or_insert(0) += 1;
252
253                    if new_depth < options.depth {
254                        queue.push_back((
255                            neighbor_id,
256                            new_depth,
257                            new_path,
258                            new_edge_path,
259                            new_score,
260                        ));
261                    }
262
263                    stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
264                }
265            }
266
267            // Process entity connections
268            if let Some(entity_connections) = entity_connections_map.get(&current_id) {
269                for (neighbor_id, entity_name) in
270                    entity_connections.iter().take(options.limit_per_hop)
271                {
272                    let neighbor_id = *neighbor_id;
273                    if visited.contains(&neighbor_id) {
274                        continue;
275                    }
276
277                    visited.insert(neighbor_id);
278
279                    let mut new_path = current_path.clone();
280                    new_path.push(neighbor_id);
281
282                    let mut new_edge_path = current_edge_path.clone();
283                    new_edge_path.push(format!("entity:{}", entity_name));
284
285                    let new_depth = current_depth + 1;
286                    // Entity connections get a base score of 0.5
287                    let new_score = current_score * 0.5;
288
289                    nodes.push(TraversalNode {
290                        memory_id: neighbor_id,
291                        depth: new_depth,
292                        path: new_path.clone(),
293                        edge_path: new_edge_path.clone(),
294                        cumulative_score: new_score,
295                        connection_type: ConnectionType::SharedEntity {
296                            entity_name: entity_name.clone(),
297                        },
298                    });
299
300                    *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
301                    *stats
302                        .connection_type_counts
303                        .entry("shared_entity".to_string())
304                        .or_insert(0) += 1;
305
306                    if new_depth < options.depth {
307                        queue.push_back((
308                            neighbor_id,
309                            new_depth,
310                            new_path,
311                            new_edge_path,
312                            new_score,
313                        ));
314                    }
315
316                    stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
317                }
318            }
319        }
320    }
321
322    stats.nodes_visited = nodes.len();
323
324    Ok(TraversalResult {
325        start_id,
326        nodes,
327        discovery_edges,
328        stats,
329    })
330}
331
332/// Get edges for multiple memory IDs with per-node SQL limiting
333///
334/// Uses ROW_NUMBER() window function to limit results per source node in SQL,
335/// preventing memory/time blowup on high-degree nodes.
336fn get_edges_for_traversal_batch(
337    conn: &Connection,
338    memory_ids: &[MemoryId],
339    edge_types: &[EdgeType],
340    min_score: f32,
341    min_confidence: f32,
342    direction: TraversalDirection,
343    limit_per_node: usize,
344) -> Result<HashMap<MemoryId, Vec<CrossReference>>> {
345    if memory_ids.is_empty() {
346        return Ok(HashMap::new());
347    }
348
349    let mut result: HashMap<MemoryId, Vec<CrossReference>> = HashMap::new();
350    let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
351
352    // SQLite limit safety: chunk the IDs
353    for chunk in memory_ids.chunks(100) {
354        let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
355
356        let edge_type_clause = if edge_types.is_empty() {
357            String::new()
358        } else {
359            let types: Vec<String> = edge_types
360                .iter()
361                .map(|e| format!("'{}'", e.as_str()))
362                .collect();
363            format!(" AND edge_type IN ({})", types.join(", "))
364        };
365
366        // Build query based on direction, using ROW_NUMBER() to limit per source node
367        let (partition_col, filter_clause) = match direction {
368            TraversalDirection::Outgoing => ("from_id", format!("from_id IN ({})", placeholders)),
369            TraversalDirection::Incoming => ("to_id", format!("to_id IN ({})", placeholders)),
370            TraversalDirection::Both => {
371                // For Both direction, we need a UNION approach to properly partition
372                // by source node from both directions
373                let query = format!(
374                    r#"
375                    WITH ranked_edges AS (
376                        SELECT *, ROW_NUMBER() OVER (
377                            PARTITION BY from_id ORDER BY score * confidence DESC
378                        ) as rn
379                        FROM crossrefs
380                        WHERE from_id IN ({placeholders}) AND valid_to IS NULL
381                          AND score >= ? AND confidence >= ?
382                          {edge_type_clause}
383                        UNION ALL
384                        SELECT *, ROW_NUMBER() OVER (
385                            PARTITION BY to_id ORDER BY score * confidence DESC
386                        ) as rn
387                        FROM crossrefs
388                        WHERE to_id IN ({placeholders}) AND from_id NOT IN ({placeholders}) AND valid_to IS NULL
389                          AND score >= ? AND confidence >= ?
390                          {edge_type_clause}
391                    )
392                    SELECT from_id, to_id, edge_type, score, confidence, strength, source,
393                           source_context, created_at, valid_from, valid_to, pinned, metadata
394                    FROM ranked_edges
395                    WHERE rn <= ?
396                    "#,
397                    placeholders = placeholders,
398                    edge_type_clause = edge_type_clause,
399                );
400
401                let mut stmt = conn.prepare(&query)?;
402                let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
403
404                // First subquery params: from_id IN, min_score, min_confidence
405                for id in chunk {
406                    params.push(Box::new(*id));
407                }
408                params.push(Box::new(min_score));
409                params.push(Box::new(min_confidence));
410
411                // Second subquery params: to_id IN, from_id NOT IN, min_score, min_confidence
412                for id in chunk {
413                    params.push(Box::new(*id));
414                }
415                for id in chunk {
416                    params.push(Box::new(*id));
417                }
418                params.push(Box::new(min_score));
419                params.push(Box::new(min_confidence));
420
421                // Limit param
422                params.push(Box::new(limit_per_node as i64));
423
424                let param_refs: Vec<&dyn rusqlite::ToSql> =
425                    params.iter().map(|p| p.as_ref()).collect();
426
427                let crossrefs = stmt
428                    .query_map(param_refs.as_slice(), crossref_from_row)?
429                    .filter_map(|r| r.ok());
430
431                for crossref in crossrefs {
432                    if id_set.contains(&crossref.from_id) {
433                        result
434                            .entry(crossref.from_id)
435                            .or_default()
436                            .push(crossref.clone());
437                    }
438                    if id_set.contains(&crossref.to_id) && crossref.from_id != crossref.to_id {
439                        result.entry(crossref.to_id).or_default().push(crossref);
440                    }
441                }
442
443                continue; // Skip the common path below for Both direction
444            }
445        };
446
447        // Common path for Outgoing and Incoming directions
448        let query = format!(
449            r#"
450            WITH ranked_edges AS (
451                SELECT *, ROW_NUMBER() OVER (
452                    PARTITION BY {partition_col} ORDER BY score * confidence DESC
453                ) as rn
454                FROM crossrefs
455                WHERE {filter_clause} AND valid_to IS NULL
456                  AND score >= ? AND confidence >= ?
457                  {edge_type_clause}
458            )
459            SELECT from_id, to_id, edge_type, score, confidence, strength, source,
460                   source_context, created_at, valid_from, valid_to, pinned, metadata
461            FROM ranked_edges
462            WHERE rn <= ?
463            "#,
464            partition_col = partition_col,
465            filter_clause = filter_clause,
466            edge_type_clause = edge_type_clause,
467        );
468
469        let mut stmt = conn.prepare(&query)?;
470
471        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
472        for id in chunk {
473            params.push(Box::new(*id));
474        }
475        params.push(Box::new(min_score));
476        params.push(Box::new(min_confidence));
477        params.push(Box::new(limit_per_node as i64));
478
479        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
480
481        let crossrefs = stmt
482            .query_map(param_refs.as_slice(), crossref_from_row)?
483            .filter_map(|r| r.ok());
484
485        for crossref in crossrefs {
486            match direction {
487                TraversalDirection::Outgoing => {
488                    if id_set.contains(&crossref.from_id) {
489                        result.entry(crossref.from_id).or_default().push(crossref);
490                    }
491                }
492                TraversalDirection::Incoming => {
493                    if id_set.contains(&crossref.to_id) {
494                        result.entry(crossref.to_id).or_default().push(crossref);
495                    }
496                }
497                TraversalDirection::Both => unreachable!(), // Handled above with continue
498            }
499        }
500    }
501
502    Ok(result)
503}
504
505/// Get memories connected through shared entities for multiple memory IDs
506fn get_entity_connections_batch(
507    conn: &Connection,
508    memory_ids: &[MemoryId],
509    _limit: usize,
510) -> Result<HashMap<MemoryId, Vec<(MemoryId, String)>>> {
511    if memory_ids.is_empty() {
512        return Ok(HashMap::new());
513    }
514
515    let mut result: HashMap<MemoryId, Vec<(MemoryId, String)>> = HashMap::new();
516    let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
517
518    for chunk in memory_ids.chunks(100) {
519        let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
520
521        let query = format!(
522            r#"
523            SELECT DISTINCT me1.memory_id, me2.memory_id, e.name
524            FROM memory_entities me1
525            JOIN memory_entities me2 ON me1.entity_id = me2.entity_id
526            JOIN entities e ON me1.entity_id = e.id
527            WHERE me1.memory_id IN ({}) AND me2.memory_id != me1.memory_id
528            ORDER BY e.mention_count DESC
529            "#,
530            placeholders
531        );
532
533        let mut stmt = conn.prepare(&query)?;
534
535        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
536        for id in chunk {
537            params.push(Box::new(*id));
538        }
539
540        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
541
542        let rows = stmt
543            .query_map(param_refs.as_slice(), |row| {
544                Ok((
545                    row.get::<_, i64>(0)?,
546                    row.get::<_, i64>(1)?,
547                    row.get::<_, String>(2)?,
548                ))
549            })?
550            .filter_map(|r| r.ok());
551
552        for (source_id, target_id, entity_name) in rows {
553            if id_set.contains(&source_id) {
554                result
555                    .entry(source_id)
556                    .or_default()
557                    .push((target_id, entity_name));
558            }
559        }
560    }
561
562    Ok(result)
563}
564
565/// Helper to parse CrossReference from row
566fn crossref_from_row(row: &rusqlite::Row) -> rusqlite::Result<CrossReference> {
567    let edge_type_str: String = row.get("edge_type")?;
568    let source_str: String = row.get("source")?;
569    let created_at_str: String = row.get("created_at")?;
570    let valid_from_str: String = row.get("valid_from")?;
571    let valid_to_str: Option<String> = row.get("valid_to")?;
572    let metadata_str: String = row.get("metadata")?;
573
574    Ok(CrossReference {
575        from_id: row.get("from_id")?,
576        to_id: row.get("to_id")?,
577        edge_type: edge_type_str.parse().unwrap_or(EdgeType::RelatedTo),
578        score: row.get("score")?,
579        confidence: row.get("confidence")?,
580        strength: row.get("strength")?,
581        source: match source_str.as_str() {
582            "manual" => RelationSource::Manual,
583            "llm" => RelationSource::Llm,
584            _ => RelationSource::Auto,
585        },
586        source_context: row.get("source_context")?,
587        created_at: DateTime::parse_from_rfc3339(&created_at_str)
588            .map(|dt| dt.with_timezone(&Utc))
589            .unwrap_or_else(|_| Utc::now()),
590        valid_from: DateTime::parse_from_rfc3339(&valid_from_str)
591            .map(|dt| dt.with_timezone(&Utc))
592            .unwrap_or_else(|_| Utc::now()),
593        valid_to: valid_to_str.and_then(|s| {
594            DateTime::parse_from_rfc3339(&s)
595                .map(|dt| dt.with_timezone(&Utc))
596                .ok()
597        }),
598        pinned: row.get::<_, i32>("pinned")? != 0,
599        metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
600    })
601}
602
603/// Find shortest path between two memories
604pub fn find_path(
605    conn: &Connection,
606    from_id: MemoryId,
607    to_id: MemoryId,
608    max_depth: usize,
609) -> Result<Option<TraversalNode>> {
610    let options = TraversalOptions {
611        depth: max_depth,
612        include_entities: true,
613        ..Default::default()
614    };
615
616    let result = get_related_multi_hop(conn, from_id, &options)?;
617
618    // Find the target node in results
619    Ok(result.nodes.into_iter().find(|n| n.memory_id == to_id))
620}
621
622/// Get all memories within a certain graph distance
623pub fn get_neighborhood(
624    conn: &Connection,
625    center_id: MemoryId,
626    radius: usize,
627) -> Result<Vec<MemoryId>> {
628    let options = TraversalOptions {
629        depth: radius,
630        include_entities: true,
631        ..Default::default()
632    };
633
634    let result = get_related_multi_hop(conn, center_id, &options)?;
635
636    Ok(result.nodes.into_iter().map(|n| n.memory_id).collect())
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use crate::intelligence::entities::{EntityRelation, EntityType, ExtractedEntity};
643    use crate::storage::entity_queries::{link_entity_to_memory, upsert_entity};
644    use crate::storage::queries::{create_crossref, create_memory};
645    use crate::storage::Storage;
646    use crate::types::{CreateCrossRefInput, CreateMemoryInput, MemoryType};
647
648    fn create_test_memory(conn: &Connection, content: &str) -> MemoryId {
649        let input = CreateMemoryInput {
650            content: content.to_string(),
651            memory_type: MemoryType::Note,
652            tags: vec![],
653            importance: None,
654            metadata: Default::default(),
655            scope: Default::default(),
656            workspace: None,
657            tier: Default::default(),
658            defer_embedding: false,
659            ttl_seconds: None,
660            dedup_mode: Default::default(),
661            dedup_threshold: None,
662            event_time: None,
663            event_duration_seconds: None,
664            trigger_pattern: None,
665            summary_of_id: None,
666        };
667        create_memory(conn, &input).unwrap().id
668    }
669
670    fn create_test_crossref(
671        conn: &Connection,
672        from_id: MemoryId,
673        to_id: MemoryId,
674        edge_type: EdgeType,
675    ) -> crate::error::Result<()> {
676        let input = CreateCrossRefInput {
677            from_id,
678            to_id,
679            edge_type,
680            strength: None,
681            source_context: None,
682            pinned: false,
683        };
684        create_crossref(conn, &input)?;
685        Ok(())
686    }
687
688    #[test]
689    fn test_multi_hop_traversal() {
690        let storage = Storage::open_in_memory().unwrap();
691        storage
692            .with_transaction(|conn| {
693                // Create a chain: A -> B -> C -> D
694                let id_a = create_test_memory(conn, "Memory A");
695                let id_b = create_test_memory(conn, "Memory B");
696                let id_c = create_test_memory(conn, "Memory C");
697                let id_d = create_test_memory(conn, "Memory D");
698
699                // Create edges
700                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
701                create_test_crossref(conn, id_b, id_c, EdgeType::RelatedTo)?;
702                create_test_crossref(conn, id_c, id_d, EdgeType::RelatedTo)?;
703
704                // Traverse from A with depth 1 - should only reach B
705                let options = TraversalOptions {
706                    depth: 1,
707                    include_entities: false,
708                    ..Default::default()
709                };
710                let result = get_related_multi_hop(conn, id_a, &options)?;
711                assert_eq!(result.nodes.len(), 2); // A + B
712                assert!(result.nodes.iter().any(|n| n.memory_id == id_a));
713                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
714
715                // Traverse from A with depth 2 - should reach B and C
716                let options = TraversalOptions {
717                    depth: 2,
718                    include_entities: false,
719                    ..Default::default()
720                };
721                let result = get_related_multi_hop(conn, id_a, &options)?;
722                assert_eq!(result.nodes.len(), 3); // A + B + C
723                assert!(result.nodes.iter().any(|n| n.memory_id == id_c));
724
725                // Traverse from A with depth 3 - should reach all
726                let options = TraversalOptions {
727                    depth: 3,
728                    include_entities: false,
729                    ..Default::default()
730                };
731                let result = get_related_multi_hop(conn, id_a, &options)?;
732                assert_eq!(result.nodes.len(), 4); // A + B + C + D
733
734                Ok(())
735            })
736            .unwrap();
737    }
738
739    #[test]
740    fn test_entity_based_connections() {
741        let storage = Storage::open_in_memory().unwrap();
742        storage
743            .with_transaction(|conn| {
744                // Create memories
745                let id_a = create_test_memory(conn, "Memory about Rust programming");
746                let id_b = create_test_memory(conn, "Another memory about Rust");
747                let id_c = create_test_memory(conn, "Memory about Python");
748
749                // Create shared entity using ExtractedEntity
750                let entity = ExtractedEntity {
751                    text: "Rust".to_string(),
752                    normalized: "rust".to_string(),
753                    entity_type: EntityType::Concept,
754                    confidence: 0.9,
755                    offset: 0,
756                    length: 4,
757                    suggested_relation: EntityRelation::Mentions,
758                };
759                let entity_id = upsert_entity(conn, &entity)?;
760                let _ = link_entity_to_memory(
761                    conn,
762                    id_a,
763                    entity_id,
764                    EntityRelation::Mentions,
765                    0.9,
766                    None,
767                )?;
768                let _ = link_entity_to_memory(
769                    conn,
770                    id_b,
771                    entity_id,
772                    EntityRelation::Mentions,
773                    0.8,
774                    None,
775                )?;
776
777                // Traverse from A with entities enabled
778                let options = TraversalOptions {
779                    depth: 1,
780                    include_entities: true,
781                    ..Default::default()
782                };
783                let result = get_related_multi_hop(conn, id_a, &options)?;
784
785                // Should find B through shared entity
786                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
787                let b_node = result.nodes.iter().find(|n| n.memory_id == id_b).unwrap();
788                assert!(matches!(
789                    &b_node.connection_type,
790                    ConnectionType::SharedEntity { entity_name } if entity_name == "Rust"
791                ));
792
793                // Should NOT find C (no shared entity)
794                assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
795
796                Ok(())
797            })
798            .unwrap();
799    }
800
801    #[test]
802    fn test_find_path() {
803        let storage = Storage::open_in_memory().unwrap();
804        storage
805            .with_transaction(|conn| {
806                let id_a = create_test_memory(conn, "Start");
807                let id_b = create_test_memory(conn, "Middle");
808                let id_c = create_test_memory(conn, "End");
809
810                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
811                create_test_crossref(conn, id_b, id_c, EdgeType::DependsOn)?;
812
813                let path = find_path(conn, id_a, id_c, 5)?;
814                assert!(path.is_some());
815                let path = path.unwrap();
816                assert_eq!(path.memory_id, id_c);
817                assert_eq!(path.depth, 2);
818                assert_eq!(path.path.len(), 3);
819                assert_eq!(path.path, vec![id_a, id_b, id_c]);
820
821                Ok(())
822            })
823            .unwrap();
824    }
825
826    #[test]
827    fn test_traversal_direction() {
828        let storage = Storage::open_in_memory().unwrap();
829        storage
830            .with_transaction(|conn| {
831                let id_a = create_test_memory(conn, "A");
832                let id_b = create_test_memory(conn, "B");
833                let id_c = create_test_memory(conn, "C");
834
835                // A -> B and C -> B (B has incoming from both)
836                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
837                create_test_crossref(conn, id_c, id_b, EdgeType::RelatedTo)?;
838
839                // Outgoing from B - should find nothing (B has no outgoing)
840                let options = TraversalOptions {
841                    depth: 1,
842                    direction: TraversalDirection::Outgoing,
843                    include_entities: false,
844                    ..Default::default()
845                };
846                let result = get_related_multi_hop(conn, id_b, &options)?;
847                assert_eq!(result.nodes.len(), 1); // Just B itself
848
849                // Incoming to B - should find A and C
850                let options = TraversalOptions {
851                    depth: 1,
852                    direction: TraversalDirection::Incoming,
853                    include_entities: false,
854                    ..Default::default()
855                };
856                let result = get_related_multi_hop(conn, id_b, &options)?;
857                assert_eq!(result.nodes.len(), 3); // B, A, C
858
859                Ok(())
860            })
861            .unwrap();
862    }
863
864    #[test]
865    fn test_edge_type_filter() {
866        let storage = Storage::open_in_memory().unwrap();
867        storage
868            .with_transaction(|conn| {
869                let id_a = create_test_memory(conn, "A");
870                let id_b = create_test_memory(conn, "B");
871                let id_c = create_test_memory(conn, "C");
872
873                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
874                create_test_crossref(conn, id_a, id_c, EdgeType::DependsOn)?;
875
876                // Filter to only RelatedTo edges
877                let options = TraversalOptions {
878                    depth: 1,
879                    edge_types: vec![EdgeType::RelatedTo],
880                    include_entities: false,
881                    ..Default::default()
882                };
883                let result = get_related_multi_hop(conn, id_a, &options)?;
884                assert_eq!(result.nodes.len(), 2); // A + B only
885                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
886                assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
887
888                Ok(())
889            })
890            .unwrap();
891    }
892}