Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::constants::{
4    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashMap, HashSet, VecDeque};
14
15/// Identifies whether the seed resolved to a memory or a bare entity.
16enum SeedKind {
17    Memory(i64),
18    Entity(i64),
19}
20
21/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
22/// target_name, relation, weight).
23type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n  \
27    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
28    sqlite-graphrag related onboarding\n\n  \
29    # Increase hop distance and filter by relation type\n  \
30    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
31    # Cap result count and require minimum edge weight\n  \
32    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34    /// Memory name as a positional argument. Alternative to `--name`.
35    #[arg(
36        value_name = "NAME",
37        conflicts_with = "name",
38        help = "Memory name whose neighbours to traverse; alternative to --name"
39    )]
40    pub name_positional: Option<String>,
41    /// Memory name as a flag. Required when the positional form is absent. Also accepts the alias `--from`.
42    #[arg(long, alias = "from")]
43    pub name: Option<String>,
44    /// Maximum graph hop count. Also accepts the alias `--hops`.
45    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46    pub max_hops: u32,
47    /// Filter results to a specific relation type. Canonical values:
48    /// applies-to, uses, depends-on, causes, fixes, contradicts, supports,
49    /// follows, related, mentions, replaces, tracked-in.
50    /// Any kebab-case or snake_case string is also accepted as a custom relation.
51    #[arg(long, value_parser = crate::parsers::parse_relation)]
52    pub relation: Option<String>,
53    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54    pub min_weight: f64,
55    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56    pub limit: usize,
57    #[arg(long)]
58    pub namespace: Option<String>,
59    #[arg(long, value_enum, default_value = "json")]
60    pub format: OutputFormat,
61    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62    pub json: bool,
63    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64    pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69    /// Echo of the seed memory name resolved from `--name` or the positional argument.
70    /// Added in v1.0.35 for input transparency in JSON output.
71    name: String,
72    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
73    max_hops: u32,
74    results: Vec<RelatedMemory>,
75    elapsed_ms: u64,
76}
77
78#[derive(Serialize, Clone)]
79struct RelatedMemory {
80    memory_id: i64,
81    name: String,
82    namespace: String,
83    #[serde(rename = "type")]
84    memory_type: String,
85    description: String,
86    hop_distance: u32,
87    source_entity: Option<String>,
88    target_entity: Option<String>,
89    relation: Option<String>,
90    weight: Option<f64>,
91}
92
93pub fn run(args: RelatedArgs) -> Result<(), AppError> {
94    let inicio = std::time::Instant::now();
95    let name = args
96        .name_positional
97        .as_deref()
98        .or(args.name.as_deref())
99        .ok_or_else(|| {
100            AppError::Validation(
101                "name required: pass as positional argument or via --name".to_string(),
102            )
103        })?
104        .to_string();
105
106    if name.trim().is_empty() {
107        return Err(AppError::Validation("name must not be empty".to_string()));
108    }
109
110    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
111    let paths = AppPaths::resolve(args.db.as_deref())?;
112
113    crate::storage::connection::ensure_db_ready(&paths)?;
114
115    let conn = open_ro(&paths.db)?;
116
117    // Locate the seed: try memory first, fall back to bare entity.
118    let seed = match conn.query_row(
119        "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
120        params![namespace, name],
121        |r| r.get::<_, i64>(0),
122    ) {
123        Ok(id) => SeedKind::Memory(id),
124        Err(rusqlite::Error::QueryReturnedNoRows) => {
125            match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
126                Some(id) => SeedKind::Entity(id),
127                None => {
128                    return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
129                        &name, &namespace,
130                    )))
131                }
132            }
133        }
134        Err(e) => return Err(AppError::Database(e)),
135    };
136
137    // Collect seed entity IDs depending on seed kind.
138    let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
139        SeedKind::Memory(id) => {
140            let mem_id = *id;
141            let mut stmt =
142                conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
143            let rows: Vec<i64> = stmt
144                .query_map(params![mem_id], |r| r.get(0))?
145                .collect::<Result<Vec<i64>, _>>()?;
146            (mem_id, rows)
147        }
148        SeedKind::Entity(entity_id) => {
149            // For a bare entity seed there is no corresponding memory to skip.
150            // Use a sentinel -1 so dedup never matches a real memory_id.
151            (-1, vec![*entity_id])
152        }
153    };
154
155    let relation_filter = args.relation;
156    if let Some(ref r) = relation_filter {
157        crate::parsers::warn_if_non_canonical(r);
158    }
159    let results = traverse_related(
160        &conn,
161        seed_memory_id,
162        &seed_entity_ids,
163        &namespace,
164        args.max_hops,
165        args.min_weight,
166        relation_filter.as_deref(),
167        args.limit,
168    )?;
169
170    match args.format {
171        OutputFormat::Json => output::emit_json(&RelatedResponse {
172            name: name.clone(),
173            max_hops: args.max_hops,
174            results,
175            elapsed_ms: inicio.elapsed().as_millis() as u64,
176        })?,
177        OutputFormat::Text => {
178            for item in &results {
179                if item.description.is_empty() {
180                    output::emit_text(&format!(
181                        "{}. {} ({})",
182                        item.hop_distance, item.name, item.namespace
183                    ));
184                } else {
185                    let preview: String = item
186                        .description
187                        .chars()
188                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
189                        .collect();
190                    output::emit_text(&format!(
191                        "{}. {} ({}): {}",
192                        item.hop_distance, item.name, item.namespace, preview
193                    ));
194                }
195            }
196        }
197        OutputFormat::Markdown => {
198            for item in &results {
199                if item.description.is_empty() {
200                    output::emit_text(&format!(
201                        "- **{}** ({}) — hop {}",
202                        item.name, item.namespace, item.hop_distance
203                    ));
204                } else {
205                    let preview: String = item
206                        .description
207                        .chars()
208                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
209                        .collect();
210                    output::emit_text(&format!(
211                        "- **{}** ({}) — hop {}: {}",
212                        item.name, item.namespace, item.hop_distance, preview
213                    ));
214                }
215            }
216        }
217    }
218
219    Ok(())
220}
221
222#[allow(clippy::too_many_arguments)]
223fn traverse_related(
224    conn: &Connection,
225    seed_memory_id: i64,
226    seed_entity_ids: &[i64],
227    namespace: &str,
228    max_hops: u32,
229    min_weight: f64,
230    relation_filter: Option<&str>,
231    limit: usize,
232) -> Result<Vec<RelatedMemory>, AppError> {
233    if seed_entity_ids.is_empty() || max_hops == 0 {
234        return Ok(Vec::new());
235    }
236
237    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
238    // of the edge that first reached each entity.
239    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
240    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
241    for &e in seed_entity_ids {
242        entity_hop.insert(e, 0);
243    }
244    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
245    // that reached this entity — equivalent to BFS shortest path recall edge).
246    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
247
248    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
249
250    while let Some(current_entity) = queue.pop_front() {
251        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
252        if current_hop >= max_hops {
253            continue;
254        }
255
256        let neighbours =
257            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
258
259        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
260            if visited.insert(neighbour_id) {
261                entity_hop.insert(neighbour_id, current_hop + 1);
262                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
263                queue.push_back(neighbour_id);
264            }
265        }
266    }
267
268    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
269    let mut out: Vec<RelatedMemory> = Vec::new();
270    let mut dedup_ids: HashSet<i64> = HashSet::new();
271    dedup_ids.insert(seed_memory_id);
272
273    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
274    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
275        .iter()
276        .filter(|(id, _)| !seed_entity_ids.contains(id))
277        .map(|(id, hop)| (*id, *hop))
278        .collect();
279    ordered_entities.sort_by(|a, b| {
280        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
281        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
282        a.1.cmp(&b.1).then_with(|| {
283            weight_b
284                .partial_cmp(&weight_a)
285                .unwrap_or(std::cmp::Ordering::Equal)
286        })
287    });
288
289    for (entity_id, hop) in ordered_entities {
290        let mut stmt = conn.prepare_cached(
291            "SELECT m.id, m.name, m.namespace, m.type, m.description
292             FROM memory_entities me
293             JOIN memories m ON m.id = me.memory_id
294             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
295        )?;
296        let rows = stmt
297            .query_map(params![entity_id], |r| {
298                Ok((
299                    r.get::<_, i64>(0)?,
300                    r.get::<_, String>(1)?,
301                    r.get::<_, String>(2)?,
302                    r.get::<_, String>(3)?,
303                    r.get::<_, String>(4)?,
304                ))
305            })?
306            .collect::<Result<Vec<_>, _>>()?;
307
308        for (mid, name, ns, mtype, desc) in rows {
309            if !dedup_ids.insert(mid) {
310                continue;
311            }
312            let edge = entity_edge.get(&entity_id);
313            out.push(RelatedMemory {
314                memory_id: mid,
315                name,
316                namespace: ns,
317                memory_type: mtype,
318                description: desc,
319                hop_distance: hop,
320                source_entity: edge.map(|e| e.0.clone()),
321                target_entity: edge.map(|e| e.1.clone()),
322                relation: edge.map(|e| e.2.clone()),
323                weight: edge.map(|e| e.3),
324            });
325            if out.len() >= limit {
326                return Ok(out);
327            }
328        }
329    }
330
331    Ok(out)
332}
333
334fn fetch_neighbours(
335    conn: &Connection,
336    entity_id: i64,
337    namespace: &str,
338    min_weight: f64,
339    relation_filter: Option<&str>,
340) -> Result<Vec<Neighbour>, AppError> {
341    // Follow edges in both directions: source -> target and target -> source so traversal is
342    // undirected, which is how users typically reason about "related" memories.
343    let base_sql = "\
344        SELECT r.target_id, se.name, te.name, r.relation, r.weight
345        FROM relationships r
346        JOIN entities se ON se.id = r.source_id
347        JOIN entities te ON te.id = r.target_id
348        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
349
350    let reverse_sql = "\
351        SELECT r.source_id, se.name, te.name, r.relation, r.weight
352        FROM relationships r
353        JOIN entities se ON se.id = r.source_id
354        JOIN entities te ON te.id = r.target_id
355        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
356
357    let mut results: Vec<Neighbour> = Vec::new();
358
359    let forward_sql = match relation_filter {
360        Some(_) => format!("{base_sql} AND r.relation = ?4"),
361        None => base_sql.to_string(),
362    };
363    let rev_sql = match relation_filter {
364        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
365        None => reverse_sql.to_string(),
366    };
367
368    let mut stmt = conn.prepare_cached(&forward_sql)?;
369    let rows: Vec<_> = if let Some(rel) = relation_filter {
370        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
371            Ok((
372                r.get::<_, i64>(0)?,
373                r.get::<_, String>(1)?,
374                r.get::<_, String>(2)?,
375                r.get::<_, String>(3)?,
376                r.get::<_, f64>(4)?,
377            ))
378        })?
379        .collect::<Result<Vec<_>, _>>()?
380    } else {
381        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
382            Ok((
383                r.get::<_, i64>(0)?,
384                r.get::<_, String>(1)?,
385                r.get::<_, String>(2)?,
386                r.get::<_, String>(3)?,
387                r.get::<_, f64>(4)?,
388            ))
389        })?
390        .collect::<Result<Vec<_>, _>>()?
391    };
392    results.extend(rows);
393
394    let mut stmt = conn.prepare_cached(&rev_sql)?;
395    let rows: Vec<_> = if let Some(rel) = relation_filter {
396        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
397            Ok((
398                r.get::<_, i64>(0)?,
399                r.get::<_, String>(1)?,
400                r.get::<_, String>(2)?,
401                r.get::<_, String>(3)?,
402                r.get::<_, f64>(4)?,
403            ))
404        })?
405        .collect::<Result<Vec<_>, _>>()?
406    } else {
407        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
408            Ok((
409                r.get::<_, i64>(0)?,
410                r.get::<_, String>(1)?,
411                r.get::<_, String>(2)?,
412                r.get::<_, String>(3)?,
413                r.get::<_, f64>(4)?,
414            ))
415        })?
416        .collect::<Result<Vec<_>, _>>()?
417    };
418    results.extend(rows);
419
420    Ok(results)
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    fn setup_related_db() -> rusqlite::Connection {
428        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
429        conn.execute_batch(
430            "CREATE TABLE memories (
431                id INTEGER PRIMARY KEY AUTOINCREMENT,
432                name TEXT NOT NULL,
433                namespace TEXT NOT NULL DEFAULT 'global',
434                type TEXT NOT NULL DEFAULT 'fact',
435                description TEXT NOT NULL DEFAULT '',
436                deleted_at INTEGER
437            );
438            CREATE TABLE entities (
439                id INTEGER PRIMARY KEY AUTOINCREMENT,
440                namespace TEXT NOT NULL,
441                name TEXT NOT NULL
442            );
443            CREATE TABLE relationships (
444                id INTEGER PRIMARY KEY AUTOINCREMENT,
445                namespace TEXT NOT NULL,
446                source_id INTEGER NOT NULL,
447                target_id INTEGER NOT NULL,
448                relation TEXT NOT NULL DEFAULT 'related_to',
449                weight REAL NOT NULL DEFAULT 1.0
450            );
451            CREATE TABLE memory_entities (
452                memory_id INTEGER NOT NULL,
453                entity_id INTEGER NOT NULL
454            );",
455        )
456        .expect("failed to create test tables");
457        conn
458    }
459
460    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
461        conn.execute(
462            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
463            rusqlite::params![name, namespace],
464        )
465        .expect("failed to insert memory");
466        conn.last_insert_rowid()
467    }
468
469    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
470        conn.execute(
471            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
472            rusqlite::params![name, namespace],
473        )
474        .expect("failed to insert entity");
475        conn.last_insert_rowid()
476    }
477
478    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
479        conn.execute(
480            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
481            rusqlite::params![memory_id, entity_id],
482        )
483        .expect("failed to link memory-entity");
484    }
485
486    fn insert_relationship(
487        conn: &rusqlite::Connection,
488        namespace: &str,
489        source_id: i64,
490        target_id: i64,
491        relation: &str,
492        weight: f64,
493    ) {
494        conn.execute(
495            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
496             VALUES (?1, ?2, ?3, ?4, ?5)",
497            rusqlite::params![namespace, source_id, target_id, relation, weight],
498        )
499        .expect("failed to insert relationship");
500    }
501
502    #[test]
503    fn related_response_serializes_results_and_elapsed_ms() {
504        let resp = RelatedResponse {
505            name: "seed-mem".to_string(),
506            max_hops: 2,
507            results: vec![RelatedMemory {
508                memory_id: 1,
509                name: "neighbor-mem".to_string(),
510                namespace: "global".to_string(),
511                memory_type: "document".to_string(),
512                description: "desc".to_string(),
513                hop_distance: 1,
514                source_entity: Some("entity-a".to_string()),
515                target_entity: Some("entity-b".to_string()),
516                relation: Some("related_to".to_string()),
517                weight: Some(0.9),
518            }],
519            elapsed_ms: 7,
520        };
521
522        let json = serde_json::to_value(&resp).expect("serialization failed");
523        assert!(json["results"].is_array());
524        assert_eq!(json["results"].as_array().unwrap().len(), 1);
525        assert_eq!(json["elapsed_ms"], 7u64);
526        assert_eq!(json["results"][0]["type"], "document");
527        assert_eq!(json["results"][0]["hop_distance"], 1);
528    }
529
530    #[test]
531    fn traverse_related_returns_empty_without_seed_entities() {
532        let conn = setup_related_db();
533        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
534            .expect("traverse_related failed");
535        assert!(
536            result.is_empty(),
537            "no seed entities must yield empty results"
538        );
539    }
540
541    #[test]
542    fn traverse_related_returns_empty_with_max_hops_zero() {
543        let conn = setup_related_db();
544        let mem_id = insert_memory(&conn, "seed-mem", "global");
545        let ent_id = insert_entity(&conn, "ent-a", "global");
546        link_memory_entity(&conn, mem_id, ent_id);
547
548        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
549            .expect("traverse_related failed");
550        assert!(result.is_empty(), "max_hops=0 must return empty");
551    }
552
553    #[test]
554    fn traverse_related_discovers_neighbor_memory_via_graph() {
555        let conn = setup_related_db();
556
557        let seed_id = insert_memory(&conn, "seed-mem", "global");
558        let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
559        let ent_a = insert_entity(&conn, "ent-a", "global");
560        let ent_b = insert_entity(&conn, "ent-b", "global");
561
562        link_memory_entity(&conn, seed_id, ent_a);
563        link_memory_entity(&conn, neighbor_id, ent_b);
564        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
565
566        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
567            .expect("traverse_related failed");
568
569        assert_eq!(result.len(), 1, "must find 1 neighboring memory");
570        assert_eq!(result[0].name, "neighbor-mem");
571        assert_eq!(result[0].hop_distance, 1);
572    }
573
574    #[test]
575    fn traverse_related_respects_limit() {
576        let conn = setup_related_db();
577
578        let seed_id = insert_memory(&conn, "seed", "global");
579        let ent_seed = insert_entity(&conn, "ent-seed", "global");
580        link_memory_entity(&conn, seed_id, ent_seed);
581
582        for i in 0..5 {
583            let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
584            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
585            link_memory_entity(&conn, mem_id, ent_id);
586            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
587        }
588
589        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
590            .expect("traverse_related failed");
591
592        assert!(
593            result.len() <= 3,
594            "limit=3 must constrain to at most 3 results"
595        );
596    }
597
598    #[test]
599    fn related_memory_optional_null_fields_serialized() {
600        let mem = RelatedMemory {
601            memory_id: 99,
602            name: "no-relation".to_string(),
603            namespace: "ns".to_string(),
604            memory_type: "concept".to_string(),
605            description: "".to_string(),
606            hop_distance: 2,
607            source_entity: None,
608            target_entity: None,
609            relation: None,
610            weight: None,
611        };
612
613        let json = serde_json::to_value(&mem).expect("serialization failed");
614        assert!(json["source_entity"].is_null());
615        assert!(json["target_entity"].is_null());
616        assert!(json["relation"].is_null());
617        assert!(json["weight"].is_null());
618        assert_eq!(json["hop_distance"], 2);
619    }
620}