Skip to main content

sqlite_graphrag/
graph.rs

1// src/graph.rs
2
3use crate::errors::AppError;
4use rusqlite::{params, Connection};
5
6/// Percorre o grafo de entidades por BFS a partir de memórias-semente.
7///
8/// Retorna `memory_id`s alcançáveis pelo grafo de entidades e relacionamentos,
9/// excluindo as próprias sementes. O algoritmo:
10/// 1. Coleta entidades associadas às sementes via `memory_entities`.
11/// 2. Executa BFS sobre `relationships` filtrando por `weight >= min_weight` e `namespace`.
12/// 3. Retorna memórias ligadas às entidades descobertas (excluindo soft-deleted).
13///
14/// # Errors
15///
16/// Propaga [`AppError::Database`] (exit 10) em falhas de consulta SQLite.
17///
18/// # Examples
19///
20/// ```
21/// use rusqlite::Connection;
22/// use sqlite_graphrag::graph::traverse_from_memories;
23///
24/// // Lista de sementes vazia retorna imediatamente sem consultar o banco.
25/// let conn = Connection::open_in_memory().unwrap();
26/// let ids = traverse_from_memories(&conn, &[], "global", 0.5, 3).unwrap();
27/// assert!(ids.is_empty());
28/// ```
29///
30/// ```
31/// use rusqlite::Connection;
32/// use sqlite_graphrag::graph::traverse_from_memories;
33///
34/// // max_hops == 0 retorna imediatamente sem traversal.
35/// let conn = Connection::open_in_memory().unwrap();
36/// let ids = traverse_from_memories(&conn, &[1, 2], "global", 0.5, 0).unwrap();
37/// assert!(ids.is_empty());
38/// ```
39pub fn traverse_from_memories(
40    conn: &Connection,
41    seed_memory_ids: &[i64],
42    namespace: &str,
43    min_weight: f64,
44    max_hops: u32,
45) -> Result<Vec<i64>, AppError> {
46    if seed_memory_ids.is_empty() || max_hops == 0 {
47        return Ok(vec![]);
48    }
49
50    // Step 1: collect seed entity IDs from seed memories
51    let mut seed_entities: Vec<i64> = Vec::new();
52    for &mem_id in seed_memory_ids {
53        let mut stmt =
54            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
55        let ids: Vec<i64> = stmt
56            .query_map(params![mem_id], |r| r.get(0))?
57            .filter_map(|r| r.ok())
58            .collect();
59        seed_entities.extend(ids);
60    }
61    seed_entities.sort_unstable();
62    seed_entities.dedup();
63
64    if seed_entities.is_empty() {
65        return Ok(vec![]);
66    }
67
68    // Step 2: BFS over relationships
69    use std::collections::HashSet;
70    let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
71    let mut frontier = seed_entities.clone();
72
73    for _ in 0..max_hops {
74        if frontier.is_empty() {
75            break;
76        }
77        let mut next_frontier = Vec::new();
78
79        for &entity_id in &frontier {
80            let mut stmt = conn.prepare_cached(
81                "SELECT target_id FROM relationships
82                 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
83            )?;
84            let neighbors: Vec<i64> = stmt
85                .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
86                .filter_map(|r| r.ok())
87                .filter(|id| !visited.contains(id))
88                .collect();
89
90            for id in neighbors {
91                visited.insert(id);
92                next_frontier.push(id);
93            }
94        }
95        frontier = next_frontier;
96    }
97
98    // Step 3: find memories connected to traversed entities (excluding seeds)
99    let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
100    let graph_only_entities: Vec<i64> = visited
101        .into_iter()
102        .filter(|id| !seed_entities.contains(id))
103        .collect();
104
105    let mut result_ids: Vec<i64> = Vec::new();
106    for &entity_id in &graph_only_entities {
107        let mut stmt = conn.prepare_cached(
108            "SELECT DISTINCT me.memory_id
109             FROM memory_entities me
110             JOIN memories m ON m.id = me.memory_id
111             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
112        )?;
113        let mem_ids: Vec<i64> = stmt
114            .query_map(params![entity_id], |r| r.get(0))?
115            .filter_map(|r| r.ok())
116            .filter(|id| !seed_set.contains(id))
117            .collect();
118        result_ids.extend(mem_ids);
119    }
120
121    result_ids.sort_unstable();
122    result_ids.dedup();
123    Ok(result_ids)
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use rusqlite::Connection;
130
131    fn setup_db() -> Connection {
132        let conn = Connection::open_in_memory().unwrap();
133        conn.execute_batch(
134            "CREATE TABLE memories (
135                id INTEGER PRIMARY KEY,
136                namespace TEXT NOT NULL,
137                deleted_at TEXT
138            );
139            CREATE TABLE memory_entities (
140                memory_id INTEGER NOT NULL,
141                entity_id INTEGER NOT NULL
142            );
143            CREATE TABLE relationships (
144                source_id INTEGER NOT NULL,
145                target_id INTEGER NOT NULL,
146                weight REAL NOT NULL,
147                namespace TEXT NOT NULL
148            );",
149        )
150        .unwrap();
151        conn
152    }
153
154    fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
155        conn.execute(
156            "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
157            params![
158                id,
159                namespace,
160                if deleted { Some("2024-01-01") } else { None }
161            ],
162        )
163        .unwrap();
164    }
165
166    fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
167        conn.execute(
168            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
169            params![memory_id, entity_id],
170        )
171        .unwrap();
172    }
173
174    fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
175        conn.execute(
176            "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
177            params![src, tgt, weight, ns],
178        )
179        .unwrap();
180    }
181
182    // --- edge cases retornando vazio ---
183
184    #[test]
185    fn retorna_vazio_quando_seeds_vazio() {
186        let conn = setup_db();
187        let resultado = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
188        assert!(resultado.is_empty());
189    }
190
191    #[test]
192    fn retorna_vazio_quando_max_hops_zero() {
193        let conn = setup_db();
194        insert_memory(&conn, 1, "ns", false);
195        link_memory_entity(&conn, 1, 10);
196        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
197        assert!(resultado.is_empty());
198    }
199
200    #[test]
201    fn retorna_vazio_quando_seed_sem_entidades() {
202        let conn = setup_db();
203        insert_memory(&conn, 1, "ns", false);
204        // memoria existe mas não tem entidades associadas
205        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
206        assert!(resultado.is_empty());
207    }
208
209    #[test]
210    fn retorna_vazio_quando_sem_relacionamentos() {
211        let conn = setup_db();
212        insert_memory(&conn, 1, "ns", false);
213        link_memory_entity(&conn, 1, 10);
214        // entidade 10 não tem relacionamentos
215        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
216        assert!(resultado.is_empty());
217    }
218
219    // --- happy path básico ---
220
221    #[test]
222    fn traversal_basico_um_hop() {
223        let conn = setup_db();
224
225        // seed: memory 1 com entity 10
226        insert_memory(&conn, 1, "ns", false);
227        link_memory_entity(&conn, 1, 10);
228
229        // vizinha: entity 20 ligada a memory 2
230        insert_memory(&conn, 2, "ns", false);
231        link_memory_entity(&conn, 2, 20);
232
233        // relacionamento 10 -> 20
234        insert_relationship(&conn, 10, 20, 1.0, "ns");
235
236        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
237        assert_eq!(resultado, vec![2]);
238    }
239
240    #[test]
241    fn traversal_dois_hops() {
242        let conn = setup_db();
243
244        insert_memory(&conn, 1, "ns", false);
245        link_memory_entity(&conn, 1, 10);
246
247        insert_memory(&conn, 2, "ns", false);
248        link_memory_entity(&conn, 2, 20);
249
250        insert_memory(&conn, 3, "ns", false);
251        link_memory_entity(&conn, 3, 30);
252
253        // cadeia 10 -> 20 -> 30
254        insert_relationship(&conn, 10, 20, 1.0, "ns");
255        insert_relationship(&conn, 20, 30, 1.0, "ns");
256
257        let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
258        resultado.sort_unstable();
259        assert_eq!(resultado, vec![2, 3]);
260    }
261
262    #[test]
263    fn max_hops_limita_profundidade() {
264        let conn = setup_db();
265
266        insert_memory(&conn, 1, "ns", false);
267        link_memory_entity(&conn, 1, 10);
268
269        insert_memory(&conn, 2, "ns", false);
270        link_memory_entity(&conn, 2, 20);
271
272        insert_memory(&conn, 3, "ns", false);
273        link_memory_entity(&conn, 3, 30);
274
275        insert_relationship(&conn, 10, 20, 1.0, "ns");
276        insert_relationship(&conn, 20, 30, 1.0, "ns");
277
278        // com apenas 1 hop, memory 3 não deve aparecer
279        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
280        assert_eq!(resultado, vec![2]);
281        assert!(!resultado.contains(&3));
282    }
283
284    // --- filtro de peso ---
285
286    #[test]
287    fn relacionamento_com_peso_abaixo_do_minimo_ignorado() {
288        let conn = setup_db();
289
290        insert_memory(&conn, 1, "ns", false);
291        link_memory_entity(&conn, 1, 10);
292
293        insert_memory(&conn, 2, "ns", false);
294        link_memory_entity(&conn, 2, 20);
295
296        // peso 0.3 < min_weight 0.5
297        insert_relationship(&conn, 10, 20, 0.3, "ns");
298
299        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
300        assert!(resultado.is_empty());
301    }
302
303    #[test]
304    fn relacionamento_com_peso_exatamente_no_minimo_incluido() {
305        let conn = setup_db();
306
307        insert_memory(&conn, 1, "ns", false);
308        link_memory_entity(&conn, 1, 10);
309
310        insert_memory(&conn, 2, "ns", false);
311        link_memory_entity(&conn, 2, 20);
312
313        insert_relationship(&conn, 10, 20, 0.5, "ns");
314
315        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
316        assert_eq!(resultado, vec![2]);
317    }
318
319    // --- isolamento de namespace ---
320
321    #[test]
322    fn relacionamento_de_namespace_diferente_ignorado() {
323        let conn = setup_db();
324
325        insert_memory(&conn, 1, "ns_a", false);
326        link_memory_entity(&conn, 1, 10);
327
328        insert_memory(&conn, 2, "ns_a", false);
329        link_memory_entity(&conn, 2, 20);
330
331        // relacionamento no namespace errado
332        insert_relationship(&conn, 10, 20, 1.0, "ns_b");
333
334        let resultado = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
335        assert!(resultado.is_empty());
336    }
337
338    // --- excluir seeds do resultado ---
339
340    #[test]
341    fn seeds_nao_aparecem_no_resultado() {
342        let conn = setup_db();
343
344        insert_memory(&conn, 1, "ns", false);
345        link_memory_entity(&conn, 1, 10);
346
347        insert_memory(&conn, 2, "ns", false);
348        link_memory_entity(&conn, 2, 20);
349
350        // relacionamento de 20 de volta para 10 (ciclo)
351        insert_relationship(&conn, 10, 20, 1.0, "ns");
352        insert_relationship(&conn, 20, 10, 1.0, "ns");
353
354        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
355        // memory 1 não deve aparecer mesmo com ciclo
356        assert!(!resultado.contains(&1));
357        assert_eq!(resultado, vec![2]);
358    }
359
360    // --- memórias soft-deleted excluídas ---
361
362    #[test]
363    fn memorias_deletadas_nao_incluidas() {
364        let conn = setup_db();
365
366        insert_memory(&conn, 1, "ns", false);
367        link_memory_entity(&conn, 1, 10);
368
369        // memory 2 foi deletada
370        insert_memory(&conn, 2, "ns", true);
371        link_memory_entity(&conn, 2, 20);
372
373        insert_relationship(&conn, 10, 20, 1.0, "ns");
374
375        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
376        assert!(resultado.is_empty());
377    }
378
379    // --- múltiplos seeds ---
380
381    #[test]
382    fn multiplos_seeds_unidos_no_resultado() {
383        let conn = setup_db();
384
385        insert_memory(&conn, 1, "ns", false);
386        link_memory_entity(&conn, 1, 10);
387
388        insert_memory(&conn, 2, "ns", false);
389        link_memory_entity(&conn, 2, 20);
390
391        insert_memory(&conn, 3, "ns", false);
392        link_memory_entity(&conn, 3, 30);
393
394        insert_memory(&conn, 4, "ns", false);
395        link_memory_entity(&conn, 4, 40);
396
397        insert_relationship(&conn, 10, 30, 1.0, "ns");
398        insert_relationship(&conn, 20, 40, 1.0, "ns");
399
400        let mut resultado = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
401        resultado.sort_unstable();
402        assert_eq!(resultado, vec![3, 4]);
403    }
404
405    // --- deduplicação de resultado ---
406
407    #[test]
408    fn resultado_sem_duplicatas() {
409        let conn = setup_db();
410
411        insert_memory(&conn, 1, "ns", false);
412        link_memory_entity(&conn, 1, 10);
413        link_memory_entity(&conn, 1, 11); // dois seeds na mesma memory
414
415        insert_memory(&conn, 2, "ns", false);
416        link_memory_entity(&conn, 2, 20);
417
418        // ambos os seeds apontam para a mesma entity 20
419        insert_relationship(&conn, 10, 20, 1.0, "ns");
420        insert_relationship(&conn, 11, 20, 1.0, "ns");
421
422        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
423        // memory 2 deve aparecer apenas uma vez
424        assert_eq!(resultado.len(), 1);
425        assert_eq!(resultado, vec![2]);
426    }
427
428    // --- nó único (single node) ---
429
430    #[test]
431    fn single_node_sem_vizinhos_retorna_vazio() {
432        let conn = setup_db();
433
434        insert_memory(&conn, 1, "ns", false);
435        link_memory_entity(&conn, 1, 10);
436        // entity 10 não tem relacionamentos de saída
437
438        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
439        assert!(resultado.is_empty());
440    }
441
442    // --- ciclos no grafo ---
443
444    #[test]
445    fn ciclo_nao_causa_loop_infinito() {
446        let conn = setup_db();
447
448        insert_memory(&conn, 1, "ns", false);
449        link_memory_entity(&conn, 1, 10);
450
451        insert_memory(&conn, 2, "ns", false);
452        link_memory_entity(&conn, 2, 20);
453
454        insert_memory(&conn, 3, "ns", false);
455        link_memory_entity(&conn, 3, 30);
456
457        // triângulo 10 -> 20 -> 30 -> 10
458        insert_relationship(&conn, 10, 20, 1.0, "ns");
459        insert_relationship(&conn, 20, 30, 1.0, "ns");
460        insert_relationship(&conn, 30, 10, 1.0, "ns");
461
462        let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
463        resultado.sort_unstable();
464        // deve retornar 2 e 3 sem loop infinito
465        assert_eq!(resultado, vec![2, 3]);
466    }
467}