Skip to main content

sqlite_graphrag/commands/
related.rs

1use crate::cli::RelationKind;
2use crate::constants::{
3    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
4};
5use crate::errors::AppError;
6use crate::i18n::erros;
7use crate::output::{self, OutputFormat};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use rusqlite::{params, Connection};
11use serde::Serialize;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
15/// target_name, relation, weight).
16type Neighbour = (i64, String, String, String, f64);
17
18#[derive(clap::Args)]
19pub struct RelatedArgs {
20    /// Memory name as a positional argument. Alternative to `--name`.
21    #[arg(value_name = "NAME", conflicts_with = "name")]
22    pub name_positional: Option<String>,
23    /// Memory name as a flag. Required when the positional form is absent.
24    #[arg(long)]
25    pub name: Option<String>,
26    /// Maximum graph hop count. Also accepts the alias `--hops`.
27    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
28    pub max_hops: u32,
29    #[arg(long, value_enum)]
30    pub relation: Option<RelationKind>,
31    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
32    pub min_weight: f64,
33    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
34    pub limit: usize,
35    #[arg(long)]
36    pub namespace: Option<String>,
37    #[arg(long, value_enum, default_value = "json")]
38    pub format: OutputFormat,
39    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40    pub json: bool,
41    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
42    pub db: Option<String>,
43}
44
45#[derive(Serialize)]
46struct RelatedResponse {
47    results: Vec<RelatedMemory>,
48    elapsed_ms: u64,
49}
50
51#[derive(Serialize, Clone)]
52struct RelatedMemory {
53    memory_id: i64,
54    name: String,
55    namespace: String,
56    #[serde(rename = "type")]
57    memory_type: String,
58    description: String,
59    hop_distance: u32,
60    source_entity: Option<String>,
61    target_entity: Option<String>,
62    relation: Option<String>,
63    weight: Option<f64>,
64}
65
66pub fn run(args: RelatedArgs) -> Result<(), AppError> {
67    let inicio = std::time::Instant::now();
68    let name = args
69        .name_positional
70        .as_deref()
71        .or(args.name.as_deref())
72        .ok_or_else(|| {
73            AppError::Validation(
74                "name required: pass as positional argument or via --name".to_string(),
75            )
76        })?
77        .to_string();
78
79    if name.trim().is_empty() {
80        return Err(AppError::Validation("name must not be empty".to_string()));
81    }
82
83    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
84    let paths = AppPaths::resolve(args.db.as_deref())?;
85
86    if !paths.db.exists() {
87        return Err(AppError::NotFound(erros::banco_nao_encontrado(
88            &paths.db.display().to_string(),
89        )));
90    }
91
92    let conn = open_ro(&paths.db)?;
93
94    // Locate the seed memory.
95    let seed_id: i64 = match conn.query_row(
96        "SELECT id FROM memories
97         WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
98        params![namespace, name],
99        |r| r.get(0),
100    ) {
101        Ok(id) => id,
102        Err(rusqlite::Error::QueryReturnedNoRows) => {
103            return Err(AppError::NotFound(erros::memoria_nao_encontrada(
104                &name, &namespace,
105            )));
106        }
107        Err(e) => return Err(AppError::Database(e)),
108    };
109
110    // Collect seed entity IDs from seed memory.
111    let seed_entity_ids: Vec<i64> = {
112        let mut stmt =
113            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
114        let rows: Vec<i64> = stmt
115            .query_map(params![seed_id], |r| r.get(0))?
116            .collect::<Result<Vec<i64>, _>>()?;
117        rows
118    };
119
120    let relation_filter = args.relation.map(|r| r.as_str().to_string());
121    let results = traverse_related(
122        &conn,
123        seed_id,
124        &seed_entity_ids,
125        &namespace,
126        args.max_hops,
127        args.min_weight,
128        relation_filter.as_deref(),
129        args.limit,
130    )?;
131
132    match args.format {
133        OutputFormat::Json => output::emit_json(&RelatedResponse {
134            results,
135            elapsed_ms: inicio.elapsed().as_millis() as u64,
136        })?,
137        OutputFormat::Text => {
138            for item in &results {
139                if item.description.is_empty() {
140                    output::emit_text(&format!(
141                        "{}. {} ({})",
142                        item.hop_distance, item.name, item.namespace
143                    ));
144                } else {
145                    let preview: String = item
146                        .description
147                        .chars()
148                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
149                        .collect();
150                    output::emit_text(&format!(
151                        "{}. {} ({}): {}",
152                        item.hop_distance, item.name, item.namespace, preview
153                    ));
154                }
155            }
156        }
157        OutputFormat::Markdown => {
158            for item in &results {
159                if item.description.is_empty() {
160                    output::emit_text(&format!(
161                        "- **{}** ({}) — hop {}",
162                        item.name, item.namespace, item.hop_distance
163                    ));
164                } else {
165                    let preview: String = item
166                        .description
167                        .chars()
168                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
169                        .collect();
170                    output::emit_text(&format!(
171                        "- **{}** ({}) — hop {}: {}",
172                        item.name, item.namespace, item.hop_distance, preview
173                    ));
174                }
175            }
176        }
177    }
178
179    Ok(())
180}
181
182#[allow(clippy::too_many_arguments)]
183fn traverse_related(
184    conn: &Connection,
185    seed_memory_id: i64,
186    seed_entity_ids: &[i64],
187    namespace: &str,
188    max_hops: u32,
189    min_weight: f64,
190    relation_filter: Option<&str>,
191    limit: usize,
192) -> Result<Vec<RelatedMemory>, AppError> {
193    if seed_entity_ids.is_empty() || max_hops == 0 {
194        return Ok(Vec::new());
195    }
196
197    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
198    // of the edge that first reached each entity.
199    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
200    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
201    for &e in seed_entity_ids {
202        entity_hop.insert(e, 0);
203    }
204    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
205    // that reached this entity — equivalent to BFS shortest path recall edge).
206    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
207
208    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
209
210    while let Some(current_entity) = queue.pop_front() {
211        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
212        if current_hop >= max_hops {
213            continue;
214        }
215
216        let neighbours =
217            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
218
219        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
220            if visited.insert(neighbour_id) {
221                entity_hop.insert(neighbour_id, current_hop + 1);
222                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
223                queue.push_back(neighbour_id);
224            }
225        }
226    }
227
228    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
229    let mut out: Vec<RelatedMemory> = Vec::new();
230    let mut dedup_ids: HashSet<i64> = HashSet::new();
231    dedup_ids.insert(seed_memory_id);
232
233    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
234    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
235        .iter()
236        .filter(|(id, _)| !seed_entity_ids.contains(id))
237        .map(|(id, hop)| (*id, *hop))
238        .collect();
239    ordered_entities.sort_by(|a, b| {
240        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
241        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
242        a.1.cmp(&b.1).then_with(|| {
243            weight_b
244                .partial_cmp(&weight_a)
245                .unwrap_or(std::cmp::Ordering::Equal)
246        })
247    });
248
249    for (entity_id, hop) in ordered_entities {
250        let mut stmt = conn.prepare_cached(
251            "SELECT m.id, m.name, m.namespace, m.type, m.description
252             FROM memory_entities me
253             JOIN memories m ON m.id = me.memory_id
254             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
255        )?;
256        let rows = stmt
257            .query_map(params![entity_id], |r| {
258                Ok((
259                    r.get::<_, i64>(0)?,
260                    r.get::<_, String>(1)?,
261                    r.get::<_, String>(2)?,
262                    r.get::<_, String>(3)?,
263                    r.get::<_, String>(4)?,
264                ))
265            })?
266            .collect::<Result<Vec<_>, _>>()?;
267
268        for (mid, name, ns, mtype, desc) in rows {
269            if !dedup_ids.insert(mid) {
270                continue;
271            }
272            let edge = entity_edge.get(&entity_id);
273            out.push(RelatedMemory {
274                memory_id: mid,
275                name,
276                namespace: ns,
277                memory_type: mtype,
278                description: desc,
279                hop_distance: hop,
280                source_entity: edge.map(|e| e.0.clone()),
281                target_entity: edge.map(|e| e.1.clone()),
282                relation: edge.map(|e| e.2.clone()),
283                weight: edge.map(|e| e.3),
284            });
285            if out.len() >= limit {
286                return Ok(out);
287            }
288        }
289    }
290
291    Ok(out)
292}
293
294fn fetch_neighbours(
295    conn: &Connection,
296    entity_id: i64,
297    namespace: &str,
298    min_weight: f64,
299    relation_filter: Option<&str>,
300) -> Result<Vec<Neighbour>, AppError> {
301    // Follow edges in both directions: source -> target and target -> source so traversal is
302    // undirected, which is how users typically reason about "related" memories.
303    let base_sql = "\
304        SELECT r.target_id, se.name, te.name, r.relation, r.weight
305        FROM relationships r
306        JOIN entities se ON se.id = r.source_id
307        JOIN entities te ON te.id = r.target_id
308        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
309
310    let reverse_sql = "\
311        SELECT r.source_id, se.name, te.name, r.relation, r.weight
312        FROM relationships r
313        JOIN entities se ON se.id = r.source_id
314        JOIN entities te ON te.id = r.target_id
315        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
316
317    let mut results: Vec<Neighbour> = Vec::new();
318
319    let forward_sql = match relation_filter {
320        Some(_) => format!("{base_sql} AND r.relation = ?4"),
321        None => base_sql.to_string(),
322    };
323    let rev_sql = match relation_filter {
324        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
325        None => reverse_sql.to_string(),
326    };
327
328    let mut stmt = conn.prepare_cached(&forward_sql)?;
329    let rows: Vec<_> = if let Some(rel) = relation_filter {
330        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
331            Ok((
332                r.get::<_, i64>(0)?,
333                r.get::<_, String>(1)?,
334                r.get::<_, String>(2)?,
335                r.get::<_, String>(3)?,
336                r.get::<_, f64>(4)?,
337            ))
338        })?
339        .collect::<Result<Vec<_>, _>>()?
340    } else {
341        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
342            Ok((
343                r.get::<_, i64>(0)?,
344                r.get::<_, String>(1)?,
345                r.get::<_, String>(2)?,
346                r.get::<_, String>(3)?,
347                r.get::<_, f64>(4)?,
348            ))
349        })?
350        .collect::<Result<Vec<_>, _>>()?
351    };
352    results.extend(rows);
353
354    let mut stmt = conn.prepare_cached(&rev_sql)?;
355    let rows: Vec<_> = if let Some(rel) = relation_filter {
356        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
357            Ok((
358                r.get::<_, i64>(0)?,
359                r.get::<_, String>(1)?,
360                r.get::<_, String>(2)?,
361                r.get::<_, String>(3)?,
362                r.get::<_, f64>(4)?,
363            ))
364        })?
365        .collect::<Result<Vec<_>, _>>()?
366    } else {
367        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
368            Ok((
369                r.get::<_, i64>(0)?,
370                r.get::<_, String>(1)?,
371                r.get::<_, String>(2)?,
372                r.get::<_, String>(3)?,
373                r.get::<_, f64>(4)?,
374            ))
375        })?
376        .collect::<Result<Vec<_>, _>>()?
377    };
378    results.extend(rows);
379
380    Ok(results)
381}
382
383#[cfg(test)]
384mod testes {
385    use super::*;
386
387    fn setup_related_db() -> rusqlite::Connection {
388        let conn = rusqlite::Connection::open_in_memory().expect("falha ao abrir banco em memória");
389        conn.execute_batch(
390            "CREATE TABLE memories (
391                id INTEGER PRIMARY KEY AUTOINCREMENT,
392                name TEXT NOT NULL,
393                namespace TEXT NOT NULL DEFAULT 'global',
394                type TEXT NOT NULL DEFAULT 'fact',
395                description TEXT NOT NULL DEFAULT '',
396                deleted_at INTEGER
397            );
398            CREATE TABLE entities (
399                id INTEGER PRIMARY KEY AUTOINCREMENT,
400                namespace TEXT NOT NULL,
401                name TEXT NOT NULL
402            );
403            CREATE TABLE relationships (
404                id INTEGER PRIMARY KEY AUTOINCREMENT,
405                namespace TEXT NOT NULL,
406                source_id INTEGER NOT NULL,
407                target_id INTEGER NOT NULL,
408                relation TEXT NOT NULL DEFAULT 'related_to',
409                weight REAL NOT NULL DEFAULT 1.0
410            );
411            CREATE TABLE memory_entities (
412                memory_id INTEGER NOT NULL,
413                entity_id INTEGER NOT NULL
414            );",
415        )
416        .expect("falha ao criar tabelas de teste");
417        conn
418    }
419
420    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
421        conn.execute(
422            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
423            rusqlite::params![name, namespace],
424        )
425        .expect("falha ao inserir memória");
426        conn.last_insert_rowid()
427    }
428
429    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
430        conn.execute(
431            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
432            rusqlite::params![name, namespace],
433        )
434        .expect("falha ao inserir entidade");
435        conn.last_insert_rowid()
436    }
437
438    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
439        conn.execute(
440            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
441            rusqlite::params![memory_id, entity_id],
442        )
443        .expect("falha ao vincular memória-entidade");
444    }
445
446    fn insert_relationship(
447        conn: &rusqlite::Connection,
448        namespace: &str,
449        source_id: i64,
450        target_id: i64,
451        relation: &str,
452        weight: f64,
453    ) {
454        conn.execute(
455            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
456             VALUES (?1, ?2, ?3, ?4, ?5)",
457            rusqlite::params![namespace, source_id, target_id, relation, weight],
458        )
459        .expect("falha ao inserir relacionamento");
460    }
461
462    #[test]
463    fn related_response_serializa_results_e_elapsed_ms() {
464        let resp = RelatedResponse {
465            results: vec![RelatedMemory {
466                memory_id: 1,
467                name: "mem-vizinha".to_string(),
468                namespace: "global".to_string(),
469                memory_type: "fact".to_string(),
470                description: "desc".to_string(),
471                hop_distance: 1,
472                source_entity: Some("entidade-a".to_string()),
473                target_entity: Some("entidade-b".to_string()),
474                relation: Some("related_to".to_string()),
475                weight: Some(0.9),
476            }],
477            elapsed_ms: 7,
478        };
479
480        let json = serde_json::to_value(&resp).expect("serialização falhou");
481        assert!(json["results"].is_array());
482        assert_eq!(json["results"].as_array().unwrap().len(), 1);
483        assert_eq!(json["elapsed_ms"], 7u64);
484        assert_eq!(json["results"][0]["type"], "fact");
485        assert_eq!(json["results"][0]["hop_distance"], 1);
486    }
487
488    #[test]
489    fn traverse_related_retorna_vazio_sem_entidades_seed() {
490        let conn = setup_related_db();
491        let resultado = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
492            .expect("traverse_related falhou");
493        assert!(
494            resultado.is_empty(),
495            "sem entidades seed deve retornar vazio"
496        );
497    }
498
499    #[test]
500    fn traverse_related_retorna_vazio_com_max_hops_zero() {
501        let conn = setup_related_db();
502        let mem_id = insert_memory(&conn, "seed-mem", "global");
503        let ent_id = insert_entity(&conn, "ent-a", "global");
504        link_memory_entity(&conn, mem_id, ent_id);
505
506        let resultado = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
507            .expect("traverse_related falhou");
508        assert!(resultado.is_empty(), "max_hops=0 deve retornar vazio");
509    }
510
511    #[test]
512    fn traverse_related_descobre_memoria_vizinha_por_grafo() {
513        let conn = setup_related_db();
514
515        let seed_id = insert_memory(&conn, "seed-mem", "global");
516        let vizinha_id = insert_memory(&conn, "vizinha-mem", "global");
517        let ent_a = insert_entity(&conn, "ent-a", "global");
518        let ent_b = insert_entity(&conn, "ent-b", "global");
519
520        link_memory_entity(&conn, seed_id, ent_a);
521        link_memory_entity(&conn, vizinha_id, ent_b);
522        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
523
524        let resultado = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
525            .expect("traverse_related falhou");
526
527        assert_eq!(resultado.len(), 1, "deve encontrar 1 memória vizinha");
528        assert_eq!(resultado[0].name, "vizinha-mem");
529        assert_eq!(resultado[0].hop_distance, 1);
530    }
531
532    #[test]
533    fn traverse_related_respeita_limite() {
534        let conn = setup_related_db();
535
536        let seed_id = insert_memory(&conn, "seed", "global");
537        let ent_seed = insert_entity(&conn, "ent-seed", "global");
538        link_memory_entity(&conn, seed_id, ent_seed);
539
540        for i in 0..5 {
541            let mem_id = insert_memory(&conn, &format!("vizinha-{i}"), "global");
542            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
543            link_memory_entity(&conn, mem_id, ent_id);
544            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
545        }
546
547        let resultado = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
548            .expect("traverse_related falhou");
549
550        assert!(
551            resultado.len() <= 3,
552            "limite=3 deve restringir a no máximo 3 resultados"
553        );
554    }
555
556    #[test]
557    fn related_memory_campos_opcionais_nulos_serializados() {
558        let mem = RelatedMemory {
559            memory_id: 99,
560            name: "sem-relacao".to_string(),
561            namespace: "ns".to_string(),
562            memory_type: "concept".to_string(),
563            description: "".to_string(),
564            hop_distance: 2,
565            source_entity: None,
566            target_entity: None,
567            relation: None,
568            weight: None,
569        };
570
571        let json = serde_json::to_value(&mem).expect("serialização falhou");
572        assert!(json["source_entity"].is_null());
573        assert!(json["target_entity"].is_null());
574        assert!(json["relation"].is_null());
575        assert!(json["weight"].is_null());
576        assert_eq!(json["hop_distance"], 2);
577    }
578}