Skip to main content

engram/storage/
graph_queries.rs

1//! Multi-hop graph traversal queries
2//!
3//! Provides graph traversal capabilities for exploring memory relationships
4//! at various depths, with support for filtering by edge type and combining
5//! with entity-based connections.
6
7use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10
11use crate::error::Result;
12use crate::types::{CrossReference, EdgeType, MemoryId, RelationSource};
13use chrono::{DateTime, Utc};
14
15/// Options for multi-hop graph traversal
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TraversalOptions {
18    /// Maximum traversal depth (1 = direct relations only)
19    #[serde(default = "default_depth")]
20    pub depth: usize,
21    /// Filter by edge types (empty = all types)
22    #[serde(default)]
23    pub edge_types: Vec<EdgeType>,
24    /// Minimum score threshold
25    #[serde(default)]
26    pub min_score: f32,
27    /// Minimum confidence threshold
28    #[serde(default)]
29    pub min_confidence: f32,
30    /// Maximum number of results per hop
31    #[serde(default = "default_limit_per_hop")]
32    pub limit_per_hop: usize,
33    /// Include entity-based connections
34    #[serde(default = "default_include_entities")]
35    pub include_entities: bool,
36    /// Direction of traversal
37    #[serde(default)]
38    pub direction: TraversalDirection,
39}
40
41fn default_depth() -> usize {
42    2
43}
44
45fn default_limit_per_hop() -> usize {
46    50
47}
48
49fn default_include_entities() -> bool {
50    true
51}
52
53impl Default for TraversalOptions {
54    fn default() -> Self {
55        Self {
56            depth: 2,
57            edge_types: vec![],
58            min_score: 0.0,
59            min_confidence: 0.0,
60            limit_per_hop: 50,
61            include_entities: true,
62            direction: TraversalDirection::Both,
63        }
64    }
65}
66
67/// Direction of graph traversal
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum TraversalDirection {
71    /// Follow outgoing edges only (from -> to)
72    Outgoing,
73    /// Follow incoming edges only (to -> from)
74    Incoming,
75    /// Follow edges in both directions
76    #[default]
77    Both,
78}
79
80/// A node in the traversal result with path information
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct TraversalNode {
83    /// Memory ID
84    pub memory_id: MemoryId,
85    /// Depth from the starting node (0 = start node)
86    pub depth: usize,
87    /// Path of memory IDs from start to this node
88    pub path: Vec<MemoryId>,
89    /// Edge types along the path
90    pub edge_path: Vec<String>,
91    /// Cumulative score (product of edge scores)
92    pub cumulative_score: f32,
93    /// How this node was reached
94    pub connection_type: ConnectionType,
95}
96
97/// How a node was reached during traversal
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
99#[serde(rename_all = "snake_case")]
100pub enum ConnectionType {
101    /// Starting node
102    Origin,
103    /// Connected via cross-reference edge
104    CrossReference,
105    /// Connected via shared entity
106    SharedEntity { entity_name: String },
107}
108
109/// Result of a multi-hop traversal
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct TraversalResult {
112    /// Starting memory ID
113    pub start_id: MemoryId,
114    /// All nodes found during traversal
115    pub nodes: Vec<TraversalNode>,
116    /// Edges that led to newly discovered nodes
117    pub discovery_edges: Vec<CrossReference>,
118    /// Statistics about the traversal
119    pub stats: TraversalStats,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Default)]
123pub struct TraversalStats {
124    /// Total nodes visited
125    pub nodes_visited: usize,
126    /// Nodes at each depth level
127    pub nodes_per_depth: HashMap<usize, usize>,
128    /// Count by connection type
129    pub connection_type_counts: HashMap<String, usize>,
130    /// Maximum depth reached
131    pub max_depth_reached: usize,
132}
133
134/// Get related memories with multi-hop traversal
135pub fn get_related_multi_hop(
136    conn: &Connection,
137    start_id: MemoryId,
138    options: &TraversalOptions,
139) -> Result<TraversalResult> {
140    let mut visited: HashSet<MemoryId> = HashSet::new();
141    let mut nodes: Vec<TraversalNode> = Vec::new();
142    let mut discovery_edges: Vec<CrossReference> = Vec::new();
143    let mut stats = TraversalStats::default();
144
145    // Queue: (memory_id, depth, path, edge_path, cumulative_score)
146    let mut queue: VecDeque<(MemoryId, usize, Vec<MemoryId>, Vec<String>, f32)> = VecDeque::new();
147
148    // Start with the origin node
149    visited.insert(start_id);
150    nodes.push(TraversalNode {
151        memory_id: start_id,
152        depth: 0,
153        path: vec![start_id],
154        edge_path: vec![],
155        cumulative_score: 1.0,
156        connection_type: ConnectionType::Origin,
157    });
158    queue.push_back((start_id, 0, vec![start_id], vec![], 1.0));
159
160    *stats.nodes_per_depth.entry(0).or_insert(0) += 1;
161    *stats
162        .connection_type_counts
163        .entry("origin".to_string())
164        .or_insert(0) += 1;
165
166    // Level-based BFS traversal
167    while !queue.is_empty() {
168        let level_size = queue.len();
169        let mut current_batch = Vec::with_capacity(level_size);
170        for _ in 0..level_size {
171            if let Some(item) = queue.pop_front() {
172                current_batch.push(item);
173            }
174        }
175
176        if current_batch.is_empty() {
177            break;
178        }
179
180        // All nodes in this batch should be at the same depth
181        let current_depth = current_batch[0].1;
182
183        if current_depth >= options.depth {
184            continue;
185        }
186
187        let node_ids: Vec<MemoryId> = current_batch.iter().map(|(id, _, _, _, _)| *id).collect();
188
189        // Batch fetch cross-reference edges (with SQL-level per-node limiting)
190        let crossrefs_map = get_edges_for_traversal_batch(
191            conn,
192            &node_ids,
193            &options.edge_types,
194            options.min_score,
195            options.min_confidence,
196            options.direction,
197            options.limit_per_hop,
198        )?;
199
200        // Batch fetch entity-based connections if enabled
201        let entity_connections_map = if options.include_entities {
202            get_entity_connections_batch(conn, &node_ids, options.limit_per_hop)?
203        } else {
204            HashMap::new()
205        };
206
207        // Process each node in the batch
208        for (current_id, _current_depth, current_path, current_edge_path, current_score) in
209            current_batch
210        {
211            // Process cross-references (already limited per-node in SQL)
212            if let Some(crossrefs) = crossrefs_map.get(&current_id) {
213                for crossref in crossrefs.iter() {
214                    // Determine the neighbor ID based on direction
215                    let neighbor_id = if crossref.from_id == current_id {
216                        crossref.to_id
217                    } else {
218                        crossref.from_id
219                    };
220
221                    if visited.contains(&neighbor_id) {
222                        continue;
223                    }
224
225                    visited.insert(neighbor_id);
226
227                    let mut new_path = current_path.clone();
228                    new_path.push(neighbor_id);
229
230                    let mut new_edge_path = current_edge_path.clone();
231                    new_edge_path.push(crossref.edge_type.as_str().to_string());
232
233                    let new_score = current_score * crossref.score * crossref.confidence;
234                    let new_depth = current_depth + 1;
235
236                    nodes.push(TraversalNode {
237                        memory_id: neighbor_id,
238                        depth: new_depth,
239                        path: new_path.clone(),
240                        edge_path: new_edge_path.clone(),
241                        cumulative_score: new_score,
242                        connection_type: ConnectionType::CrossReference,
243                    });
244
245                    discovery_edges.push(crossref.clone());
246
247                    *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
248                    *stats
249                        .connection_type_counts
250                        .entry("cross_reference".to_string())
251                        .or_insert(0) += 1;
252
253                    if new_depth < options.depth {
254                        queue.push_back((
255                            neighbor_id,
256                            new_depth,
257                            new_path,
258                            new_edge_path,
259                            new_score,
260                        ));
261                    }
262
263                    stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
264                }
265            }
266
267            // Process entity connections
268            if let Some(entity_connections) = entity_connections_map.get(&current_id) {
269                for (neighbor_id, entity_name) in
270                    entity_connections.iter().take(options.limit_per_hop)
271                {
272                    let neighbor_id = *neighbor_id;
273                    if visited.contains(&neighbor_id) {
274                        continue;
275                    }
276
277                    visited.insert(neighbor_id);
278
279                    let mut new_path = current_path.clone();
280                    new_path.push(neighbor_id);
281
282                    let mut new_edge_path = current_edge_path.clone();
283                    new_edge_path.push(format!("entity:{}", entity_name));
284
285                    let new_depth = current_depth + 1;
286                    // Entity connections get a base score of 0.5
287                    let new_score = current_score * 0.5;
288
289                    nodes.push(TraversalNode {
290                        memory_id: neighbor_id,
291                        depth: new_depth,
292                        path: new_path.clone(),
293                        edge_path: new_edge_path.clone(),
294                        cumulative_score: new_score,
295                        connection_type: ConnectionType::SharedEntity {
296                            entity_name: entity_name.clone(),
297                        },
298                    });
299
300                    *stats.nodes_per_depth.entry(new_depth).or_insert(0) += 1;
301                    *stats
302                        .connection_type_counts
303                        .entry("shared_entity".to_string())
304                        .or_insert(0) += 1;
305
306                    if new_depth < options.depth {
307                        queue.push_back((
308                            neighbor_id,
309                            new_depth,
310                            new_path,
311                            new_edge_path,
312                            new_score,
313                        ));
314                    }
315
316                    stats.max_depth_reached = stats.max_depth_reached.max(new_depth);
317                }
318            }
319        }
320    }
321
322    stats.nodes_visited = nodes.len();
323
324    Ok(TraversalResult {
325        start_id,
326        nodes,
327        discovery_edges,
328        stats,
329    })
330}
331
332/// Get edges for multiple memory IDs with per-node SQL limiting
333///
334/// Uses ROW_NUMBER() window function to limit results per source node in SQL,
335/// preventing memory/time blowup on high-degree nodes.
336fn get_edges_for_traversal_batch(
337    conn: &Connection,
338    memory_ids: &[MemoryId],
339    edge_types: &[EdgeType],
340    min_score: f32,
341    min_confidence: f32,
342    direction: TraversalDirection,
343    limit_per_node: usize,
344) -> Result<HashMap<MemoryId, Vec<CrossReference>>> {
345    if memory_ids.is_empty() {
346        return Ok(HashMap::new());
347    }
348
349    let mut result: HashMap<MemoryId, Vec<CrossReference>> = HashMap::new();
350    let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
351
352    // SQLite limit safety: chunk the IDs
353    for chunk in memory_ids.chunks(100) {
354        let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
355
356        let edge_type_clause = if edge_types.is_empty() {
357            String::new()
358        } else {
359            let types: Vec<String> = edge_types
360                .iter()
361                .map(|e| format!("'{}'", e.as_str()))
362                .collect();
363            format!(" AND edge_type IN ({})", types.join(", "))
364        };
365
366        // Build query based on direction, using ROW_NUMBER() to limit per source node
367        let (partition_col, filter_clause) = match direction {
368            TraversalDirection::Outgoing => ("from_id", format!("from_id IN ({})", placeholders)),
369            TraversalDirection::Incoming => ("to_id", format!("to_id IN ({})", placeholders)),
370            TraversalDirection::Both => {
371                // For Both direction, we need a UNION approach to properly partition
372                // by source node from both directions
373                let query = format!(
374                    r#"
375                    WITH ranked_edges AS (
376                        SELECT *, ROW_NUMBER() OVER (
377                            PARTITION BY from_id ORDER BY score * confidence DESC
378                        ) as rn
379                        FROM crossrefs
380                        WHERE from_id IN ({placeholders}) AND valid_to IS NULL
381                          AND score >= ? AND confidence >= ?
382                          {edge_type_clause}
383                        UNION ALL
384                        SELECT *, ROW_NUMBER() OVER (
385                            PARTITION BY to_id ORDER BY score * confidence DESC
386                        ) as rn
387                        FROM crossrefs
388                        WHERE to_id IN ({placeholders}) AND from_id NOT IN ({placeholders}) AND valid_to IS NULL
389                          AND score >= ? AND confidence >= ?
390                          {edge_type_clause}
391                    )
392                    SELECT from_id, to_id, edge_type, score, confidence, strength, source,
393                           source_context, created_at, valid_from, valid_to, pinned, metadata
394                    FROM ranked_edges
395                    WHERE rn <= ?
396                    "#,
397                    placeholders = placeholders,
398                    edge_type_clause = edge_type_clause,
399                );
400
401                let mut stmt = conn.prepare(&query)?;
402                let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
403
404                // First subquery params: from_id IN, min_score, min_confidence
405                for id in chunk {
406                    params.push(Box::new(*id));
407                }
408                params.push(Box::new(min_score));
409                params.push(Box::new(min_confidence));
410
411                // Second subquery params: to_id IN, from_id NOT IN, min_score, min_confidence
412                for id in chunk {
413                    params.push(Box::new(*id));
414                }
415                for id in chunk {
416                    params.push(Box::new(*id));
417                }
418                params.push(Box::new(min_score));
419                params.push(Box::new(min_confidence));
420
421                // Limit param
422                params.push(Box::new(limit_per_node as i64));
423
424                let param_refs: Vec<&dyn rusqlite::ToSql> =
425                    params.iter().map(|p| p.as_ref()).collect();
426
427                let crossrefs = stmt
428                    .query_map(param_refs.as_slice(), crossref_from_row)?
429                    .filter_map(|r| r.ok());
430
431                for crossref in crossrefs {
432                    if id_set.contains(&crossref.from_id) {
433                        result
434                            .entry(crossref.from_id)
435                            .or_default()
436                            .push(crossref.clone());
437                    }
438                    if id_set.contains(&crossref.to_id) && crossref.from_id != crossref.to_id {
439                        result.entry(crossref.to_id).or_default().push(crossref);
440                    }
441                }
442
443                continue; // Skip the common path below for Both direction
444            }
445        };
446
447        // Common path for Outgoing and Incoming directions
448        let query = format!(
449            r#"
450            WITH ranked_edges AS (
451                SELECT *, ROW_NUMBER() OVER (
452                    PARTITION BY {partition_col} ORDER BY score * confidence DESC
453                ) as rn
454                FROM crossrefs
455                WHERE {filter_clause} AND valid_to IS NULL
456                  AND score >= ? AND confidence >= ?
457                  {edge_type_clause}
458            )
459            SELECT from_id, to_id, edge_type, score, confidence, strength, source,
460                   source_context, created_at, valid_from, valid_to, pinned, metadata
461            FROM ranked_edges
462            WHERE rn <= ?
463            "#,
464            partition_col = partition_col,
465            filter_clause = filter_clause,
466            edge_type_clause = edge_type_clause,
467        );
468
469        let mut stmt = conn.prepare(&query)?;
470
471        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
472        for id in chunk {
473            params.push(Box::new(*id));
474        }
475        params.push(Box::new(min_score));
476        params.push(Box::new(min_confidence));
477        params.push(Box::new(limit_per_node as i64));
478
479        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
480
481        let crossrefs = stmt
482            .query_map(param_refs.as_slice(), crossref_from_row)?
483            .filter_map(|r| r.ok());
484
485        for crossref in crossrefs {
486            match direction {
487                TraversalDirection::Outgoing => {
488                    if id_set.contains(&crossref.from_id) {
489                        result.entry(crossref.from_id).or_default().push(crossref);
490                    }
491                }
492                TraversalDirection::Incoming => {
493                    if id_set.contains(&crossref.to_id) {
494                        result.entry(crossref.to_id).or_default().push(crossref);
495                    }
496                }
497                TraversalDirection::Both => unreachable!(), // Handled above with continue
498            }
499        }
500    }
501
502    Ok(result)
503}
504
505/// Get memories connected through shared entities for multiple memory IDs
506fn get_entity_connections_batch(
507    conn: &Connection,
508    memory_ids: &[MemoryId],
509    _limit: usize,
510) -> Result<HashMap<MemoryId, Vec<(MemoryId, String)>>> {
511    if memory_ids.is_empty() {
512        return Ok(HashMap::new());
513    }
514
515    let mut result: HashMap<MemoryId, Vec<(MemoryId, String)>> = HashMap::new();
516    let id_set: HashSet<MemoryId> = memory_ids.iter().cloned().collect();
517
518    for chunk in memory_ids.chunks(100) {
519        let placeholders = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
520
521        let query = format!(
522            r#"
523            SELECT DISTINCT me1.memory_id, me2.memory_id, e.name
524            FROM memory_entities me1
525            JOIN memory_entities me2 ON me1.entity_id = me2.entity_id
526            JOIN entities e ON me1.entity_id = e.id
527            WHERE me1.memory_id IN ({}) AND me2.memory_id != me1.memory_id
528            ORDER BY e.mention_count DESC
529            "#,
530            placeholders
531        );
532
533        let mut stmt = conn.prepare(&query)?;
534
535        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
536        for id in chunk {
537            params.push(Box::new(*id));
538        }
539
540        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
541
542        let rows = stmt
543            .query_map(param_refs.as_slice(), |row| {
544                Ok((
545                    row.get::<_, i64>(0)?,
546                    row.get::<_, i64>(1)?,
547                    row.get::<_, String>(2)?,
548                ))
549            })?
550            .filter_map(|r| r.ok());
551
552        for (source_id, target_id, entity_name) in rows {
553            if id_set.contains(&source_id) {
554                result
555                    .entry(source_id)
556                    .or_default()
557                    .push((target_id, entity_name));
558            }
559        }
560    }
561
562    Ok(result)
563}
564
565/// Helper to parse CrossReference from row
566fn crossref_from_row(row: &rusqlite::Row) -> rusqlite::Result<CrossReference> {
567    let edge_type_str: String = row.get("edge_type")?;
568    let source_str: String = row.get("source")?;
569    let created_at_str: String = row.get("created_at")?;
570    let valid_from_str: String = row.get("valid_from")?;
571    let valid_to_str: Option<String> = row.get("valid_to")?;
572    let metadata_str: String = row.get("metadata")?;
573
574    Ok(CrossReference {
575        from_id: row.get("from_id")?,
576        to_id: row.get("to_id")?,
577        edge_type: edge_type_str.parse().unwrap_or(EdgeType::RelatedTo),
578        score: row.get("score")?,
579        confidence: row.get("confidence")?,
580        strength: row.get("strength")?,
581        source: match source_str.as_str() {
582            "manual" => RelationSource::Manual,
583            "llm" => RelationSource::Llm,
584            _ => RelationSource::Auto,
585        },
586        source_context: row.get("source_context")?,
587        created_at: DateTime::parse_from_rfc3339(&created_at_str)
588            .map(|dt| dt.with_timezone(&Utc))
589            .unwrap_or_else(|_| Utc::now()),
590        valid_from: DateTime::parse_from_rfc3339(&valid_from_str)
591            .map(|dt| dt.with_timezone(&Utc))
592            .unwrap_or_else(|_| Utc::now()),
593        valid_to: valid_to_str.and_then(|s| {
594            DateTime::parse_from_rfc3339(&s)
595                .map(|dt| dt.with_timezone(&Utc))
596                .ok()
597        }),
598        pinned: row.get::<_, i32>("pinned")? != 0,
599        metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
600    })
601}
602
603/// Find shortest path between two memories
604pub fn find_path(
605    conn: &Connection,
606    from_id: MemoryId,
607    to_id: MemoryId,
608    max_depth: usize,
609) -> Result<Option<TraversalNode>> {
610    let options = TraversalOptions {
611        depth: max_depth,
612        include_entities: true,
613        ..Default::default()
614    };
615
616    let result = get_related_multi_hop(conn, from_id, &options)?;
617
618    // Find the target node in results
619    Ok(result.nodes.into_iter().find(|n| n.memory_id == to_id))
620}
621
622/// Get all memories within a certain graph distance
623pub fn get_neighborhood(
624    conn: &Connection,
625    center_id: MemoryId,
626    radius: usize,
627) -> Result<Vec<MemoryId>> {
628    let options = TraversalOptions {
629        depth: radius,
630        include_entities: true,
631        ..Default::default()
632    };
633
634    let result = get_related_multi_hop(conn, center_id, &options)?;
635
636    Ok(result.nodes.into_iter().map(|n| n.memory_id).collect())
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use crate::intelligence::entities::{EntityRelation, EntityType, ExtractedEntity};
643    use crate::storage::entity_queries::{link_entity_to_memory, upsert_entity};
644    use crate::storage::queries::{create_crossref, create_memory};
645    use crate::storage::Storage;
646    use crate::types::{CreateCrossRefInput, CreateMemoryInput, MemoryType};
647
648    fn create_test_memory(conn: &Connection, content: &str) -> MemoryId {
649        let input = CreateMemoryInput {
650            content: content.to_string(),
651            memory_type: MemoryType::Note,
652            tags: vec![],
653            importance: None,
654            metadata: Default::default(),
655            scope: Default::default(),
656            workspace: None,
657            tier: Default::default(),
658            defer_embedding: false,
659            ttl_seconds: None,
660            dedup_mode: Default::default(),
661            dedup_threshold: None,
662            event_time: None,
663            event_duration_seconds: None,
664            trigger_pattern: None,
665            summary_of_id: None,
666                media_url: None,
667        };
668        create_memory(conn, &input).unwrap().id
669    }
670
671    fn create_test_crossref(
672        conn: &Connection,
673        from_id: MemoryId,
674        to_id: MemoryId,
675        edge_type: EdgeType,
676    ) -> crate::error::Result<()> {
677        let input = CreateCrossRefInput {
678            from_id,
679            to_id,
680            edge_type,
681            strength: None,
682            source_context: None,
683            pinned: false,
684        };
685        create_crossref(conn, &input)?;
686        Ok(())
687    }
688
689    #[test]
690    fn test_multi_hop_traversal() {
691        let storage = Storage::open_in_memory().unwrap();
692        storage
693            .with_transaction(|conn| {
694                // Create a chain: A -> B -> C -> D
695                let id_a = create_test_memory(conn, "Memory A");
696                let id_b = create_test_memory(conn, "Memory B");
697                let id_c = create_test_memory(conn, "Memory C");
698                let id_d = create_test_memory(conn, "Memory D");
699
700                // Create edges
701                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
702                create_test_crossref(conn, id_b, id_c, EdgeType::RelatedTo)?;
703                create_test_crossref(conn, id_c, id_d, EdgeType::RelatedTo)?;
704
705                // Traverse from A with depth 1 - should only reach B
706                let options = TraversalOptions {
707                    depth: 1,
708                    include_entities: false,
709                    ..Default::default()
710                };
711                let result = get_related_multi_hop(conn, id_a, &options)?;
712                assert_eq!(result.nodes.len(), 2); // A + B
713                assert!(result.nodes.iter().any(|n| n.memory_id == id_a));
714                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
715
716                // Traverse from A with depth 2 - should reach B and C
717                let options = TraversalOptions {
718                    depth: 2,
719                    include_entities: false,
720                    ..Default::default()
721                };
722                let result = get_related_multi_hop(conn, id_a, &options)?;
723                assert_eq!(result.nodes.len(), 3); // A + B + C
724                assert!(result.nodes.iter().any(|n| n.memory_id == id_c));
725
726                // Traverse from A with depth 3 - should reach all
727                let options = TraversalOptions {
728                    depth: 3,
729                    include_entities: false,
730                    ..Default::default()
731                };
732                let result = get_related_multi_hop(conn, id_a, &options)?;
733                assert_eq!(result.nodes.len(), 4); // A + B + C + D
734
735                Ok(())
736            })
737            .unwrap();
738    }
739
740    #[test]
741    fn test_entity_based_connections() {
742        let storage = Storage::open_in_memory().unwrap();
743        storage
744            .with_transaction(|conn| {
745                // Create memories
746                let id_a = create_test_memory(conn, "Memory about Rust programming");
747                let id_b = create_test_memory(conn, "Another memory about Rust");
748                let id_c = create_test_memory(conn, "Memory about Python");
749
750                // Create shared entity using ExtractedEntity
751                let entity = ExtractedEntity {
752                    text: "Rust".to_string(),
753                    normalized: "rust".to_string(),
754                    entity_type: EntityType::Concept,
755                    confidence: 0.9,
756                    offset: 0,
757                    length: 4,
758                    suggested_relation: EntityRelation::Mentions,
759                };
760                let entity_id = upsert_entity(conn, &entity)?;
761                let _ = link_entity_to_memory(
762                    conn,
763                    id_a,
764                    entity_id,
765                    EntityRelation::Mentions,
766                    0.9,
767                    None,
768                )?;
769                let _ = link_entity_to_memory(
770                    conn,
771                    id_b,
772                    entity_id,
773                    EntityRelation::Mentions,
774                    0.8,
775                    None,
776                )?;
777
778                // Traverse from A with entities enabled
779                let options = TraversalOptions {
780                    depth: 1,
781                    include_entities: true,
782                    ..Default::default()
783                };
784                let result = get_related_multi_hop(conn, id_a, &options)?;
785
786                // Should find B through shared entity
787                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
788                let b_node = result.nodes.iter().find(|n| n.memory_id == id_b).unwrap();
789                assert!(matches!(
790                    &b_node.connection_type,
791                    ConnectionType::SharedEntity { entity_name } if entity_name == "Rust"
792                ));
793
794                // Should NOT find C (no shared entity)
795                assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
796
797                Ok(())
798            })
799            .unwrap();
800    }
801
802    #[test]
803    fn test_find_path() {
804        let storage = Storage::open_in_memory().unwrap();
805        storage
806            .with_transaction(|conn| {
807                let id_a = create_test_memory(conn, "Start");
808                let id_b = create_test_memory(conn, "Middle");
809                let id_c = create_test_memory(conn, "End");
810
811                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
812                create_test_crossref(conn, id_b, id_c, EdgeType::DependsOn)?;
813
814                let path = find_path(conn, id_a, id_c, 5)?;
815                assert!(path.is_some());
816                let path = path.unwrap();
817                assert_eq!(path.memory_id, id_c);
818                assert_eq!(path.depth, 2);
819                assert_eq!(path.path.len(), 3);
820                assert_eq!(path.path, vec![id_a, id_b, id_c]);
821
822                Ok(())
823            })
824            .unwrap();
825    }
826
827    #[test]
828    fn test_traversal_direction() {
829        let storage = Storage::open_in_memory().unwrap();
830        storage
831            .with_transaction(|conn| {
832                let id_a = create_test_memory(conn, "A");
833                let id_b = create_test_memory(conn, "B");
834                let id_c = create_test_memory(conn, "C");
835
836                // A -> B and C -> B (B has incoming from both)
837                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
838                create_test_crossref(conn, id_c, id_b, EdgeType::RelatedTo)?;
839
840                // Outgoing from B - should find nothing (B has no outgoing)
841                let options = TraversalOptions {
842                    depth: 1,
843                    direction: TraversalDirection::Outgoing,
844                    include_entities: false,
845                    ..Default::default()
846                };
847                let result = get_related_multi_hop(conn, id_b, &options)?;
848                assert_eq!(result.nodes.len(), 1); // Just B itself
849
850                // Incoming to B - should find A and C
851                let options = TraversalOptions {
852                    depth: 1,
853                    direction: TraversalDirection::Incoming,
854                    include_entities: false,
855                    ..Default::default()
856                };
857                let result = get_related_multi_hop(conn, id_b, &options)?;
858                assert_eq!(result.nodes.len(), 3); // B, A, C
859
860                Ok(())
861            })
862            .unwrap();
863    }
864
865    #[test]
866    fn test_edge_type_filter() {
867        let storage = Storage::open_in_memory().unwrap();
868        storage
869            .with_transaction(|conn| {
870                let id_a = create_test_memory(conn, "A");
871                let id_b = create_test_memory(conn, "B");
872                let id_c = create_test_memory(conn, "C");
873
874                create_test_crossref(conn, id_a, id_b, EdgeType::RelatedTo)?;
875                create_test_crossref(conn, id_a, id_c, EdgeType::DependsOn)?;
876
877                // Filter to only RelatedTo edges
878                let options = TraversalOptions {
879                    depth: 1,
880                    edge_types: vec![EdgeType::RelatedTo],
881                    include_entities: false,
882                    ..Default::default()
883                };
884                let result = get_related_multi_hop(conn, id_a, &options)?;
885                assert_eq!(result.nodes.len(), 2); // A + B only
886                assert!(result.nodes.iter().any(|n| n.memory_id == id_b));
887                assert!(!result.nodes.iter().any(|n| n.memory_id == id_c));
888
889                Ok(())
890            })
891            .unwrap();
892    }
893}