1use crate::constants::{
4 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashMap, HashSet, VecDeque};
14
15enum SeedKind {
17 Memory(i64),
18 Entity(i64),
19}
20
21type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n \
27 # List memories connected to a memory via the entity graph (default 2 hops)\n \
28 sqlite-graphrag related onboarding\n\n \
29 # Increase hop distance and filter by relation type\n \
30 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
31 # Cap result count and require minimum edge weight\n \
32 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34 #[arg(
36 value_name = "NAME",
37 conflicts_with = "name",
38 help = "Memory name whose neighbours to traverse; alternative to --name"
39 )]
40 pub name_positional: Option<String>,
41 #[arg(long, alias = "from")]
43 pub name: Option<String>,
44 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46 pub max_hops: u32,
47 #[arg(long, value_parser = crate::parsers::parse_relation)]
52 pub relation: Option<String>,
53 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54 pub min_weight: f64,
55 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56 pub limit: usize,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long, value_enum, default_value = "json")]
60 pub format: OutputFormat,
61 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62 pub json: bool,
63 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64 pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69 name: String,
72 max_hops: u32,
74 results: Vec<RelatedMemory>,
75 elapsed_ms: u64,
76}
77
78#[derive(Serialize, Clone)]
79struct RelatedMemory {
80 memory_id: i64,
81 name: String,
82 namespace: String,
83 #[serde(rename = "type")]
84 memory_type: String,
85 description: String,
86 hop_distance: u32,
87 source_entity: Option<String>,
88 target_entity: Option<String>,
89 relation: Option<String>,
90 weight: Option<f64>,
91}
92
93pub fn run(args: RelatedArgs) -> Result<(), AppError> {
94 let inicio = std::time::Instant::now();
95 let name = args
96 .name_positional
97 .as_deref()
98 .or(args.name.as_deref())
99 .ok_or_else(|| {
100 AppError::Validation(
101 "name required: pass as positional argument or via --name".to_string(),
102 )
103 })?
104 .to_string();
105
106 if name.trim().is_empty() {
107 return Err(AppError::Validation("name must not be empty".to_string()));
108 }
109
110 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
111 let paths = AppPaths::resolve(args.db.as_deref())?;
112
113 crate::storage::connection::ensure_db_ready(&paths)?;
114
115 let conn = open_ro(&paths.db)?;
116
117 let seed = match conn.query_row(
119 "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
120 params![namespace, name],
121 |r| r.get::<_, i64>(0),
122 ) {
123 Ok(id) => SeedKind::Memory(id),
124 Err(rusqlite::Error::QueryReturnedNoRows) => {
125 match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
126 Some(id) => SeedKind::Entity(id),
127 None => {
128 return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
129 &name, &namespace,
130 )))
131 }
132 }
133 }
134 Err(e) => return Err(AppError::Database(e)),
135 };
136
137 let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
139 SeedKind::Memory(id) => {
140 let mem_id = *id;
141 let mut stmt =
142 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
143 let rows: Vec<i64> = stmt
144 .query_map(params![mem_id], |r| r.get(0))?
145 .collect::<Result<Vec<i64>, _>>()?;
146 (mem_id, rows)
147 }
148 SeedKind::Entity(entity_id) => {
149 (-1, vec![*entity_id])
152 }
153 };
154
155 let relation_filter = args.relation;
156 if let Some(ref r) = relation_filter {
157 crate::parsers::warn_if_non_canonical(r);
158 }
159 let results = traverse_related(
160 &conn,
161 seed_memory_id,
162 &seed_entity_ids,
163 &namespace,
164 args.max_hops,
165 args.min_weight,
166 relation_filter.as_deref(),
167 args.limit,
168 )?;
169
170 match args.format {
171 OutputFormat::Json => output::emit_json(&RelatedResponse {
172 name: name.clone(),
173 max_hops: args.max_hops,
174 results,
175 elapsed_ms: inicio.elapsed().as_millis() as u64,
176 })?,
177 OutputFormat::Text => {
178 for item in &results {
179 if item.description.is_empty() {
180 output::emit_text(&format!(
181 "{}. {} ({})",
182 item.hop_distance, item.name, item.namespace
183 ));
184 } else {
185 let preview: String = item
186 .description
187 .chars()
188 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
189 .collect();
190 output::emit_text(&format!(
191 "{}. {} ({}): {}",
192 item.hop_distance, item.name, item.namespace, preview
193 ));
194 }
195 }
196 }
197 OutputFormat::Markdown => {
198 for item in &results {
199 if item.description.is_empty() {
200 output::emit_text(&format!(
201 "- **{}** ({}) — hop {}",
202 item.name, item.namespace, item.hop_distance
203 ));
204 } else {
205 let preview: String = item
206 .description
207 .chars()
208 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
209 .collect();
210 output::emit_text(&format!(
211 "- **{}** ({}) — hop {}: {}",
212 item.name, item.namespace, item.hop_distance, preview
213 ));
214 }
215 }
216 }
217 }
218
219 Ok(())
220}
221
222#[allow(clippy::too_many_arguments)]
223fn traverse_related(
224 conn: &Connection,
225 seed_memory_id: i64,
226 seed_entity_ids: &[i64],
227 namespace: &str,
228 max_hops: u32,
229 min_weight: f64,
230 relation_filter: Option<&str>,
231 limit: usize,
232) -> Result<Vec<RelatedMemory>, AppError> {
233 if seed_entity_ids.is_empty() || max_hops == 0 {
234 return Ok(Vec::new());
235 }
236
237 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
240 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
241 for &e in seed_entity_ids {
242 entity_hop.insert(e, 0);
243 }
244 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
247
248 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
249
250 while let Some(current_entity) = queue.pop_front() {
251 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
252 if current_hop >= max_hops {
253 continue;
254 }
255
256 let neighbours =
257 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
258
259 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
260 if visited.insert(neighbour_id) {
261 entity_hop.insert(neighbour_id, current_hop + 1);
262 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
263 queue.push_back(neighbour_id);
264 }
265 }
266 }
267
268 let mut out: Vec<RelatedMemory> = Vec::with_capacity(limit);
270 let mut dedup_ids: HashSet<i64> = HashSet::new();
271 dedup_ids.insert(seed_memory_id);
272
273 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
275 .iter()
276 .filter(|(id, _)| !seed_entity_ids.contains(id))
277 .map(|(id, hop)| (*id, *hop))
278 .collect();
279 ordered_entities.sort_by(|a, b| {
280 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
281 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
282 a.1.cmp(&b.1).then_with(|| {
283 weight_b
284 .partial_cmp(&weight_a)
285 .unwrap_or(std::cmp::Ordering::Equal)
286 })
287 });
288
289 for (entity_id, hop) in ordered_entities {
290 let mut stmt = conn.prepare_cached(
291 "SELECT m.id, m.name, m.namespace, m.type, m.description
292 FROM memory_entities me
293 JOIN memories m ON m.id = me.memory_id
294 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
295 )?;
296 let rows = stmt
297 .query_map(params![entity_id], |r| {
298 Ok((
299 r.get::<_, i64>(0)?,
300 r.get::<_, String>(1)?,
301 r.get::<_, String>(2)?,
302 r.get::<_, String>(3)?,
303 r.get::<_, String>(4)?,
304 ))
305 })?
306 .collect::<Result<Vec<_>, _>>()?;
307
308 for (mid, name, ns, mtype, desc) in rows {
309 if !dedup_ids.insert(mid) {
310 continue;
311 }
312 let edge = entity_edge.get(&entity_id);
313 out.push(RelatedMemory {
314 memory_id: mid,
315 name,
316 namespace: ns,
317 memory_type: mtype,
318 description: desc,
319 hop_distance: hop,
320 source_entity: edge.map(|e| e.0.clone()),
321 target_entity: edge.map(|e| e.1.clone()),
322 relation: edge.map(|e| e.2.clone()),
323 weight: edge.map(|e| e.3),
324 });
325 if out.len() >= limit {
326 return Ok(out);
327 }
328 }
329 }
330
331 Ok(out)
332}
333
334fn fetch_neighbours(
335 conn: &Connection,
336 entity_id: i64,
337 namespace: &str,
338 min_weight: f64,
339 relation_filter: Option<&str>,
340) -> Result<Vec<Neighbour>, AppError> {
341 let base_sql = "\
344 SELECT r.target_id, se.name, te.name, r.relation, r.weight
345 FROM relationships r
346 JOIN entities se ON se.id = r.source_id
347 JOIN entities te ON te.id = r.target_id
348 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
349
350 let reverse_sql = "\
351 SELECT r.source_id, se.name, te.name, r.relation, r.weight
352 FROM relationships r
353 JOIN entities se ON se.id = r.source_id
354 JOIN entities te ON te.id = r.target_id
355 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
356
357 let mut results: Vec<Neighbour> = Vec::new();
358
359 let forward_sql = match relation_filter {
360 Some(_) => format!("{base_sql} AND r.relation = ?4"),
361 None => base_sql.to_string(),
362 };
363 let rev_sql = match relation_filter {
364 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
365 None => reverse_sql.to_string(),
366 };
367
368 let mut stmt = conn.prepare_cached(&forward_sql)?;
369 let rows: Vec<_> = if let Some(rel) = relation_filter {
370 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
371 Ok((
372 r.get::<_, i64>(0)?,
373 r.get::<_, String>(1)?,
374 r.get::<_, String>(2)?,
375 r.get::<_, String>(3)?,
376 r.get::<_, f64>(4)?,
377 ))
378 })?
379 .collect::<Result<Vec<_>, _>>()?
380 } else {
381 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
382 Ok((
383 r.get::<_, i64>(0)?,
384 r.get::<_, String>(1)?,
385 r.get::<_, String>(2)?,
386 r.get::<_, String>(3)?,
387 r.get::<_, f64>(4)?,
388 ))
389 })?
390 .collect::<Result<Vec<_>, _>>()?
391 };
392 results.extend(rows);
393
394 let mut stmt = conn.prepare_cached(&rev_sql)?;
395 let rows: Vec<_> = if let Some(rel) = relation_filter {
396 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
397 Ok((
398 r.get::<_, i64>(0)?,
399 r.get::<_, String>(1)?,
400 r.get::<_, String>(2)?,
401 r.get::<_, String>(3)?,
402 r.get::<_, f64>(4)?,
403 ))
404 })?
405 .collect::<Result<Vec<_>, _>>()?
406 } else {
407 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
408 Ok((
409 r.get::<_, i64>(0)?,
410 r.get::<_, String>(1)?,
411 r.get::<_, String>(2)?,
412 r.get::<_, String>(3)?,
413 r.get::<_, f64>(4)?,
414 ))
415 })?
416 .collect::<Result<Vec<_>, _>>()?
417 };
418 results.extend(rows);
419
420 Ok(results)
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 fn setup_related_db() -> rusqlite::Connection {
428 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
429 conn.execute_batch(
430 "CREATE TABLE memories (
431 id INTEGER PRIMARY KEY AUTOINCREMENT,
432 name TEXT NOT NULL,
433 namespace TEXT NOT NULL DEFAULT 'global',
434 type TEXT NOT NULL DEFAULT 'fact',
435 description TEXT NOT NULL DEFAULT '',
436 deleted_at INTEGER
437 );
438 CREATE TABLE entities (
439 id INTEGER PRIMARY KEY AUTOINCREMENT,
440 namespace TEXT NOT NULL,
441 name TEXT NOT NULL
442 );
443 CREATE TABLE relationships (
444 id INTEGER PRIMARY KEY AUTOINCREMENT,
445 namespace TEXT NOT NULL,
446 source_id INTEGER NOT NULL,
447 target_id INTEGER NOT NULL,
448 relation TEXT NOT NULL DEFAULT 'related_to',
449 weight REAL NOT NULL DEFAULT 1.0
450 );
451 CREATE TABLE memory_entities (
452 memory_id INTEGER NOT NULL,
453 entity_id INTEGER NOT NULL
454 );",
455 )
456 .expect("failed to create test tables");
457 conn
458 }
459
460 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
461 conn.execute(
462 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
463 rusqlite::params![name, namespace],
464 )
465 .expect("failed to insert memory");
466 conn.last_insert_rowid()
467 }
468
469 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
470 conn.execute(
471 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
472 rusqlite::params![name, namespace],
473 )
474 .expect("failed to insert entity");
475 conn.last_insert_rowid()
476 }
477
478 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
479 conn.execute(
480 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
481 rusqlite::params![memory_id, entity_id],
482 )
483 .expect("failed to link memory-entity");
484 }
485
486 fn insert_relationship(
487 conn: &rusqlite::Connection,
488 namespace: &str,
489 source_id: i64,
490 target_id: i64,
491 relation: &str,
492 weight: f64,
493 ) {
494 conn.execute(
495 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
496 VALUES (?1, ?2, ?3, ?4, ?5)",
497 rusqlite::params![namespace, source_id, target_id, relation, weight],
498 )
499 .expect("failed to insert relationship");
500 }
501
502 #[test]
503 fn related_response_serializes_results_and_elapsed_ms() {
504 let resp = RelatedResponse {
505 name: "seed-mem".to_string(),
506 max_hops: 2,
507 results: vec![RelatedMemory {
508 memory_id: 1,
509 name: "neighbor-mem".to_string(),
510 namespace: "global".to_string(),
511 memory_type: "document".to_string(),
512 description: "desc".to_string(),
513 hop_distance: 1,
514 source_entity: Some("entity-a".to_string()),
515 target_entity: Some("entity-b".to_string()),
516 relation: Some("related_to".to_string()),
517 weight: Some(0.9),
518 }],
519 elapsed_ms: 7,
520 };
521
522 let json = serde_json::to_value(&resp).expect("serialization failed");
523 assert!(json["results"].is_array());
524 assert_eq!(json["results"].as_array().unwrap().len(), 1);
525 assert_eq!(json["elapsed_ms"], 7u64);
526 assert_eq!(json["results"][0]["type"], "document");
527 assert_eq!(json["results"][0]["hop_distance"], 1);
528 }
529
530 #[test]
531 fn traverse_related_returns_empty_without_seed_entities() {
532 let conn = setup_related_db();
533 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
534 .expect("traverse_related failed");
535 assert!(
536 result.is_empty(),
537 "no seed entities must yield empty results"
538 );
539 }
540
541 #[test]
542 fn traverse_related_returns_empty_with_max_hops_zero() {
543 let conn = setup_related_db();
544 let mem_id = insert_memory(&conn, "seed-mem", "global");
545 let ent_id = insert_entity(&conn, "ent-a", "global");
546 link_memory_entity(&conn, mem_id, ent_id);
547
548 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
549 .expect("traverse_related failed");
550 assert!(result.is_empty(), "max_hops=0 must return empty");
551 }
552
553 #[test]
554 fn traverse_related_discovers_neighbor_memory_via_graph() {
555 let conn = setup_related_db();
556
557 let seed_id = insert_memory(&conn, "seed-mem", "global");
558 let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
559 let ent_a = insert_entity(&conn, "ent-a", "global");
560 let ent_b = insert_entity(&conn, "ent-b", "global");
561
562 link_memory_entity(&conn, seed_id, ent_a);
563 link_memory_entity(&conn, neighbor_id, ent_b);
564 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
565
566 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
567 .expect("traverse_related failed");
568
569 assert_eq!(result.len(), 1, "must find 1 neighboring memory");
570 assert_eq!(result[0].name, "neighbor-mem");
571 assert_eq!(result[0].hop_distance, 1);
572 }
573
574 #[test]
575 fn traverse_related_respects_limit() {
576 let conn = setup_related_db();
577
578 let seed_id = insert_memory(&conn, "seed", "global");
579 let ent_seed = insert_entity(&conn, "ent-seed", "global");
580 link_memory_entity(&conn, seed_id, ent_seed);
581
582 for i in 0..5 {
583 let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
584 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
585 link_memory_entity(&conn, mem_id, ent_id);
586 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
587 }
588
589 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
590 .expect("traverse_related failed");
591
592 assert!(
593 result.len() <= 3,
594 "limit=3 must constrain to at most 3 results"
595 );
596 }
597
598 #[test]
599 fn related_memory_optional_null_fields_serialized() {
600 let mem = RelatedMemory {
601 memory_id: 99,
602 name: "no-relation".to_string(),
603 namespace: "ns".to_string(),
604 memory_type: "concept".to_string(),
605 description: "".to_string(),
606 hop_distance: 2,
607 source_entity: None,
608 target_entity: None,
609 relation: None,
610 weight: None,
611 };
612
613 let json = serde_json::to_value(&mem).expect("serialization failed");
614 assert!(json["source_entity"].is_null());
615 assert!(json["target_entity"].is_null());
616 assert!(json["relation"].is_null());
617 assert!(json["weight"].is_null());
618 assert_eq!(json["hop_distance"], 2);
619 }
620}