Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::constants::{
4    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashSet, VecDeque};
14
15/// Identifies whether the seed resolved to a memory or a bare entity.
16enum SeedKind {
17    Memory(i64),
18    Entity(i64),
19}
20
21/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
22/// target_name, relation, weight).
23type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n  \
27    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
28    sqlite-graphrag related onboarding\n\n  \
29    # Increase hop distance and filter by relation type\n  \
30    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
31    # Cap result count and require minimum edge weight\n  \
32    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34    /// Memory name as a positional argument. Alternative to `--name`.
35    #[arg(
36        value_name = "NAME",
37        conflicts_with = "name",
38        help = "Memory name whose neighbours to traverse; alternative to --name"
39    )]
40    pub name_positional: Option<String>,
41    /// Memory name as a flag. Required when the positional form is absent. Also accepts the alias `--from`.
42    #[arg(long, alias = "from")]
43    pub name: Option<String>,
44    /// Maximum graph hop count. Also accepts the alias `--hops`.
45    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46    pub max_hops: u32,
47    /// Filter results to a specific relation type. Canonical values:
48    /// applies-to, uses, depends-on, causes, fixes, contradicts, supports,
49    /// follows, related, mentions, replaces, tracked-in.
50    /// Any kebab-case or snake_case string is also accepted as a custom relation.
51    #[arg(long, value_parser = crate::parsers::parse_relation)]
52    pub relation: Option<String>,
53    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54    pub min_weight: f64,
55    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56    pub limit: usize,
57    #[arg(long)]
58    pub namespace: Option<String>,
59    #[arg(long, value_enum, default_value = "json")]
60    pub format: OutputFormat,
61    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62    pub json: bool,
63    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64    pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69    /// Echo of the seed memory name resolved from `--name` or the positional argument.
70    /// Added in v1.0.35 for input transparency in JSON output.
71    name: String,
72    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
73    max_hops: u32,
74    results: Vec<RelatedMemory>,
75    /// Semantic alias of `results` following the v1.0.66 alias pattern (list has items/memories).
76    related_memories: Vec<RelatedMemory>,
77    elapsed_ms: u64,
78}
79
80#[derive(Serialize, Clone)]
81struct RelatedMemory {
82    memory_id: i64,
83    name: String,
84    namespace: String,
85    #[serde(rename = "type")]
86    memory_type: String,
87    description: String,
88    hop_distance: u32,
89    source_entity: Option<String>,
90    target_entity: Option<String>,
91    /// Alias of `source_entity` for cross-command consistency (graph, link, deep-research use from/to).
92    #[serde(skip_serializing_if = "Option::is_none")]
93    from: Option<String>,
94    /// Alias of `target_entity` for cross-command consistency.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    to: Option<String>,
97    relation: Option<String>,
98    weight: Option<f64>,
99}
100
101pub fn run(args: RelatedArgs) -> Result<(), AppError> {
102    let inicio = std::time::Instant::now();
103    let name = args
104        .name_positional
105        .as_deref()
106        .or(args.name.as_deref())
107        .ok_or_else(|| {
108            AppError::Validation(
109                "name required: pass as positional argument or via --name".to_string(),
110            )
111        })?
112        .to_string();
113
114    if name.trim().is_empty() {
115        return Err(AppError::Validation("name must not be empty".to_string()));
116    }
117
118    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
119    let paths = AppPaths::resolve(args.db.as_deref())?;
120
121    crate::storage::connection::ensure_db_ready(&paths)?;
122
123    let conn = open_ro(&paths.db)?;
124
125    // Locate the seed: try memory first, fall back to bare entity.
126    let seed = match conn.query_row(
127        "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
128        params![namespace, name],
129        |r| r.get::<_, i64>(0),
130    ) {
131        Ok(id) => SeedKind::Memory(id),
132        Err(rusqlite::Error::QueryReturnedNoRows) => {
133            match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
134                Some(id) => SeedKind::Entity(id),
135                None => {
136                    return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
137                        &name, &namespace,
138                    )))
139                }
140            }
141        }
142        Err(e) => return Err(AppError::Database(e)),
143    };
144
145    // Collect seed entity IDs depending on seed kind.
146    let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
147        SeedKind::Memory(id) => {
148            let mem_id = *id;
149            let mut stmt =
150                conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151            let rows: Vec<i64> = stmt
152                .query_map(params![mem_id], |r| r.get(0))?
153                .collect::<Result<Vec<i64>, _>>()?;
154            (mem_id, rows)
155        }
156        SeedKind::Entity(entity_id) => {
157            // For a bare entity seed there is no corresponding memory to skip.
158            // Use a sentinel -1 so dedup never matches a real memory_id.
159            (-1, vec![*entity_id])
160        }
161    };
162
163    let relation_filter = args.relation;
164    if let Some(ref r) = relation_filter {
165        crate::parsers::warn_if_non_canonical(r);
166    }
167    let results = traverse_related(
168        &conn,
169        seed_memory_id,
170        &seed_entity_ids,
171        &namespace,
172        args.max_hops,
173        args.min_weight,
174        relation_filter.as_deref(),
175        args.limit,
176    )?;
177
178    match args.format {
179        OutputFormat::Json => {
180            let related_memories = results.clone();
181            output::emit_json(&RelatedResponse {
182                name: name.clone(),
183                max_hops: args.max_hops,
184                results,
185                related_memories,
186                elapsed_ms: inicio.elapsed().as_millis() as u64,
187            })?;
188        }
189        OutputFormat::Text => {
190            for item in &results {
191                if item.description.is_empty() {
192                    output::emit_text(&format!(
193                        "{}. {} ({})",
194                        item.hop_distance, item.name, item.namespace
195                    ));
196                } else {
197                    let preview: String = item
198                        .description
199                        .chars()
200                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
201                        .collect();
202                    output::emit_text(&format!(
203                        "{}. {} ({}): {}",
204                        item.hop_distance, item.name, item.namespace, preview
205                    ));
206                }
207            }
208        }
209        OutputFormat::Markdown => {
210            for item in &results {
211                if item.description.is_empty() {
212                    output::emit_text(&format!(
213                        "- **{}** ({}) — hop {}",
214                        item.name, item.namespace, item.hop_distance
215                    ));
216                } else {
217                    let preview: String = item
218                        .description
219                        .chars()
220                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
221                        .collect();
222                    output::emit_text(&format!(
223                        "- **{}** ({}) — hop {}: {}",
224                        item.name, item.namespace, item.hop_distance, preview
225                    ));
226                }
227            }
228        }
229    }
230
231    Ok(())
232}
233
234#[allow(clippy::too_many_arguments)]
235fn traverse_related(
236    conn: &Connection,
237    seed_memory_id: i64,
238    seed_entity_ids: &[i64],
239    namespace: &str,
240    max_hops: u32,
241    min_weight: f64,
242    relation_filter: Option<&str>,
243    limit: usize,
244) -> Result<Vec<RelatedMemory>, AppError> {
245    if seed_entity_ids.is_empty() || max_hops == 0 {
246        return Ok(Vec::new());
247    }
248
249    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
250    // of the edge that first reached each entity.
251    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
252    let mut entity_hop: crate::hash::AHashMap<i64, u32> =
253        crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
254    for &e in seed_entity_ids {
255        entity_hop.insert(e, 0);
256    }
257    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
258    // that reached this entity — equivalent to BFS shortest path recall edge).
259    let mut entity_edge: crate::hash::AHashMap<i64, (String, String, String, f64)> =
260        crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
261
262    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
263
264    while let Some(current_entity) = queue.pop_front() {
265        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
266        if current_hop >= max_hops {
267            continue;
268        }
269
270        let neighbours =
271            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
272
273        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
274            if visited.insert(neighbour_id) {
275                entity_hop.insert(neighbour_id, current_hop + 1);
276                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
277                queue.push_back(neighbour_id);
278            }
279        }
280    }
281
282    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
283    let mut out: Vec<RelatedMemory> = Vec::with_capacity(limit);
284    let mut dedup_ids: crate::hash::AHashSet<i64> =
285        crate::hash::AHashSet::with_capacity_and_hasher(limit, Default::default());
286    dedup_ids.insert(seed_memory_id);
287
288    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
289    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
290        .iter()
291        .filter(|(id, _)| !seed_entity_ids.contains(id))
292        .map(|(id, hop)| (*id, *hop))
293        .collect();
294    ordered_entities.sort_by(|a, b| {
295        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
296        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
297        a.1.cmp(&b.1).then_with(|| {
298            weight_b
299                .partial_cmp(&weight_a)
300                .unwrap_or(std::cmp::Ordering::Equal)
301        })
302    });
303
304    for (entity_id, hop) in ordered_entities {
305        let mut stmt = conn.prepare_cached(
306            "SELECT m.id, m.name, m.namespace, m.type, m.description
307             FROM memory_entities me
308             JOIN memories m ON m.id = me.memory_id
309             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
310        )?;
311        let rows = stmt
312            .query_map(params![entity_id], |r| {
313                Ok((
314                    r.get::<_, i64>(0)?,
315                    r.get::<_, String>(1)?,
316                    r.get::<_, String>(2)?,
317                    r.get::<_, String>(3)?,
318                    r.get::<_, String>(4)?,
319                ))
320            })?
321            .collect::<Result<Vec<_>, _>>()?;
322
323        for (mid, name, ns, mtype, desc) in rows {
324            if !dedup_ids.insert(mid) {
325                continue;
326            }
327            let edge = entity_edge.get(&entity_id);
328            let src = edge.map(|e| e.0.clone());
329            let tgt = edge.map(|e| e.1.clone());
330            out.push(RelatedMemory {
331                memory_id: mid,
332                name,
333                namespace: ns,
334                memory_type: mtype,
335                description: desc,
336                hop_distance: hop,
337                source_entity: src.clone(),
338                target_entity: tgt.clone(),
339                from: src,
340                to: tgt,
341                relation: edge.map(|e| e.2.clone()),
342                weight: edge.map(|e| e.3),
343            });
344            if out.len() >= limit {
345                return Ok(out);
346            }
347        }
348    }
349    Ok(out)
350}
351
352fn fetch_neighbours(
353    conn: &Connection,
354    entity_id: i64,
355    namespace: &str,
356    min_weight: f64,
357    relation_filter: Option<&str>,
358) -> Result<Vec<Neighbour>, AppError> {
359    // Follow edges in both directions: source -> target and target -> source so traversal is
360    // undirected, which is how users typically reason about "related" memories.
361    let base_sql = "\
362        SELECT r.target_id, se.name, te.name, r.relation, r.weight
363        FROM relationships r
364        JOIN entities se ON se.id = r.source_id
365        JOIN entities te ON te.id = r.target_id
366        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
367
368    let reverse_sql = "\
369        SELECT r.source_id, se.name, te.name, r.relation, r.weight
370        FROM relationships r
371        JOIN entities se ON se.id = r.source_id
372        JOIN entities te ON te.id = r.target_id
373        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
374
375    let mut results: Vec<Neighbour> = Vec::with_capacity(16);
376
377    let forward_sql = match relation_filter {
378        Some(_) => format!("{base_sql} AND r.relation = ?4"),
379        None => base_sql.to_string(),
380    };
381    let rev_sql = match relation_filter {
382        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
383        None => reverse_sql.to_string(),
384    };
385
386    let mut stmt = conn.prepare_cached(&forward_sql)?;
387    let rows: Vec<_> = if let Some(rel) = relation_filter {
388        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
389            Ok((
390                r.get::<_, i64>(0)?,
391                r.get::<_, String>(1)?,
392                r.get::<_, String>(2)?,
393                r.get::<_, String>(3)?,
394                r.get::<_, f64>(4)?,
395            ))
396        })?
397        .collect::<Result<Vec<_>, _>>()?
398    } else {
399        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
400            Ok((
401                r.get::<_, i64>(0)?,
402                r.get::<_, String>(1)?,
403                r.get::<_, String>(2)?,
404                r.get::<_, String>(3)?,
405                r.get::<_, f64>(4)?,
406            ))
407        })?
408        .collect::<Result<Vec<_>, _>>()?
409    };
410    results.extend(rows);
411
412    let mut stmt = conn.prepare_cached(&rev_sql)?;
413    let rows: Vec<_> = if let Some(rel) = relation_filter {
414        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
415            Ok((
416                r.get::<_, i64>(0)?,
417                r.get::<_, String>(1)?,
418                r.get::<_, String>(2)?,
419                r.get::<_, String>(3)?,
420                r.get::<_, f64>(4)?,
421            ))
422        })?
423        .collect::<Result<Vec<_>, _>>()?
424    } else {
425        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
426            Ok((
427                r.get::<_, i64>(0)?,
428                r.get::<_, String>(1)?,
429                r.get::<_, String>(2)?,
430                r.get::<_, String>(3)?,
431                r.get::<_, f64>(4)?,
432            ))
433        })?
434        .collect::<Result<Vec<_>, _>>()?
435    };
436    results.extend(rows);
437
438    Ok(results)
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    fn setup_related_db() -> rusqlite::Connection {
446        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
447        conn.execute_batch(
448            "CREATE TABLE memories (
449                id INTEGER PRIMARY KEY AUTOINCREMENT,
450                name TEXT NOT NULL,
451                namespace TEXT NOT NULL DEFAULT 'global',
452                type TEXT NOT NULL DEFAULT 'fact',
453                description TEXT NOT NULL DEFAULT '',
454                deleted_at INTEGER
455            );
456            CREATE TABLE entities (
457                id INTEGER PRIMARY KEY AUTOINCREMENT,
458                namespace TEXT NOT NULL,
459                name TEXT NOT NULL
460            );
461            CREATE TABLE relationships (
462                id INTEGER PRIMARY KEY AUTOINCREMENT,
463                namespace TEXT NOT NULL,
464                source_id INTEGER NOT NULL,
465                target_id INTEGER NOT NULL,
466                relation TEXT NOT NULL DEFAULT 'related_to',
467                weight REAL NOT NULL DEFAULT 1.0
468            );
469            CREATE TABLE memory_entities (
470                memory_id INTEGER NOT NULL,
471                entity_id INTEGER NOT NULL
472            );",
473        )
474        .expect("failed to create test tables");
475        conn
476    }
477
478    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
479        conn.execute(
480            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
481            rusqlite::params![name, namespace],
482        )
483        .expect("failed to insert memory");
484        conn.last_insert_rowid()
485    }
486
487    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
488        conn.execute(
489            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
490            rusqlite::params![name, namespace],
491        )
492        .expect("failed to insert entity");
493        conn.last_insert_rowid()
494    }
495
496    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
497        conn.execute(
498            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
499            rusqlite::params![memory_id, entity_id],
500        )
501        .expect("failed to link memory-entity");
502    }
503
504    fn insert_relationship(
505        conn: &rusqlite::Connection,
506        namespace: &str,
507        source_id: i64,
508        target_id: i64,
509        relation: &str,
510        weight: f64,
511    ) {
512        conn.execute(
513            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
514             VALUES (?1, ?2, ?3, ?4, ?5)",
515            rusqlite::params![namespace, source_id, target_id, relation, weight],
516        )
517        .expect("failed to insert relationship");
518    }
519
520    #[test]
521    fn related_response_serializes_results_and_elapsed_ms() {
522        let mem = RelatedMemory {
523            memory_id: 1,
524            name: "neighbor-mem".to_string(),
525            namespace: "global".to_string(),
526            memory_type: "document".to_string(),
527            description: "desc".to_string(),
528            hop_distance: 1,
529            source_entity: Some("entity-a".to_string()),
530            target_entity: Some("entity-b".to_string()),
531            from: Some("entity-a".to_string()),
532            to: Some("entity-b".to_string()),
533            relation: Some("related_to".to_string()),
534            weight: Some(0.9),
535        };
536        let resp = RelatedResponse {
537            name: "seed-mem".to_string(),
538            max_hops: 2,
539            related_memories: vec![mem.clone()],
540            results: vec![mem],
541            elapsed_ms: 7,
542        };
543        let json = serde_json::to_value(&resp).expect("serialization failed");
544        assert!(json["results"].is_array());
545        assert_eq!(json["results"].as_array().unwrap().len(), 1);
546        assert_eq!(json["elapsed_ms"], 7u64);
547        assert_eq!(json["results"][0]["type"], "document");
548        assert_eq!(json["results"][0]["hop_distance"], 1);
549    }
550
551    #[test]
552    fn traverse_related_returns_empty_without_seed_entities() {
553        let conn = setup_related_db();
554        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
555            .expect("traverse_related failed");
556        assert!(result.is_empty());
557    }
558
559    #[test]
560    fn traverse_related_returns_empty_with_max_hops_zero() {
561        let conn = setup_related_db();
562        let mem_id = insert_memory(&conn, "seed", "global");
563        let ent_id = insert_entity(&conn, "global", "ent");
564        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
565            .expect("traverse_related failed");
566        assert!(result.is_empty());
567    }
568
569    #[test]
570    fn traverse_related_discovers_neighbor_memory_via_graph() {
571        let conn = setup_related_db();
572        let seed_id = insert_memory(&conn, "seed", "global");
573        let ent_a = insert_entity(&conn, "global", "ent-a");
574        let ent_b = insert_entity(&conn, "global", "ent-b");
575        let neighbor_id = insert_memory(&conn, "neighbor", "global");
576        link_memory_entity(&conn, seed_id, ent_a);
577        link_memory_entity(&conn, neighbor_id, ent_b);
578        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
579        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
580            .expect("traverse_related failed");
581        assert_eq!(result.len(), 1);
582        assert_eq!(result[0].name, "neighbor");
583    }
584
585    #[test]
586    fn traverse_related_respects_limit() {
587        let conn = setup_related_db();
588        let seed_id = insert_memory(&conn, "seed", "global");
589        let ent_seed = insert_entity(&conn, "global", "ent-seed");
590        link_memory_entity(&conn, seed_id, ent_seed);
591        for i in 0..5 {
592            let ent_id = insert_entity(&conn, "global", &format!("ent-{i}"));
593            let mem_id = insert_memory(&conn, &format!("mem-{i}"), "global");
594            link_memory_entity(&conn, mem_id, ent_id);
595            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
596        }
597        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
598            .expect("traverse_related failed");
599        assert_eq!(
600            result.len(),
601            3,
602            "limit=3 must constrain to at most 3 results"
603        );
604    }
605
606    #[test]
607    fn related_memory_optional_null_fields_serialized() {
608        let mem = RelatedMemory {
609            memory_id: 99,
610            name: "no-relation".to_string(),
611            namespace: "ns".to_string(),
612            memory_type: "concept".to_string(),
613            description: "".to_string(),
614            hop_distance: 2,
615            source_entity: None,
616            target_entity: None,
617            from: None,
618            to: None,
619            relation: None,
620            weight: None,
621        };
622        let json = serde_json::to_value(&mem).expect("serialization failed");
623        assert!(json["source_entity"].is_null());
624        assert!(json["target_entity"].is_null());
625        assert!(json["relation"].is_null());
626        assert!(json["weight"].is_null());
627        assert_eq!(json["hop_distance"], 2);
628    }
629}