Skip to main content

sqlite_graphrag/commands/
related.rs

1use crate::cli::RelationKind;
2use crate::constants::{
3    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
4};
5use crate::errors::AppError;
6use crate::i18n::erros;
7use crate::output::{self, OutputFormat};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use rusqlite::{params, Connection};
11use serde::Serialize;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
15/// target_name, relation, weight).
16type Neighbour = (i64, String, String, String, f64);
17
18#[derive(clap::Args)]
19pub struct RelatedArgs {
20    /// Nome da memória (posicional). Alternativa a `--name` para compatibilidade com doc bilíngue.
21    #[arg(value_name = "NAME", conflicts_with = "name")]
22    pub name_positional: Option<String>,
23    /// Nome da memória (flag). Obrigatório se posicional não fornecido.
24    #[arg(long)]
25    pub name: Option<String>,
26    /// Número máximo de hops no grafo. Aceita alias `--hops` para compatibilidade com doc bilíngue.
27    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
28    pub max_hops: u32,
29    #[arg(long, value_enum)]
30    pub relation: Option<RelationKind>,
31    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
32    pub min_weight: f64,
33    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
34    pub limit: usize,
35    #[arg(long)]
36    pub namespace: Option<String>,
37    #[arg(long, value_enum, default_value = "json")]
38    pub format: OutputFormat,
39    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
40    pub json: bool,
41    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
42    pub db: Option<String>,
43}
44
45#[derive(Serialize)]
46struct RelatedResponse {
47    results: Vec<RelatedMemory>,
48    elapsed_ms: u64,
49}
50
51#[derive(Serialize, Clone)]
52struct RelatedMemory {
53    memory_id: i64,
54    name: String,
55    namespace: String,
56    #[serde(rename = "type")]
57    memory_type: String,
58    description: String,
59    hop_distance: u32,
60    source_entity: Option<String>,
61    target_entity: Option<String>,
62    relation: Option<String>,
63    weight: Option<f64>,
64}
65
66pub fn run(args: RelatedArgs) -> Result<(), AppError> {
67    let inicio = std::time::Instant::now();
68    let name = args
69        .name_positional
70        .as_deref()
71        .or(args.name.as_deref())
72        .ok_or_else(|| {
73            AppError::Validation(
74                "name required: pass as positional argument or via --name".to_string(),
75            )
76        })?
77        .to_string();
78
79    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
80    let paths = AppPaths::resolve(args.db.as_deref())?;
81
82    if !paths.db.exists() {
83        return Err(AppError::NotFound(erros::banco_nao_encontrado(
84            &paths.db.display().to_string(),
85        )));
86    }
87
88    let conn = open_ro(&paths.db)?;
89
90    // Locate the seed memory.
91    let seed_id: i64 = match conn.query_row(
92        "SELECT id FROM memories
93         WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
94        params![namespace, name],
95        |r| r.get(0),
96    ) {
97        Ok(id) => id,
98        Err(rusqlite::Error::QueryReturnedNoRows) => {
99            return Err(AppError::NotFound(erros::memoria_nao_encontrada(
100                &name, &namespace,
101            )));
102        }
103        Err(e) => return Err(AppError::Database(e)),
104    };
105
106    // Collect seed entity IDs from seed memory.
107    let seed_entity_ids: Vec<i64> = {
108        let mut stmt =
109            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
110        let rows: Vec<i64> = stmt
111            .query_map(params![seed_id], |r| r.get(0))?
112            .collect::<Result<Vec<i64>, _>>()?;
113        rows
114    };
115
116    let relation_filter = args.relation.map(|r| r.as_str().to_string());
117    let results = traverse_related(
118        &conn,
119        seed_id,
120        &seed_entity_ids,
121        &namespace,
122        args.max_hops,
123        args.min_weight,
124        relation_filter.as_deref(),
125        args.limit,
126    )?;
127
128    match args.format {
129        OutputFormat::Json => output::emit_json(&RelatedResponse {
130            results,
131            elapsed_ms: inicio.elapsed().as_millis() as u64,
132        })?,
133        OutputFormat::Text => {
134            for item in &results {
135                if item.description.is_empty() {
136                    output::emit_text(&format!(
137                        "{}. {} ({})",
138                        item.hop_distance, item.name, item.namespace
139                    ));
140                } else {
141                    let preview: String = item
142                        .description
143                        .chars()
144                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
145                        .collect();
146                    output::emit_text(&format!(
147                        "{}. {} ({}): {}",
148                        item.hop_distance, item.name, item.namespace, preview
149                    ));
150                }
151            }
152        }
153        OutputFormat::Markdown => {
154            for item in &results {
155                if item.description.is_empty() {
156                    output::emit_text(&format!(
157                        "- **{}** ({}) — hop {}",
158                        item.name, item.namespace, item.hop_distance
159                    ));
160                } else {
161                    let preview: String = item
162                        .description
163                        .chars()
164                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
165                        .collect();
166                    output::emit_text(&format!(
167                        "- **{}** ({}) — hop {}: {}",
168                        item.name, item.namespace, item.hop_distance, preview
169                    ));
170                }
171            }
172        }
173    }
174
175    Ok(())
176}
177
178#[allow(clippy::too_many_arguments)]
179fn traverse_related(
180    conn: &Connection,
181    seed_memory_id: i64,
182    seed_entity_ids: &[i64],
183    namespace: &str,
184    max_hops: u32,
185    min_weight: f64,
186    relation_filter: Option<&str>,
187    limit: usize,
188) -> Result<Vec<RelatedMemory>, AppError> {
189    if seed_entity_ids.is_empty() || max_hops == 0 {
190        return Ok(Vec::new());
191    }
192
193    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
194    // of the edge that first reached each entity.
195    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
196    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
197    for &e in seed_entity_ids {
198        entity_hop.insert(e, 0);
199    }
200    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
201    // that reached this entity — equivalent to BFS shortest path recall edge).
202    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
203
204    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
205
206    while let Some(current_entity) = queue.pop_front() {
207        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
208        if current_hop >= max_hops {
209            continue;
210        }
211
212        let neighbours =
213            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
214
215        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
216            if visited.insert(neighbour_id) {
217                entity_hop.insert(neighbour_id, current_hop + 1);
218                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
219                queue.push_back(neighbour_id);
220            }
221        }
222    }
223
224    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
225    let mut out: Vec<RelatedMemory> = Vec::new();
226    let mut dedup_ids: HashSet<i64> = HashSet::new();
227    dedup_ids.insert(seed_memory_id);
228
229    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
230    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
231        .iter()
232        .filter(|(id, _)| !seed_entity_ids.contains(id))
233        .map(|(id, hop)| (*id, *hop))
234        .collect();
235    ordered_entities.sort_by(|a, b| {
236        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
237        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
238        a.1.cmp(&b.1).then_with(|| {
239            weight_b
240                .partial_cmp(&weight_a)
241                .unwrap_or(std::cmp::Ordering::Equal)
242        })
243    });
244
245    for (entity_id, hop) in ordered_entities {
246        let mut stmt = conn.prepare_cached(
247            "SELECT m.id, m.name, m.namespace, m.type, m.description
248             FROM memory_entities me
249             JOIN memories m ON m.id = me.memory_id
250             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
251        )?;
252        let rows = stmt
253            .query_map(params![entity_id], |r| {
254                Ok((
255                    r.get::<_, i64>(0)?,
256                    r.get::<_, String>(1)?,
257                    r.get::<_, String>(2)?,
258                    r.get::<_, String>(3)?,
259                    r.get::<_, String>(4)?,
260                ))
261            })?
262            .collect::<Result<Vec<_>, _>>()?;
263
264        for (mid, name, ns, mtype, desc) in rows {
265            if !dedup_ids.insert(mid) {
266                continue;
267            }
268            let edge = entity_edge.get(&entity_id);
269            out.push(RelatedMemory {
270                memory_id: mid,
271                name,
272                namespace: ns,
273                memory_type: mtype,
274                description: desc,
275                hop_distance: hop,
276                source_entity: edge.map(|e| e.0.clone()),
277                target_entity: edge.map(|e| e.1.clone()),
278                relation: edge.map(|e| e.2.clone()),
279                weight: edge.map(|e| e.3),
280            });
281            if out.len() >= limit {
282                return Ok(out);
283            }
284        }
285    }
286
287    Ok(out)
288}
289
290fn fetch_neighbours(
291    conn: &Connection,
292    entity_id: i64,
293    namespace: &str,
294    min_weight: f64,
295    relation_filter: Option<&str>,
296) -> Result<Vec<Neighbour>, AppError> {
297    // Follow edges in both directions: source -> target and target -> source so traversal is
298    // undirected, which is how users typically reason about "related" memories.
299    let base_sql = "\
300        SELECT r.target_id, se.name, te.name, r.relation, r.weight
301        FROM relationships r
302        JOIN entities se ON se.id = r.source_id
303        JOIN entities te ON te.id = r.target_id
304        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
305
306    let reverse_sql = "\
307        SELECT r.source_id, se.name, te.name, r.relation, r.weight
308        FROM relationships r
309        JOIN entities se ON se.id = r.source_id
310        JOIN entities te ON te.id = r.target_id
311        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
312
313    let mut results: Vec<Neighbour> = Vec::new();
314
315    let forward_sql = match relation_filter {
316        Some(_) => format!("{base_sql} AND r.relation = ?4"),
317        None => base_sql.to_string(),
318    };
319    let rev_sql = match relation_filter {
320        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
321        None => reverse_sql.to_string(),
322    };
323
324    let mut stmt = conn.prepare_cached(&forward_sql)?;
325    let rows: Vec<_> = if let Some(rel) = relation_filter {
326        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
327            Ok((
328                r.get::<_, i64>(0)?,
329                r.get::<_, String>(1)?,
330                r.get::<_, String>(2)?,
331                r.get::<_, String>(3)?,
332                r.get::<_, f64>(4)?,
333            ))
334        })?
335        .collect::<Result<Vec<_>, _>>()?
336    } else {
337        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
338            Ok((
339                r.get::<_, i64>(0)?,
340                r.get::<_, String>(1)?,
341                r.get::<_, String>(2)?,
342                r.get::<_, String>(3)?,
343                r.get::<_, f64>(4)?,
344            ))
345        })?
346        .collect::<Result<Vec<_>, _>>()?
347    };
348    results.extend(rows);
349
350    let mut stmt = conn.prepare_cached(&rev_sql)?;
351    let rows: Vec<_> = if let Some(rel) = relation_filter {
352        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
353            Ok((
354                r.get::<_, i64>(0)?,
355                r.get::<_, String>(1)?,
356                r.get::<_, String>(2)?,
357                r.get::<_, String>(3)?,
358                r.get::<_, f64>(4)?,
359            ))
360        })?
361        .collect::<Result<Vec<_>, _>>()?
362    } else {
363        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
364            Ok((
365                r.get::<_, i64>(0)?,
366                r.get::<_, String>(1)?,
367                r.get::<_, String>(2)?,
368                r.get::<_, String>(3)?,
369                r.get::<_, f64>(4)?,
370            ))
371        })?
372        .collect::<Result<Vec<_>, _>>()?
373    };
374    results.extend(rows);
375
376    Ok(results)
377}
378
379#[cfg(test)]
380mod testes {
381    use super::*;
382
383    fn setup_related_db() -> rusqlite::Connection {
384        let conn = rusqlite::Connection::open_in_memory().expect("falha ao abrir banco em memória");
385        conn.execute_batch(
386            "CREATE TABLE memories (
387                id INTEGER PRIMARY KEY AUTOINCREMENT,
388                name TEXT NOT NULL,
389                namespace TEXT NOT NULL DEFAULT 'global',
390                type TEXT NOT NULL DEFAULT 'fact',
391                description TEXT NOT NULL DEFAULT '',
392                deleted_at INTEGER
393            );
394            CREATE TABLE entities (
395                id INTEGER PRIMARY KEY AUTOINCREMENT,
396                namespace TEXT NOT NULL,
397                name TEXT NOT NULL
398            );
399            CREATE TABLE relationships (
400                id INTEGER PRIMARY KEY AUTOINCREMENT,
401                namespace TEXT NOT NULL,
402                source_id INTEGER NOT NULL,
403                target_id INTEGER NOT NULL,
404                relation TEXT NOT NULL DEFAULT 'related_to',
405                weight REAL NOT NULL DEFAULT 1.0
406            );
407            CREATE TABLE memory_entities (
408                memory_id INTEGER NOT NULL,
409                entity_id INTEGER NOT NULL
410            );",
411        )
412        .expect("falha ao criar tabelas de teste");
413        conn
414    }
415
416    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
417        conn.execute(
418            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
419            rusqlite::params![name, namespace],
420        )
421        .expect("falha ao inserir memória");
422        conn.last_insert_rowid()
423    }
424
425    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
426        conn.execute(
427            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
428            rusqlite::params![name, namespace],
429        )
430        .expect("falha ao inserir entidade");
431        conn.last_insert_rowid()
432    }
433
434    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
435        conn.execute(
436            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
437            rusqlite::params![memory_id, entity_id],
438        )
439        .expect("falha ao vincular memória-entidade");
440    }
441
442    fn insert_relationship(
443        conn: &rusqlite::Connection,
444        namespace: &str,
445        source_id: i64,
446        target_id: i64,
447        relation: &str,
448        weight: f64,
449    ) {
450        conn.execute(
451            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
452             VALUES (?1, ?2, ?3, ?4, ?5)",
453            rusqlite::params![namespace, source_id, target_id, relation, weight],
454        )
455        .expect("falha ao inserir relacionamento");
456    }
457
458    #[test]
459    fn related_response_serializa_results_e_elapsed_ms() {
460        let resp = RelatedResponse {
461            results: vec![RelatedMemory {
462                memory_id: 1,
463                name: "mem-vizinha".to_string(),
464                namespace: "global".to_string(),
465                memory_type: "fact".to_string(),
466                description: "desc".to_string(),
467                hop_distance: 1,
468                source_entity: Some("entidade-a".to_string()),
469                target_entity: Some("entidade-b".to_string()),
470                relation: Some("related_to".to_string()),
471                weight: Some(0.9),
472            }],
473            elapsed_ms: 7,
474        };
475
476        let json = serde_json::to_value(&resp).expect("serialização falhou");
477        assert!(json["results"].is_array());
478        assert_eq!(json["results"].as_array().unwrap().len(), 1);
479        assert_eq!(json["elapsed_ms"], 7u64);
480        assert_eq!(json["results"][0]["type"], "fact");
481        assert_eq!(json["results"][0]["hop_distance"], 1);
482    }
483
484    #[test]
485    fn traverse_related_retorna_vazio_sem_entidades_seed() {
486        let conn = setup_related_db();
487        let resultado = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
488            .expect("traverse_related falhou");
489        assert!(
490            resultado.is_empty(),
491            "sem entidades seed deve retornar vazio"
492        );
493    }
494
495    #[test]
496    fn traverse_related_retorna_vazio_com_max_hops_zero() {
497        let conn = setup_related_db();
498        let mem_id = insert_memory(&conn, "seed-mem", "global");
499        let ent_id = insert_entity(&conn, "ent-a", "global");
500        link_memory_entity(&conn, mem_id, ent_id);
501
502        let resultado = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
503            .expect("traverse_related falhou");
504        assert!(resultado.is_empty(), "max_hops=0 deve retornar vazio");
505    }
506
507    #[test]
508    fn traverse_related_descobre_memoria_vizinha_por_grafo() {
509        let conn = setup_related_db();
510
511        let seed_id = insert_memory(&conn, "seed-mem", "global");
512        let vizinha_id = insert_memory(&conn, "vizinha-mem", "global");
513        let ent_a = insert_entity(&conn, "ent-a", "global");
514        let ent_b = insert_entity(&conn, "ent-b", "global");
515
516        link_memory_entity(&conn, seed_id, ent_a);
517        link_memory_entity(&conn, vizinha_id, ent_b);
518        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
519
520        let resultado = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
521            .expect("traverse_related falhou");
522
523        assert_eq!(resultado.len(), 1, "deve encontrar 1 memória vizinha");
524        assert_eq!(resultado[0].name, "vizinha-mem");
525        assert_eq!(resultado[0].hop_distance, 1);
526    }
527
528    #[test]
529    fn traverse_related_respeita_limite() {
530        let conn = setup_related_db();
531
532        let seed_id = insert_memory(&conn, "seed", "global");
533        let ent_seed = insert_entity(&conn, "ent-seed", "global");
534        link_memory_entity(&conn, seed_id, ent_seed);
535
536        for i in 0..5 {
537            let mem_id = insert_memory(&conn, &format!("vizinha-{i}"), "global");
538            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
539            link_memory_entity(&conn, mem_id, ent_id);
540            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
541        }
542
543        let resultado = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
544            .expect("traverse_related falhou");
545
546        assert!(
547            resultado.len() <= 3,
548            "limite=3 deve restringir a no máximo 3 resultados"
549        );
550    }
551
552    #[test]
553    fn related_memory_campos_opcionais_nulos_serializados() {
554        let mem = RelatedMemory {
555            memory_id: 99,
556            name: "sem-relacao".to_string(),
557            namespace: "ns".to_string(),
558            memory_type: "concept".to_string(),
559            description: "".to_string(),
560            hop_distance: 2,
561            source_entity: None,
562            target_entity: None,
563            relation: None,
564            weight: None,
565        };
566
567        let json = serde_json::to_value(&mem).expect("serialização falhou");
568        assert!(json["source_entity"].is_null());
569        assert!(json["target_entity"].is_null());
570        assert!(json["relation"].is_null());
571        assert!(json["weight"].is_null());
572        assert_eq!(json["hop_distance"], 2);
573    }
574}