use crate::errors::AppError;
use crate::graph::{
bfs_with_predecessors, traverse_from_memories_with_hops_capped, PredecessorMap,
};
use crate::output;
use crate::paths::AppPaths;
use crate::storage::connection::open_ro;
use crate::storage::fusion::{rrf_fuse, rrf_max_possible};
use crate::storage::{entities, memories};
use serde::Serialize;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
#[derive(clap::Args)]
#[command(
about = "Deep parallel multi-hop GraphRAG research via query decomposition",
after_long_help = "EXAMPLES:\n \
# Basic deep research\n \
sqlite-graphrag deep-research \"auth architecture decisions\"\n\n \
# With custom parameters\n \
sqlite-graphrag deep-research \"auth\" --k 20 --max-hops 3 --max-sub-queries 7\n\n \
# Include full memory bodies in output\n \
sqlite-graphrag deep-research \"auth\" --with-bodies\n\n \
# Tune RRF and graph scoring\n \
sqlite-graphrag deep-research \"auth and deployment\" --rrf-k 60 --graph-decay 0.7"
)]
pub struct DeepResearchArgs {
#[arg(value_name = "QUERY", help = "Research query to decompose and search")]
pub query: String,
#[arg(
long,
short,
default_value_t = 20,
help = "Results per sub-query (Recall@20 captures 95%+ relevant hits)"
)]
pub k: usize,
#[arg(
long,
default_value_t = 7,
help = "Maximum sub-queries (covers complex multi-hop queries)"
)]
pub max_sub_queries: usize,
#[arg(
long,
default_value_t = 3,
help = "Multi-hop graph traversal depth (sweet spot: 2-3 hops)"
)]
pub max_hops: usize,
#[arg(
long,
default_value_t = 0.3,
help = "Minimum edge weight for graph traversal"
)]
pub min_weight: f64,
#[arg(long, help = "Maximum concurrent sub-queries (default: min(cpus, 8))")]
pub max_concurrency: Option<usize>,
#[arg(long, default_value_t = 30, help = "Timeout per sub-query in seconds")]
pub timeout: u64,
#[arg(
long,
default_value_t = false,
help = "Include full memory bodies in results"
)]
pub with_bodies: bool,
#[arg(
long,
default_value_t = 50,
help = "Maximum results after deduplication"
)]
pub max_results: usize,
#[arg(
long,
default_value_t = 60.0,
help = "RRF k parameter (higher = less weight on top ranks)"
)]
pub rrf_k: f64,
#[arg(
long,
default_value_t = 0.7,
help = "Graph score decay factor per hop (0.0-1.0)"
)]
pub graph_decay: f64,
#[arg(
long,
default_value_t = 0.2,
help = "Minimum score threshold for graph-expanded results"
)]
pub graph_min_score: f64,
#[arg(
long,
help = "Limit neighbours per entity per hop for graph traversal (default: unlimited)"
)]
pub max_neighbors_per_hop: Option<usize>,
#[arg(
long,
help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
)]
pub namespace: Option<String>,
#[arg(long, hide = true)]
pub json: bool,
#[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
pub db: Option<String>,
#[command(flatten)]
pub daemon: crate::cli::DaemonOpts,
}
#[derive(Serialize)]
struct SubQuery {
id: usize,
text: String,
source: &'static str,
}
#[derive(Serialize)]
struct DeepResult {
name: String,
score: f64,
source: String,
sub_query_ids: Vec<usize>,
snippet: String,
#[serde(skip_serializing_if = "Option::is_none")]
body: Option<String>,
hop_distance: Option<usize>,
}
#[derive(Serialize, Clone)]
struct EvidenceNode {
entity: String,
#[serde(skip_serializing_if = "Option::is_none")]
relation: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
weight: Option<f64>,
}
#[derive(Serialize)]
struct EvidenceChain {
from: String,
to: String,
path: Vec<EvidenceNode>,
total_weight: f64,
depth: usize,
sub_query_ids: Vec<usize>,
}
#[derive(Serialize)]
struct ResearchStats {
sub_queries_total: usize,
sub_queries_completed: usize,
sub_queries_failed: usize,
sub_queries_timed_out: usize,
unique_memories_found: usize,
evidence_chains_found: usize,
elapsed_ms: u64,
}
#[derive(Serialize)]
struct DeepResearchResponse {
query: String,
sub_queries: Vec<SubQuery>,
results: Vec<DeepResult>,
evidence_chains: Vec<EvidenceChain>,
stats: ResearchStats,
}
type MergedHit = (f64, String, String, String, Option<usize>, Vec<usize>);
struct SubQueryResult {
sub_query_id: usize,
hits: Vec<(i64, f64, String, String, String, Option<usize>)>,
chains: Vec<EvidenceChain>,
}
pub fn run(args: DeepResearchArgs) -> Result<(), AppError> {
let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.map_err(|e| AppError::Internal(anyhow::anyhow!("failed to build tokio runtime: {e}")))?;
rt.block_on(run_async(args))
}
async fn run_async(args: DeepResearchArgs) -> Result<(), AppError> {
let start = std::time::Instant::now();
if args.query.trim().is_empty() {
return Err(AppError::Validation(crate::i18n::validation::empty_query()));
}
let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
let paths = AppPaths::resolve(args.db.as_deref())?;
crate::storage::connection::ensure_db_ready(&paths)?;
let sub_query_texts = decompose_query(&args.query, args.max_sub_queries);
let sub_queries: Vec<SubQuery> = sub_query_texts
.iter()
.enumerate()
.map(|(i, text)| SubQuery {
id: i,
text: text.clone(),
source: if sub_query_texts.len() == 1 {
"original"
} else {
"decomposed"
},
})
.collect();
output::emit_progress_i18n(
"Computing per-sub-query embeddings...",
"Calculando embeddings por sub-consulta...",
);
let mut sub_embeddings: Vec<Arc<Vec<f32>>> = Vec::with_capacity(sub_query_texts.len());
for sq_text in &sub_query_texts {
let emb = crate::daemon::embed_query_or_local(
&paths.models,
sq_text,
args.daemon.autostart_daemon,
)?;
sub_embeddings.push(Arc::new(emb));
}
let cpu_count = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let permits = args
.max_concurrency
.unwrap_or_else(|| cpu_count.min(8))
.min(sub_queries.len())
.max(1);
let semaphore = Arc::new(Semaphore::new(permits));
let timeout_dur = std::time::Duration::from_secs(args.timeout);
let mut join_set: JoinSet<Result<SubQueryResult, (usize, String)>> = JoinSet::new();
for (idx, sq_text) in sub_query_texts.iter().enumerate() {
let sem = Arc::clone(&semaphore);
let emb = Arc::clone(&sub_embeddings[idx]);
let ns = namespace.clone();
let db_path = paths.db.clone();
let query_text = sq_text.clone();
let k = args.k;
let max_hops = args.max_hops;
let min_weight = args.min_weight;
let rrf_k = args.rrf_k;
let graph_decay = args.graph_decay;
let graph_min_score = args.graph_min_score;
let max_neighbors_per_hop = args.max_neighbors_per_hop;
join_set.spawn(async move {
let _permit = sem
.acquire_owned()
.await
.map_err(|e| (idx, format!("semaphore closed: {e}")))?;
let result = tokio::time::timeout(timeout_dur, async move {
execute_sub_query(
idx,
&query_text,
emb.as_slice(),
&ns,
&db_path,
k,
max_hops,
min_weight,
rrf_k,
graph_decay,
graph_min_score,
max_neighbors_per_hop,
)
})
.await;
match result {
Ok(inner) => inner.map_err(|e| (idx, e)),
Err(_) => Err((idx, "timeout".to_string())),
}
});
}
let mut sub_query_results: Vec<SubQueryResult> = Vec::with_capacity(sub_queries.len());
let mut failed_count = 0usize;
let mut timed_out_count = 0usize;
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok(Ok(sqr)) => sub_query_results.push(sqr),
Ok(Err((_idx, reason))) => {
if reason == "timeout" {
timed_out_count += 1;
} else {
failed_count += 1;
}
tracing::warn!(sub_query_id = _idx, reason = %reason, "sub-query failed");
}
Err(join_err) => {
failed_count += 1;
if join_err.is_panic() {
tracing::error!("sub-query task panicked: {join_err}");
} else {
tracing::warn!("sub-query task cancelled: {join_err}");
}
}
}
}
let mut merged: HashMap<i64, MergedHit> = HashMap::new();
for sqr in &sub_query_results {
for (mem_id, score, source, snippet, body, hop) in &sqr.hits {
let entry = merged.entry(*mem_id).or_insert_with(|| {
(
*score,
source.clone(),
snippet.clone(),
body.clone(),
*hop,
Vec::new(),
)
});
if *score > entry.0 {
entry.0 = *score;
entry.1 = source.clone();
entry.2 = snippet.clone();
entry.3 = body.clone();
entry.4 = *hop;
}
if !entry.5.contains(&sqr.sub_query_id) {
entry.5.push(sqr.sub_query_id);
}
}
}
let conn = open_ro(&paths.db)?;
let mut results: Vec<DeepResult> = Vec::with_capacity(merged.len().min(args.max_results));
let mut ranked: Vec<(i64, MergedHit)> = merged.into_iter().collect();
ranked.sort_by(|a, b| {
b.1 .0
.partial_cmp(&a.1 .0)
.unwrap_or(std::cmp::Ordering::Equal)
});
ranked.truncate(args.max_results);
for (mem_id, (score, source, snippet, body, hop, sq_ids)) in ranked {
let name = match memories::read_full(&conn, mem_id)? {
Some(row) => row.name,
None => continue,
};
results.push(DeepResult {
name,
score,
source,
sub_query_ids: sq_ids,
snippet,
body: if args.with_bodies { Some(body) } else { None },
hop_distance: hop,
});
}
let completed_count = sub_query_results.len();
let mut evidence_chains: Vec<EvidenceChain> = Vec::new();
let mut seen_chain_keys: HashSet<String> = HashSet::new();
for sqr in sub_query_results {
for chain in sqr.chains {
let key = format!("{}->{}", chain.from, chain.to);
if seen_chain_keys.insert(key) {
evidence_chains.push(chain);
}
}
}
evidence_chains.retain(|c| c.depth >= 2);
evidence_chains.sort_by(|a, b| {
b.total_weight
.partial_cmp(&a.total_weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
let unique_memories = results.len();
let evidence_count = evidence_chains.len();
output::emit_json(&DeepResearchResponse {
query: args.query,
sub_queries,
results,
evidence_chains,
stats: ResearchStats {
sub_queries_total: sub_query_texts.len(),
sub_queries_completed: completed_count,
sub_queries_failed: failed_count,
sub_queries_timed_out: timed_out_count,
unique_memories_found: unique_memories,
evidence_chains_found: evidence_count,
elapsed_ms: start.elapsed().as_millis() as u64,
},
})?;
Ok(())
}
fn decompose_query(query: &str, max: usize) -> Vec<String> {
if query.is_empty() {
return vec![query.to_string()];
}
let mut parts: Vec<String> = Vec::new();
let relational = [
" that caused ",
" depending on ",
" related to ",
" connected to ",
" linked to ",
" caused by ",
" followed by ",
];
let mut text = query.to_string();
let mut did_relational_split = false;
for phrase in &relational {
if text.to_lowercase().contains(phrase) {
let lower = text.to_lowercase();
if let Some(pos) = lower.find(phrase) {
let left = text[..pos].trim().to_string();
let right = text[pos + phrase.len()..].trim().to_string();
if !left.is_empty() {
parts.push(left);
}
if !right.is_empty() {
text = right;
}
did_relational_split = true;
}
}
}
if did_relational_split && !text.is_empty() {
parts.push(text.clone());
}
if parts.is_empty() {
let semi_parts: Vec<&str> = query.split(';').collect();
if semi_parts.len() > 1 {
for p in &semi_parts {
let trimmed = p.trim();
if !trimmed.is_empty() {
parts.push(trimmed.to_string());
}
}
} else {
let normalized = query
.replace(" and ", ", ")
.replace(" AND ", ", ")
.replace(" e ", ", ")
.replace(" E ", ", ");
let comma_parts: Vec<&str> = normalized.split(',').collect();
if comma_parts.len() > 1 {
for p in &comma_parts {
let trimmed = p.trim();
if !trimmed.is_empty() {
parts.push(trimmed.to_string());
}
}
}
}
}
if parts.is_empty() {
return vec![query.to_string()];
}
parts.truncate(max);
parts
}
fn reconstruct_path(
target_id: i64,
seed_entity_ids: &HashSet<i64>,
predecessor: &PredecessorMap,
entity_names: &HashMap<i64, String>,
) -> Option<(Vec<EvidenceNode>, f64)> {
let mut path_ids: Vec<(i64, Option<String>, Option<f64>)> = Vec::new();
let mut total_weight = 1.0_f64;
let mut current = target_id;
loop {
if seed_entity_ids.contains(¤t) {
break;
}
let (parent, relation, weight) = predecessor.get(¤t)?;
total_weight *= weight;
path_ids.push((current, Some(relation.clone()), Some(*weight)));
current = *parent;
}
path_ids.push((current, None, None));
path_ids.reverse();
let nodes: Vec<EvidenceNode> = path_ids
.into_iter()
.map(|(id, relation, weight)| EvidenceNode {
entity: entity_names
.get(&id)
.cloned()
.unwrap_or_else(|| format!("entity-{id}")),
relation,
weight,
})
.collect();
Some((nodes, total_weight))
}
#[allow(clippy::too_many_arguments)]
fn execute_sub_query(
sub_query_id: usize,
query_text: &str,
embedding: &[f32],
namespace: &str,
db_path: &std::path::Path,
k: usize,
max_hops: usize,
min_weight: f64,
rrf_k: f64,
graph_decay: f64,
graph_min_score: f64,
max_neighbors_per_hop: Option<usize>,
) -> Result<SubQueryResult, String> {
let conn = open_ro(db_path).map_err(|e| format!("failed to open db: {e}"))?;
let mut hits: Vec<(i64, f64, String, String, String, Option<usize>)> =
Vec::with_capacity(k * 2);
let mut seen_ids: HashSet<i64> = HashSet::new();
let knn_results = memories::knn_search(&conn, embedding, &[namespace.to_string()], None, k)
.map_err(|e| format!("knn_search failed: {e}"))?;
let knn_ids: Vec<i64> = knn_results.iter().map(|(id, _)| *id).collect();
let knn_distance_map: HashMap<i64, f64> = knn_results
.iter()
.map(|(id, dist)| (*id, *dist as f64))
.collect();
let fts_results = match memories::fts_search(&conn, query_text, namespace, None, k) {
Ok(rows) => rows,
Err(e) => {
tracing::warn!(
sub_query_id,
"FTS5 search failed, continuing with KNN only: {e}"
);
vec![]
}
};
let fts_ids: Vec<i64> = fts_results.iter().map(|r| r.id).collect();
let rrf_scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], rrf_k);
let max_possible = rrf_max_possible(&[1.0, 1.0], rrf_k);
let mut fused: Vec<(i64, f64)> = rrf_scores.into_iter().collect();
fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
fused.truncate(k * 2);
for (memory_id, combined_score) in &fused {
if seen_ids.insert(*memory_id) {
let normalized = if max_possible > 0.0 {
combined_score / max_possible
} else {
0.0
};
let score = normalized.clamp(0.0, 1.0);
let source = if knn_distance_map.contains_key(memory_id) {
"knn"
} else {
"fts"
};
if let Ok(Some(row)) = memories::read_full(&conn, *memory_id) {
let snippet: String = row.body.chars().take(300).collect();
hits.push((
*memory_id,
score,
source.to_string(),
snippet,
row.body,
None,
));
}
}
}
let memory_ids: Vec<i64> = hits.iter().map(|(id, ..)| *id).collect();
let mut chains: Vec<EvidenceChain> = Vec::new();
if !memory_ids.is_empty() && max_hops > 0 {
let entity_knn = entities::knn_search(&conn, embedding, namespace, 5).unwrap_or_default();
let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
let mut seed_entity_ids: Vec<i64> = entity_ids.clone();
for &mem_id in &memory_ids {
let mut stmt = conn
.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")
.map_err(|e| format!("prepare failed: {e}"))?;
let ids: Vec<i64> = stmt
.query_map(rusqlite::params![mem_id], |r| r.get(0))
.map_err(|e| format!("query failed: {e}"))?
.filter_map(|r| r.ok())
.collect();
seed_entity_ids.extend(ids);
}
seed_entity_ids.sort_unstable();
seed_entity_ids.dedup();
let all_seed_ids: Vec<i64> = memory_ids
.iter()
.chain(entity_ids.iter())
.copied()
.collect();
if let Ok(graph_results) = traverse_from_memories_with_hops_capped(
&conn,
&all_seed_ids,
namespace,
min_weight,
max_hops as u32,
max_neighbors_per_hop,
) {
let seed_score_map: HashMap<i64, f64> = fused
.iter()
.map(|(id, s)| {
let normalized = if max_possible > 0.0 {
s / max_possible
} else {
0.0
};
(*id, normalized.clamp(0.0, 1.0))
})
.collect();
for (graph_mem_id, hop) in graph_results {
if seen_ids.insert(graph_mem_id) {
let avg_seed_score: f64 = if seed_score_map.is_empty() {
0.5
} else {
let sum: f64 = seed_score_map.values().sum();
sum / seed_score_map.len() as f64
};
let graph_score =
(avg_seed_score * graph_decay.powi(hop as i32)).clamp(0.0, 1.0);
if graph_score < graph_min_score {
continue;
}
if let Ok(Some(row)) = memories::read_full(&conn, graph_mem_id) {
let snippet: String = row.body.chars().take(300).collect();
hits.push((
graph_mem_id,
graph_score,
"graph".to_string(),
snippet,
row.body,
Some(hop as usize),
));
}
}
}
}
if !seed_entity_ids.is_empty() {
let (entity_depth, predecessor) = bfs_with_predecessors(
&conn,
&seed_entity_ids,
namespace,
min_weight,
max_hops as u32,
max_neighbors_per_hop,
)
.unwrap_or_default();
let seed_entity_set: HashSet<i64> = seed_entity_ids.iter().copied().collect();
let all_entity_ids: Vec<i64> = entity_depth.keys().copied().collect();
let mut entity_names: HashMap<i64, String> = HashMap::new();
for &eid in &all_entity_ids {
let name_res: rusqlite::Result<String> = conn.query_row(
"SELECT name FROM entities WHERE id = ?1",
rusqlite::params![eid],
|r| r.get(0),
);
if let Ok(name) = name_res {
entity_names.insert(eid, name);
}
}
for (&target_id, &_hop) in &entity_depth {
if seed_entity_set.contains(&target_id) {
continue;
}
if !predecessor.contains_key(&target_id) {
continue;
}
if let Some((path_nodes, total_weight)) =
reconstruct_path(target_id, &seed_entity_set, &predecessor, &entity_names)
{
if path_nodes.len() < 2 {
continue;
}
let from = path_nodes
.first()
.map(|n| n.entity.clone())
.unwrap_or_default();
let to = path_nodes
.last()
.map(|n| n.entity.clone())
.unwrap_or_default();
let depth = path_nodes.len();
chains.push(EvidenceChain {
from,
to,
path: path_nodes,
total_weight,
depth,
sub_query_ids: vec![sub_query_id],
});
}
}
chains.sort_by(|a, b| {
b.total_weight
.partial_cmp(&a.total_weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
chains.truncate(20);
}
}
Ok(SubQueryResult {
sub_query_id,
hits,
chains,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decompose_and_conjunction() {
let result = decompose_query("A and B", 7);
assert_eq!(result, vec!["A", "B"]);
}
#[test]
fn test_decompose_no_split() {
let result = decompose_query("simple query", 7);
assert_eq!(result, vec!["simple query"]);
}
#[test]
fn test_decompose_three_parts() {
let result = decompose_query("A, B and C", 7);
assert_eq!(result, vec!["A", "B", "C"]);
}
#[test]
fn test_decompose_portuguese_conjunctions() {
let result = decompose_query("A e B", 7);
assert_eq!(result, vec!["A", "B"]);
}
#[test]
fn test_decompose_max_cap() {
let parts: Vec<String> = (0..10).map(|i| format!("part{i}")).collect();
let query = parts.join(", ");
let result = decompose_query(&query, 7);
assert!(
result.len() <= 7,
"expected at most 7 sub-queries, got {}",
result.len()
);
}
#[test]
fn test_decompose_empty_preserves_original() {
let result = decompose_query("", 7);
assert_eq!(result, vec![""]);
}
#[test]
fn test_decompose_semicolons() {
let result = decompose_query("auth design; deployment config; logging", 7);
assert_eq!(result, vec!["auth design", "deployment config", "logging"]);
}
#[test]
fn test_decompose_relational_phrase() {
let result = decompose_query("auth that caused deployment failure", 7);
assert_eq!(result, vec!["auth", "deployment failure"]);
}
#[test]
fn test_sub_query_serialization() {
let sq = SubQuery {
id: 0,
text: "test query".to_string(),
source: "original",
};
let json = serde_json::to_value(&sq).expect("serialization failed");
assert_eq!(json["id"], 0);
assert_eq!(json["text"], "test query");
assert_eq!(json["source"], "original");
}
#[test]
fn test_deep_result_omits_body_when_none() {
let result = DeepResult {
name: "test".to_string(),
score: 0.9,
source: "knn".to_string(),
sub_query_ids: vec![0],
snippet: "snippet".to_string(),
body: None,
hop_distance: None,
};
let json = serde_json::to_string(&result).expect("serialization failed");
assert!(!json.contains("\"body\""), "body must be omitted when None");
}
#[test]
fn test_deep_result_includes_body_when_some() {
let result = DeepResult {
name: "test".to_string(),
score: 0.9,
source: "knn".to_string(),
sub_query_ids: vec![0, 1],
snippet: "snippet".to_string(),
body: Some("full body content".to_string()),
hop_distance: Some(2),
};
let json = serde_json::to_string(&result).expect("serialization failed");
assert!(json.contains("\"body\""), "body must be present when Some");
assert!(json.contains("full body content"));
}
#[test]
fn test_evidence_node_omits_none_fields() {
let node = EvidenceNode {
entity: "auth-module".to_string(),
relation: None,
weight: None,
};
let json = serde_json::to_string(&node).expect("serialization failed");
assert!(
!json.contains("\"relation\""),
"relation must be omitted when None"
);
assert!(
!json.contains("\"weight\""),
"weight must be omitted when None"
);
}
#[test]
fn test_research_stats_serialization() {
let stats = ResearchStats {
sub_queries_total: 3,
sub_queries_completed: 2,
sub_queries_failed: 1,
sub_queries_timed_out: 0,
unique_memories_found: 10,
evidence_chains_found: 2,
elapsed_ms: 1234,
};
let json = serde_json::to_value(&stats).expect("serialization failed");
assert_eq!(json["sub_queries_total"], 3);
assert_eq!(json["sub_queries_completed"], 2);
assert_eq!(json["sub_queries_failed"], 1);
assert_eq!(json["elapsed_ms"], 1234);
}
#[test]
fn test_deep_research_response_serialization() {
let resp = DeepResearchResponse {
query: "test query".to_string(),
sub_queries: vec![SubQuery {
id: 0,
text: "test query".to_string(),
source: "original",
}],
results: vec![],
evidence_chains: vec![],
stats: ResearchStats {
sub_queries_total: 1,
sub_queries_completed: 1,
sub_queries_failed: 0,
sub_queries_timed_out: 0,
unique_memories_found: 0,
evidence_chains_found: 0,
elapsed_ms: 42,
},
};
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["query"], "test query");
assert!(json["sub_queries"].is_array());
assert!(json["results"].is_array());
assert!(json["evidence_chains"].is_array());
assert_eq!(json["stats"]["elapsed_ms"], 42);
}
#[test]
fn test_distinct_sub_queries_produce_distinct_texts() {
let queries = [
"authentication design decisions",
"deployment configuration and infrastructure",
];
assert_ne!(queries[0], queries[1]);
let decomposed = decompose_query(
"authentication design decisions; deployment configuration and infrastructure",
7,
);
assert_eq!(decomposed.len(), 2);
assert_ne!(decomposed[0], decomposed[1]);
}
#[test]
fn test_rrf_fuse_via_fusion_module() {
use crate::storage::fusion::rrf_fuse;
let knn_ids: Vec<i64> = vec![1, 2, 3];
let fts_ids: Vec<i64> = vec![2, 1, 4];
let scores = rrf_fuse(&[(1.0, &knn_ids), (1.0, &fts_ids)], 60.0);
let score_1 = scores[&1];
let score_2 = scores[&2];
let score_3 = scores[&3]; let score_4 = scores[&4];
assert!(
score_1 > score_3,
"id 1 (both lists) must beat id 3 (knn-only rank 3)"
);
assert!(
score_2 > score_4,
"id 2 (both lists) must beat id 4 (fts-only rank 3)"
);
}
#[test]
fn test_evidence_chain_has_from_to_and_path() {
let chain = EvidenceChain {
from: "auth-module".to_string(),
to: "jwt-service".to_string(),
path: vec![
EvidenceNode {
entity: "auth-module".to_string(),
relation: None,
weight: None,
},
EvidenceNode {
entity: "token-validator".to_string(),
relation: Some("depends-on".to_string()),
weight: Some(0.9),
},
EvidenceNode {
entity: "jwt-service".to_string(),
relation: Some("uses".to_string()),
weight: Some(0.8),
},
],
total_weight: 0.72,
depth: 3,
sub_query_ids: vec![0],
};
let json = serde_json::to_value(&chain).expect("serialization failed");
assert!(
json["from"].is_string(),
"evidence chain must have 'from' field"
);
assert!(
json["to"].is_string(),
"evidence chain must have 'to' field"
);
assert!(
json["path"].is_array(),
"evidence chain must have 'path' array"
);
assert_eq!(json["path"].as_array().unwrap().len(), 3);
assert!(json["total_weight"].is_number(), "must have total_weight");
assert_eq!(json["depth"], 3);
}
#[test]
fn test_reconstruct_path_root_to_target_order() {
let seed_set: HashSet<i64> = [10i64].into_iter().collect();
let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
predecessor.insert(20, (10, "depends-on".to_string(), 0.9));
predecessor.insert(30, (20, "uses".to_string(), 0.8));
let mut entity_names: HashMap<i64, String> = HashMap::new();
entity_names.insert(10, "seed-entity".to_string());
entity_names.insert(20, "middle-entity".to_string());
entity_names.insert(30, "target-entity".to_string());
let result = reconstruct_path(30, &seed_set, &predecessor, &entity_names);
assert!(result.is_some(), "path must be reconstructed");
let (nodes, weight) = result.unwrap();
assert_eq!(nodes.len(), 3);
assert_eq!(nodes[0].entity, "seed-entity");
assert_eq!(nodes[1].entity, "middle-entity");
assert_eq!(nodes[2].entity, "target-entity");
assert!((weight - 0.72).abs() < 1e-6);
}
#[test]
fn test_evidence_chains_single_hop_filtered_out() {
let chain = EvidenceChain {
from: "a".to_string(),
to: "a".to_string(),
path: vec![EvidenceNode {
entity: "a".to_string(),
relation: None,
weight: None,
}],
total_weight: 1.0,
depth: 1,
sub_query_ids: vec![0],
};
let chains = vec![chain];
let retained: Vec<_> = chains.into_iter().filter(|c| c.depth >= 2).collect();
assert!(retained.is_empty(), "depth-1 chains must be filtered out");
}
#[test]
fn test_bfs_with_predecessors_respects_neighbor_cap() {
use crate::graph::bfs_with_predecessors;
use rusqlite::Connection;
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE relationships (
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
weight REAL NOT NULL,
namespace TEXT NOT NULL,
relation TEXT NOT NULL DEFAULT 'related'
);",
)
.unwrap();
for target in 2i64..=6 {
conn.execute(
"INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, 'ns')",
rusqlite::params![1i64, target, 1.0f64],
)
.unwrap();
}
let (depth_uncapped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, None).unwrap();
assert_eq!(
depth_uncapped.len() - 1,
5,
"uncapped must discover all 5 neighbours (plus seed)"
);
let (depth_capped, _) = bfs_with_predecessors(&conn, &[1], "ns", 0.0, 1, Some(2)).unwrap();
assert_eq!(
depth_capped.len(),
3,
"capped to 2 must yield seed + 2 neighbours"
);
}
}