use crate::errors::AppError;
use rusqlite::{params, Connection};
pub fn traverse_from_memories(
conn: &Connection,
seed_memory_ids: &[i64],
namespace: &str,
min_weight: f64,
max_hops: u32,
) -> Result<Vec<i64>, AppError> {
if seed_memory_ids.is_empty() || max_hops == 0 {
return Ok(vec![]);
}
let mut seed_entities: Vec<i64> = Vec::new();
for &mem_id in seed_memory_ids {
let mut stmt =
conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
let ids: Vec<i64> = stmt
.query_map(params![mem_id], |r| r.get(0))?
.filter_map(|r| r.ok())
.collect();
seed_entities.extend(ids);
}
seed_entities.sort_unstable();
seed_entities.dedup();
if seed_entities.is_empty() {
return Ok(vec![]);
}
use std::collections::HashSet;
let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
let mut frontier = seed_entities.clone();
for _ in 0..max_hops {
if frontier.is_empty() {
break;
}
let mut next_frontier = Vec::new();
for &entity_id in &frontier {
let mut stmt = conn.prepare_cached(
"SELECT target_id FROM relationships
WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
)?;
let neighbors: Vec<i64> = stmt
.query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
.filter_map(|r| r.ok())
.filter(|id| !visited.contains(id))
.collect();
for id in neighbors {
visited.insert(id);
next_frontier.push(id);
}
}
frontier = next_frontier;
}
let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
let graph_only_entities: Vec<i64> = visited
.into_iter()
.filter(|id| !seed_entities.contains(id))
.collect();
let mut result_ids: Vec<i64> = Vec::new();
for &entity_id in &graph_only_entities {
let mut stmt = conn.prepare_cached(
"SELECT DISTINCT me.memory_id
FROM memory_entities me
JOIN memories m ON m.id = me.memory_id
WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
)?;
let mem_ids: Vec<i64> = stmt
.query_map(params![entity_id], |r| r.get(0))?
.filter_map(|r| r.ok())
.filter(|id| !seed_set.contains(id))
.collect();
result_ids.extend(mem_ids);
}
result_ids.sort_unstable();
result_ids.dedup();
Ok(result_ids)
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
fn setup_db() -> Connection {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE memories (
id INTEGER PRIMARY KEY,
namespace TEXT NOT NULL,
deleted_at TEXT
);
CREATE TABLE memory_entities (
memory_id INTEGER NOT NULL,
entity_id INTEGER NOT NULL
);
CREATE TABLE relationships (
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
weight REAL NOT NULL,
namespace TEXT NOT NULL
);",
)
.unwrap();
conn
}
fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
conn.execute(
"INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
params![
id,
namespace,
if deleted { Some("2024-01-01") } else { None }
],
)
.unwrap();
}
fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
conn.execute(
"INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
params![memory_id, entity_id],
)
.unwrap();
}
fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
conn.execute(
"INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
params![src, tgt, weight, ns],
)
.unwrap();
}
#[test]
fn retorna_vazio_quando_seeds_vazio() {
let conn = setup_db();
let resultado = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn retorna_vazio_quando_max_hops_zero() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn retorna_vazio_quando_seed_sem_entidades() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn retorna_vazio_quando_sem_relacionamentos() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn traversal_basico_um_hop() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 1.0, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
assert_eq!(resultado, vec![2]);
}
#[test]
fn traversal_dois_hops() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_memory(&conn, 3, "ns", false);
link_memory_entity(&conn, 3, 30);
insert_relationship(&conn, 10, 20, 1.0, "ns");
insert_relationship(&conn, 20, 30, 1.0, "ns");
let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
resultado.sort_unstable();
assert_eq!(resultado, vec![2, 3]);
}
#[test]
fn max_hops_limita_profundidade() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_memory(&conn, 3, "ns", false);
link_memory_entity(&conn, 3, 30);
insert_relationship(&conn, 10, 20, 1.0, "ns");
insert_relationship(&conn, 20, 30, 1.0, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
assert_eq!(resultado, vec![2]);
assert!(!resultado.contains(&3));
}
#[test]
fn relacionamento_com_peso_abaixo_do_minimo_ignorado() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 0.3, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn relacionamento_com_peso_exatamente_no_minimo_incluido() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 0.5, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
assert_eq!(resultado, vec![2]);
}
#[test]
fn relacionamento_de_namespace_diferente_ignorado() {
let conn = setup_db();
insert_memory(&conn, 1, "ns_a", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns_a", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 1.0, "ns_b");
let resultado = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn seeds_nao_aparecem_no_resultado() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 1.0, "ns");
insert_relationship(&conn, 20, 10, 1.0, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
assert!(!resultado.contains(&1));
assert_eq!(resultado, vec![2]);
}
#[test]
fn memorias_deletadas_nao_incluidas() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", true);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 1.0, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn multiplos_seeds_unidos_no_resultado() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_memory(&conn, 3, "ns", false);
link_memory_entity(&conn, 3, 30);
insert_memory(&conn, 4, "ns", false);
link_memory_entity(&conn, 4, 40);
insert_relationship(&conn, 10, 30, 1.0, "ns");
insert_relationship(&conn, 20, 40, 1.0, "ns");
let mut resultado = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
resultado.sort_unstable();
assert_eq!(resultado, vec![3, 4]);
}
#[test]
fn resultado_sem_duplicatas() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
link_memory_entity(&conn, 1, 11);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_relationship(&conn, 10, 20, 1.0, "ns");
insert_relationship(&conn, 11, 20, 1.0, "ns");
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
assert_eq!(resultado.len(), 1);
assert_eq!(resultado, vec![2]);
}
#[test]
fn single_node_sem_vizinhos_retorna_vazio() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn ciclo_nao_causa_loop_infinito() {
let conn = setup_db();
insert_memory(&conn, 1, "ns", false);
link_memory_entity(&conn, 1, 10);
insert_memory(&conn, 2, "ns", false);
link_memory_entity(&conn, 2, 20);
insert_memory(&conn, 3, "ns", false);
link_memory_entity(&conn, 3, 30);
insert_relationship(&conn, 10, 20, 1.0, "ns");
insert_relationship(&conn, 20, 30, 1.0, "ns");
insert_relationship(&conn, 30, 10, 1.0, "ns");
let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
resultado.sort_unstable();
assert_eq!(resultado, vec![2, 3]);
}
}