use crate::context::graph::find_related_chunks;
use crate::db::traits::StoreChunks;
use crate::db::SqliteStore;
use crate::search::results::RelatedChunkResult;
use anyhow::Result;
const EDGE_WEIGHT_DEFAULT: f32 = 1.0;
const EDGE_WEIGHT_TEST_PENALTY: f32 = 0.5;
const EDGE_WEIGHT_INHERITANCE_BOOST: f32 = 1.1;
const MODULE_PROXIMITY_BOOST: f32 = 1.2;
const PREVIEW_MAX_LENGTH: usize = 100;
pub async fn find_top_related_chunks(
store: &SqliteStore,
source_chunk_id: i64,
limit: usize,
) -> Result<Vec<RelatedChunkResult>> {
let source_chunk = store.get_chunk_by_id(source_chunk_id).await?;
let source_dir = if let Some(ref chunk) = source_chunk {
extract_parent_dir(&chunk.file_path)
} else {
String::new()
};
let related = find_related_chunks(store, source_chunk_id, 2, None).await?;
let mut scored_chunks: Vec<(f32, RelatedChunkResult)> = Vec::new();
for chunk in related {
let base_relevance = chunk.relevance as f32;
let edge_weight = compute_edge_weight("", &chunk.kind);
let chunk_dir = extract_parent_dir(&chunk.relpath);
let module_boost = if chunk_dir == source_dir && !source_dir.is_empty() {
MODULE_PROXIMITY_BOOST
} else {
1.0
};
let relevance = base_relevance * edge_weight * module_boost;
let result = RelatedChunkResult {
chunk_id: chunk.id,
relpath: chunk.relpath,
symbol_name: chunk.symbol_name,
kind: chunk.kind,
start_line: chunk.start_line,
end_line: chunk.end_line,
preview: truncate_preview(&chunk.preview, PREVIEW_MAX_LENGTH),
depth: chunk.depth,
relevance,
relationship_type: infer_relationship_type(chunk.depth),
};
scored_chunks.push((relevance, result));
}
scored_chunks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let top_results = scored_chunks
.into_iter()
.take(limit)
.map(|(_, result)| result)
.collect();
Ok(top_results)
}
fn compute_edge_weight(edge_type: &str, target_kind: &str) -> f32 {
match (edge_type, target_kind) {
("extends" | "implements", _) => EDGE_WEIGHT_INHERITANCE_BOOST,
(_, kind) if kind.contains("test") => EDGE_WEIGHT_TEST_PENALTY,
_ => EDGE_WEIGHT_DEFAULT,
}
}
fn extract_parent_dir(path: &str) -> String {
std::path::Path::new(path)
.parent()
.and_then(|p| p.to_str())
.unwrap_or("")
.to_string()
}
fn truncate_preview(content: &str, max_length: usize) -> String {
if content.len() <= max_length {
content.to_string()
} else {
format!("{}...", &content[..max_length])
}
}
fn infer_relationship_type(depth: i32) -> String {
match depth {
1 => "direct".to_string(),
2 => "indirect".to_string(),
_ => "unknown".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edge_weight_computation() {
assert_eq!(
compute_edge_weight("extends", "class"),
EDGE_WEIGHT_INHERITANCE_BOOST
);
assert_eq!(
compute_edge_weight("implements", "interface"),
EDGE_WEIGHT_INHERITANCE_BOOST
);
assert_eq!(
compute_edge_weight("calls", "test"),
EDGE_WEIGHT_TEST_PENALTY
);
assert_eq!(
compute_edge_weight("", "test_function"),
EDGE_WEIGHT_TEST_PENALTY
);
assert_eq!(
compute_edge_weight("calls", "unit_test"),
EDGE_WEIGHT_TEST_PENALTY
);
assert_eq!(
compute_edge_weight("calls", "function"),
EDGE_WEIGHT_DEFAULT
);
assert_eq!(
compute_edge_weight("imports", "module"),
EDGE_WEIGHT_DEFAULT
);
assert_eq!(compute_edge_weight("", "class"), EDGE_WEIGHT_DEFAULT);
}
#[test]
fn test_module_proximity_boost() {
let same_dir = "src/module";
let different_dir = "src/other";
assert_eq!(
if same_dir == same_dir && !same_dir.is_empty() {
MODULE_PROXIMITY_BOOST
} else {
1.0
},
MODULE_PROXIMITY_BOOST
);
assert_eq!(
if same_dir == different_dir && !same_dir.is_empty() {
MODULE_PROXIMITY_BOOST
} else {
1.0
},
1.0
);
assert_eq!(
if "" == "" && !"".is_empty() {
MODULE_PROXIMITY_BOOST
} else {
1.0
},
1.0
);
}
#[test]
fn test_relevance_sorting() {
let mut chunks = vec![
(0.5, "chunk_a"),
(0.9, "chunk_b"),
(0.3, "chunk_c"),
(0.7, "chunk_d"),
];
chunks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
assert_eq!(chunks[0].1, "chunk_b"); assert_eq!(chunks[1].1, "chunk_d"); assert_eq!(chunks[2].1, "chunk_a"); assert_eq!(chunks[3].1, "chunk_c"); }
#[test]
fn test_preview_truncation() {
assert_eq!(truncate_preview("short", 100), "short");
let exact = "a".repeat(100);
assert_eq!(truncate_preview(&exact, 100), exact);
let long_content = "a".repeat(150);
let expected = format!("{}...", "a".repeat(100));
assert_eq!(truncate_preview(&long_content, 100), expected);
assert_eq!(truncate_preview("", 100), "");
assert_eq!(truncate_preview("content", 0), "...");
}
#[test]
fn test_extract_parent_dir() {
assert_eq!(extract_parent_dir("src/module/file.rs"), "src/module");
assert_eq!(extract_parent_dir("a/b/c/d.txt"), "a/b/c");
assert_eq!(extract_parent_dir("src/file.rs"), "src");
assert_eq!(extract_parent_dir("file.rs"), "");
assert_eq!(extract_parent_dir(""), "");
assert_eq!(extract_parent_dir("/"), "");
assert_eq!(extract_parent_dir("a/b/c/d/e/f.txt"), "a/b/c/d/e");
}
#[test]
fn test_empty_related_chunks() {
let scored_chunks: Vec<(f32, &str)> = vec![];
let top_results: Vec<&str> = scored_chunks
.into_iter()
.take(5)
.map(|(_, result)| result)
.collect();
assert_eq!(top_results.len(), 0);
}
#[test]
fn test_fewer_than_limit() {
let scored_chunks = vec![(0.9, "chunk_a"), (0.5, "chunk_b")];
let top_results: Vec<&str> = scored_chunks
.into_iter()
.take(5) .map(|(_, result)| result)
.collect();
assert_eq!(top_results.len(), 2);
assert_eq!(top_results[0], "chunk_a");
assert_eq!(top_results[1], "chunk_b");
}
#[test]
fn test_infer_relationship_type() {
assert_eq!(infer_relationship_type(1), "direct");
assert_eq!(infer_relationship_type(2), "indirect");
assert_eq!(infer_relationship_type(3), "unknown");
assert_eq!(infer_relationship_type(0), "unknown");
}
}