1use crate::constants::{
4 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashSet, VecDeque};
14
15enum SeedKind {
17 Memory(i64),
18 Entity(i64),
19}
20
21type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n \
27 # List memories connected to a memory via the entity graph (default 2 hops)\n \
28 sqlite-graphrag related onboarding\n\n \
29 # Increase hop distance and filter by relation type\n \
30 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
31 # Cap result count and require minimum edge weight\n \
32 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34 #[arg(
36 value_name = "NAME",
37 conflicts_with = "name",
38 help = "Memory name whose neighbours to traverse; alternative to --name"
39 )]
40 pub name_positional: Option<String>,
41 #[arg(long, alias = "from")]
43 pub name: Option<String>,
44 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46 pub max_hops: u32,
47 #[arg(long, value_parser = crate::parsers::parse_relation)]
52 pub relation: Option<String>,
53 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54 pub min_weight: f64,
55 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56 pub limit: usize,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long, value_enum, default_value = "json")]
60 pub format: OutputFormat,
61 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62 pub json: bool,
63 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64 pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69 name: String,
72 max_hops: u32,
74 results: Vec<RelatedMemory>,
75 related_memories: Vec<RelatedMemory>,
77 elapsed_ms: u64,
78}
79
80#[derive(Serialize, Clone)]
81struct RelatedMemory {
82 memory_id: i64,
83 name: String,
84 namespace: String,
85 #[serde(rename = "type")]
86 memory_type: String,
87 description: String,
88 hop_distance: u32,
89 source_entity: Option<String>,
90 target_entity: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 from: Option<String>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 to: Option<String>,
97 relation: Option<String>,
98 weight: Option<f64>,
99}
100
101pub fn run(args: RelatedArgs) -> Result<(), AppError> {
102 let inicio = std::time::Instant::now();
103 let name = args
104 .name_positional
105 .as_deref()
106 .or(args.name.as_deref())
107 .ok_or_else(|| {
108 AppError::Validation(
109 "name required: pass as positional argument or via --name".to_string(),
110 )
111 })?
112 .to_string();
113
114 if name.trim().is_empty() {
115 return Err(AppError::Validation("name must not be empty".to_string()));
116 }
117
118 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
119 let paths = AppPaths::resolve(args.db.as_deref())?;
120
121 crate::storage::connection::ensure_db_ready(&paths)?;
122
123 let conn = open_ro(&paths.db)?;
124
125 let seed = match conn.query_row(
127 "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
128 params![namespace, name],
129 |r| r.get::<_, i64>(0),
130 ) {
131 Ok(id) => SeedKind::Memory(id),
132 Err(rusqlite::Error::QueryReturnedNoRows) => {
133 match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
134 Some(id) => SeedKind::Entity(id),
135 None => {
136 return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
137 &name, &namespace,
138 )))
139 }
140 }
141 }
142 Err(e) => return Err(AppError::Database(e)),
143 };
144
145 let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
147 SeedKind::Memory(id) => {
148 let mem_id = *id;
149 let mut stmt =
150 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151 let rows: Vec<i64> = stmt
152 .query_map(params![mem_id], |r| r.get(0))?
153 .collect::<Result<Vec<i64>, _>>()?;
154 (mem_id, rows)
155 }
156 SeedKind::Entity(entity_id) => {
157 (-1, vec![*entity_id])
160 }
161 };
162
163 let relation_filter = args.relation;
164 if let Some(ref r) = relation_filter {
165 crate::parsers::warn_if_non_canonical(r);
166 }
167 let results = traverse_related(
168 &conn,
169 seed_memory_id,
170 &seed_entity_ids,
171 &namespace,
172 args.max_hops,
173 args.min_weight,
174 relation_filter.as_deref(),
175 args.limit,
176 )?;
177
178 match args.format {
179 OutputFormat::Json => {
180 let related_memories = results.clone();
181 output::emit_json(&RelatedResponse {
182 name: name.clone(),
183 max_hops: args.max_hops,
184 results,
185 related_memories,
186 elapsed_ms: inicio.elapsed().as_millis() as u64,
187 })?;
188 }
189 OutputFormat::Text => {
190 for item in &results {
191 if item.description.is_empty() {
192 output::emit_text(&format!(
193 "{}. {} ({})",
194 item.hop_distance, item.name, item.namespace
195 ));
196 } else {
197 let preview: String = item
198 .description
199 .chars()
200 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
201 .collect();
202 output::emit_text(&format!(
203 "{}. {} ({}): {}",
204 item.hop_distance, item.name, item.namespace, preview
205 ));
206 }
207 }
208 }
209 OutputFormat::Markdown => {
210 for item in &results {
211 if item.description.is_empty() {
212 output::emit_text(&format!(
213 "- **{}** ({}) — hop {}",
214 item.name, item.namespace, item.hop_distance
215 ));
216 } else {
217 let preview: String = item
218 .description
219 .chars()
220 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
221 .collect();
222 output::emit_text(&format!(
223 "- **{}** ({}) — hop {}: {}",
224 item.name, item.namespace, item.hop_distance, preview
225 ));
226 }
227 }
228 }
229 }
230
231 Ok(())
232}
233
234#[allow(clippy::too_many_arguments)]
235fn traverse_related(
236 conn: &Connection,
237 seed_memory_id: i64,
238 seed_entity_ids: &[i64],
239 namespace: &str,
240 max_hops: u32,
241 min_weight: f64,
242 relation_filter: Option<&str>,
243 limit: usize,
244) -> Result<Vec<RelatedMemory>, AppError> {
245 if seed_entity_ids.is_empty() || max_hops == 0 {
246 return Ok(Vec::new());
247 }
248
249 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
252 let mut entity_hop: crate::hash::AHashMap<i64, u32> =
253 crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
254 for &e in seed_entity_ids {
255 entity_hop.insert(e, 0);
256 }
257 let mut entity_edge: crate::hash::AHashMap<i64, (String, String, String, f64)> =
260 crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
261
262 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
263
264 while let Some(current_entity) = queue.pop_front() {
265 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
266 if current_hop >= max_hops {
267 continue;
268 }
269
270 let neighbours =
271 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
272
273 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
274 if visited.insert(neighbour_id) {
275 entity_hop.insert(neighbour_id, current_hop + 1);
276 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
277 queue.push_back(neighbour_id);
278 }
279 }
280 }
281
282 let mut out: Vec<RelatedMemory> = Vec::with_capacity(limit);
284 let mut dedup_ids: crate::hash::AHashSet<i64> =
285 crate::hash::AHashSet::with_capacity_and_hasher(limit, Default::default());
286 dedup_ids.insert(seed_memory_id);
287
288 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
290 .iter()
291 .filter(|(id, _)| !seed_entity_ids.contains(id))
292 .map(|(id, hop)| (*id, *hop))
293 .collect();
294 ordered_entities.sort_by(|a, b| {
295 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
296 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
297 a.1.cmp(&b.1).then_with(|| {
298 weight_b
299 .partial_cmp(&weight_a)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 })
302 });
303
304 for (entity_id, hop) in ordered_entities {
305 let mut stmt = conn.prepare_cached(
306 "SELECT m.id, m.name, m.namespace, m.type, m.description
307 FROM memory_entities me
308 JOIN memories m ON m.id = me.memory_id
309 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
310 )?;
311 let rows = stmt
312 .query_map(params![entity_id], |r| {
313 Ok((
314 r.get::<_, i64>(0)?,
315 r.get::<_, String>(1)?,
316 r.get::<_, String>(2)?,
317 r.get::<_, String>(3)?,
318 r.get::<_, String>(4)?,
319 ))
320 })?
321 .collect::<Result<Vec<_>, _>>()?;
322
323 for (mid, name, ns, mtype, desc) in rows {
324 if !dedup_ids.insert(mid) {
325 continue;
326 }
327 let edge = entity_edge.get(&entity_id);
328 let src = edge.map(|e| e.0.clone());
329 let tgt = edge.map(|e| e.1.clone());
330 out.push(RelatedMemory {
331 memory_id: mid,
332 name,
333 namespace: ns,
334 memory_type: mtype,
335 description: desc,
336 hop_distance: hop,
337 source_entity: src.clone(),
338 target_entity: tgt.clone(),
339 from: src,
340 to: tgt,
341 relation: edge.map(|e| e.2.clone()),
342 weight: edge.map(|e| e.3),
343 });
344 if out.len() >= limit {
345 return Ok(out);
346 }
347 }
348 }
349 Ok(out)
350}
351
352fn fetch_neighbours(
353 conn: &Connection,
354 entity_id: i64,
355 namespace: &str,
356 min_weight: f64,
357 relation_filter: Option<&str>,
358) -> Result<Vec<Neighbour>, AppError> {
359 let base_sql = "\
362 SELECT r.target_id, se.name, te.name, r.relation, r.weight
363 FROM relationships r
364 JOIN entities se ON se.id = r.source_id
365 JOIN entities te ON te.id = r.target_id
366 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
367
368 let reverse_sql = "\
369 SELECT r.source_id, se.name, te.name, r.relation, r.weight
370 FROM relationships r
371 JOIN entities se ON se.id = r.source_id
372 JOIN entities te ON te.id = r.target_id
373 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
374
375 let mut results: Vec<Neighbour> = Vec::with_capacity(16);
376
377 let forward_sql = match relation_filter {
378 Some(_) => format!("{base_sql} AND r.relation = ?4"),
379 None => base_sql.to_string(),
380 };
381 let rev_sql = match relation_filter {
382 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
383 None => reverse_sql.to_string(),
384 };
385
386 let mut stmt = conn.prepare_cached(&forward_sql)?;
387 let rows: Vec<_> = if let Some(rel) = relation_filter {
388 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
389 Ok((
390 r.get::<_, i64>(0)?,
391 r.get::<_, String>(1)?,
392 r.get::<_, String>(2)?,
393 r.get::<_, String>(3)?,
394 r.get::<_, f64>(4)?,
395 ))
396 })?
397 .collect::<Result<Vec<_>, _>>()?
398 } else {
399 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
400 Ok((
401 r.get::<_, i64>(0)?,
402 r.get::<_, String>(1)?,
403 r.get::<_, String>(2)?,
404 r.get::<_, String>(3)?,
405 r.get::<_, f64>(4)?,
406 ))
407 })?
408 .collect::<Result<Vec<_>, _>>()?
409 };
410 results.extend(rows);
411
412 let mut stmt = conn.prepare_cached(&rev_sql)?;
413 let rows: Vec<_> = if let Some(rel) = relation_filter {
414 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
415 Ok((
416 r.get::<_, i64>(0)?,
417 r.get::<_, String>(1)?,
418 r.get::<_, String>(2)?,
419 r.get::<_, String>(3)?,
420 r.get::<_, f64>(4)?,
421 ))
422 })?
423 .collect::<Result<Vec<_>, _>>()?
424 } else {
425 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
426 Ok((
427 r.get::<_, i64>(0)?,
428 r.get::<_, String>(1)?,
429 r.get::<_, String>(2)?,
430 r.get::<_, String>(3)?,
431 r.get::<_, f64>(4)?,
432 ))
433 })?
434 .collect::<Result<Vec<_>, _>>()?
435 };
436 results.extend(rows);
437
438 Ok(results)
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 fn setup_related_db() -> rusqlite::Connection {
446 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
447 conn.execute_batch(
448 "CREATE TABLE memories (
449 id INTEGER PRIMARY KEY AUTOINCREMENT,
450 name TEXT NOT NULL,
451 namespace TEXT NOT NULL DEFAULT 'global',
452 type TEXT NOT NULL DEFAULT 'fact',
453 description TEXT NOT NULL DEFAULT '',
454 deleted_at INTEGER
455 );
456 CREATE TABLE entities (
457 id INTEGER PRIMARY KEY AUTOINCREMENT,
458 namespace TEXT NOT NULL,
459 name TEXT NOT NULL
460 );
461 CREATE TABLE relationships (
462 id INTEGER PRIMARY KEY AUTOINCREMENT,
463 namespace TEXT NOT NULL,
464 source_id INTEGER NOT NULL,
465 target_id INTEGER NOT NULL,
466 relation TEXT NOT NULL DEFAULT 'related_to',
467 weight REAL NOT NULL DEFAULT 1.0
468 );
469 CREATE TABLE memory_entities (
470 memory_id INTEGER NOT NULL,
471 entity_id INTEGER NOT NULL
472 );",
473 )
474 .expect("failed to create test tables");
475 conn
476 }
477
478 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
479 conn.execute(
480 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
481 rusqlite::params![name, namespace],
482 )
483 .expect("failed to insert memory");
484 conn.last_insert_rowid()
485 }
486
487 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
488 conn.execute(
489 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
490 rusqlite::params![name, namespace],
491 )
492 .expect("failed to insert entity");
493 conn.last_insert_rowid()
494 }
495
496 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
497 conn.execute(
498 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
499 rusqlite::params![memory_id, entity_id],
500 )
501 .expect("failed to link memory-entity");
502 }
503
504 fn insert_relationship(
505 conn: &rusqlite::Connection,
506 namespace: &str,
507 source_id: i64,
508 target_id: i64,
509 relation: &str,
510 weight: f64,
511 ) {
512 conn.execute(
513 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
514 VALUES (?1, ?2, ?3, ?4, ?5)",
515 rusqlite::params![namespace, source_id, target_id, relation, weight],
516 )
517 .expect("failed to insert relationship");
518 }
519
520 #[test]
521 fn related_response_serializes_results_and_elapsed_ms() {
522 let mem = RelatedMemory {
523 memory_id: 1,
524 name: "neighbor-mem".to_string(),
525 namespace: "global".to_string(),
526 memory_type: "document".to_string(),
527 description: "desc".to_string(),
528 hop_distance: 1,
529 source_entity: Some("entity-a".to_string()),
530 target_entity: Some("entity-b".to_string()),
531 from: Some("entity-a".to_string()),
532 to: Some("entity-b".to_string()),
533 relation: Some("related_to".to_string()),
534 weight: Some(0.9),
535 };
536 let resp = RelatedResponse {
537 name: "seed-mem".to_string(),
538 max_hops: 2,
539 related_memories: vec![mem.clone()],
540 results: vec![mem],
541 elapsed_ms: 7,
542 };
543 let json = serde_json::to_value(&resp).expect("serialization failed");
544 assert!(json["results"].is_array());
545 assert_eq!(json["results"].as_array().unwrap().len(), 1);
546 assert_eq!(json["elapsed_ms"], 7u64);
547 assert_eq!(json["results"][0]["type"], "document");
548 assert_eq!(json["results"][0]["hop_distance"], 1);
549 }
550
551 #[test]
552 fn traverse_related_returns_empty_without_seed_entities() {
553 let conn = setup_related_db();
554 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
555 .expect("traverse_related failed");
556 assert!(result.is_empty());
557 }
558
559 #[test]
560 fn traverse_related_returns_empty_with_max_hops_zero() {
561 let conn = setup_related_db();
562 let mem_id = insert_memory(&conn, "seed", "global");
563 let ent_id = insert_entity(&conn, "global", "ent");
564 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
565 .expect("traverse_related failed");
566 assert!(result.is_empty());
567 }
568
569 #[test]
570 fn traverse_related_discovers_neighbor_memory_via_graph() {
571 let conn = setup_related_db();
572 let seed_id = insert_memory(&conn, "seed", "global");
573 let ent_a = insert_entity(&conn, "global", "ent-a");
574 let ent_b = insert_entity(&conn, "global", "ent-b");
575 let neighbor_id = insert_memory(&conn, "neighbor", "global");
576 link_memory_entity(&conn, seed_id, ent_a);
577 link_memory_entity(&conn, neighbor_id, ent_b);
578 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
579 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
580 .expect("traverse_related failed");
581 assert_eq!(result.len(), 1);
582 assert_eq!(result[0].name, "neighbor");
583 }
584
585 #[test]
586 fn traverse_related_respects_limit() {
587 let conn = setup_related_db();
588 let seed_id = insert_memory(&conn, "seed", "global");
589 let ent_seed = insert_entity(&conn, "global", "ent-seed");
590 link_memory_entity(&conn, seed_id, ent_seed);
591 for i in 0..5 {
592 let ent_id = insert_entity(&conn, "global", &format!("ent-{i}"));
593 let mem_id = insert_memory(&conn, &format!("mem-{i}"), "global");
594 link_memory_entity(&conn, mem_id, ent_id);
595 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
596 }
597 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
598 .expect("traverse_related failed");
599 assert_eq!(
600 result.len(),
601 3,
602 "limit=3 must constrain to at most 3 results"
603 );
604 }
605
606 #[test]
607 fn related_memory_optional_null_fields_serialized() {
608 let mem = RelatedMemory {
609 memory_id: 99,
610 name: "no-relation".to_string(),
611 namespace: "ns".to_string(),
612 memory_type: "concept".to_string(),
613 description: "".to_string(),
614 hop_distance: 2,
615 source_entity: None,
616 target_entity: None,
617 from: None,
618 to: None,
619 relation: None,
620 weight: None,
621 };
622 let json = serde_json::to_value(&mem).expect("serialization failed");
623 assert!(json["source_entity"].is_null());
624 assert!(json["target_entity"].is_null());
625 assert!(json["relation"].is_null());
626 assert!(json["weight"].is_null());
627 assert_eq!(json["hop_distance"], 2);
628 }
629}