use crate::db::sqlite::graph::{GraphResult, ImportDirection};
use crate::db::traits::StoreChunks;
use crate::db::traits::StoreGraph;
use crate::db::SqliteStore;
use anyhow::Result;
const RELEVANCE_DECAY: f64 = 0.7;
#[derive(Debug, Clone)]
pub struct RelatedChunk {
pub id: i64,
pub relpath: String,
pub symbol_name: Option<String>,
pub kind: String,
pub start_line: i32,
pub end_line: i32,
pub preview: String,
pub depth: i32,
pub relevance: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EdgeType {
Imports,
Exports,
Calls,
CalledBy,
TestOf,
RouteOf,
}
impl EdgeType {
pub fn as_db_str(&self) -> &'static str {
match self {
EdgeType::Imports => "imports",
EdgeType::Exports => "exports",
EdgeType::Calls => "calls",
EdgeType::CalledBy => "called_by",
EdgeType::TestOf => "test_of",
EdgeType::RouteOf => "route_of",
}
}
pub fn from_db_str(s: &str) -> Option<Self> {
match s {
"imports" => Some(EdgeType::Imports),
"exports" => Some(EdgeType::Exports),
"calls" => Some(EdgeType::Calls),
"called_by" => Some(EdgeType::CalledBy),
"test_of" => Some(EdgeType::TestOf),
"route_of" => Some(EdgeType::RouteOf),
_ => None,
}
}
}
async fn graph_result_to_related_chunk(
store: &SqliteStore,
result: GraphResult,
) -> Result<Option<RelatedChunk>> {
let chunk = store.get_chunk_by_id(result.chunk_id).await?;
match chunk {
Some(c) => Ok(Some(RelatedChunk {
id: c.id,
relpath: c.file_path,
symbol_name: c.symbol_name,
kind: c.kind,
start_line: c.start_line,
end_line: c.end_line,
preview: c.preview,
depth: result.depth as i32,
relevance: RELEVANCE_DECAY.powi(result.depth as i32),
})),
None => Ok(None),
}
}
async fn graph_results_to_related_chunks(
store: &SqliteStore,
results: Vec<GraphResult>,
edge_types: Option<&[EdgeType]>,
) -> Result<Vec<RelatedChunk>> {
let mut related_chunks = Vec::new();
for result in results {
if let Some(types) = edge_types {
let result_edge_type = EdgeType::from_db_str(&result.edge_type);
if let Some(edge_type) = result_edge_type {
if !types.contains(&edge_type) {
continue;
}
} else {
continue;
}
}
if let Some(chunk) = graph_result_to_related_chunk(store, result).await? {
related_chunks.push(chunk);
}
}
related_chunks.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(related_chunks)
}
pub async fn find_related_chunks(
store: &SqliteStore,
chunk_id: i64,
max_depth: i32,
edge_types: Option<Vec<EdgeType>>,
) -> Result<Vec<RelatedChunk>> {
let depth = Some(max_depth as usize);
let mut all_results = Vec::new();
let query_callers = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::CalledBy)));
let query_callees = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::Calls)));
let query_imports = edge_types.as_ref().is_none_or(|types| {
types
.iter()
.any(|t| matches!(t, EdgeType::Imports | EdgeType::Exports))
});
if query_callers {
let callers = store.find_callers(chunk_id, depth).await?;
all_results.extend(callers);
}
if query_callees {
let callees = store.find_callees(chunk_id, depth).await?;
all_results.extend(callees);
}
if query_imports {
let incoming_imports = store
.find_imports(chunk_id, ImportDirection::Incoming, depth)
.await?;
let outgoing_imports = store
.find_imports(chunk_id, ImportDirection::Outgoing, depth)
.await?;
all_results.extend(incoming_imports);
all_results.extend(outgoing_imports);
}
let edge_types_ref = edge_types.as_deref();
graph_results_to_related_chunks(store, all_results, edge_types_ref).await
}
pub async fn find_related_chunks_directional(
store: &SqliteStore,
chunk_id: i64,
max_depth: i32,
edge_types: Option<Vec<EdgeType>>,
forward: bool,
) -> Result<Vec<RelatedChunk>> {
let depth = Some(max_depth as usize);
let mut all_results = Vec::new();
if forward {
let query_callees = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::Calls)));
let query_imports = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::Imports)));
if query_callees {
let callees = store.find_callees(chunk_id, depth).await?;
all_results.extend(callees);
}
if query_imports {
let imports = store
.find_imports(chunk_id, ImportDirection::Outgoing, depth)
.await?;
all_results.extend(imports);
}
} else {
let query_callers = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::CalledBy)));
let query_imports = edge_types
.as_ref()
.is_none_or(|types| types.iter().any(|t| matches!(t, EdgeType::Exports)));
if query_callers {
let callers = store.find_callers(chunk_id, depth).await?;
all_results.extend(callers);
}
if query_imports {
let imports = store
.find_imports(chunk_id, ImportDirection::Incoming, depth)
.await?;
all_results.extend(imports);
}
}
let edge_types_ref = edge_types.as_deref();
graph_results_to_related_chunks(store, all_results, edge_types_ref).await
}
pub async fn load_relationships_parallel(
store: &SqliteStore,
chunk_id: i64,
max_depth: i32,
) -> (Vec<RelatedChunk>, Vec<RelatedChunk>, Vec<RelatedChunk>) {
let depth = Some(max_depth as usize);
let (callers_result, callees_result, tests_result) = tokio::join!(
async {
match store.find_callers(chunk_id, depth).await {
Ok(results) => graph_results_to_related_chunks(store, results, None)
.await
.unwrap_or_default(),
Err(_) => Vec::new(),
}
},
async {
match store.find_callees(chunk_id, depth).await {
Ok(results) => graph_results_to_related_chunks(store, results, None)
.await
.unwrap_or_default(),
Err(_) => Vec::new(),
}
},
async {
match store
.find_imports(chunk_id, ImportDirection::Incoming, depth)
.await
{
Ok(results) => {
let test_results: Vec<_> = results
.into_iter()
.filter(|r| r.edge_type == "test_of")
.collect();
graph_results_to_related_chunks(store, test_results, None)
.await
.unwrap_or_default()
}
Err(_) => Vec::new(),
}
}
);
(callers_result, callees_result, tests_result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::traits::StoreMigration;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_edge_type_conversion() {
assert_eq!(EdgeType::Imports.as_db_str(), "imports");
assert_eq!(EdgeType::Exports.as_db_str(), "exports");
assert_eq!(EdgeType::Calls.as_db_str(), "calls");
assert_eq!(EdgeType::CalledBy.as_db_str(), "called_by");
assert_eq!(EdgeType::TestOf.as_db_str(), "test_of");
assert_eq!(EdgeType::RouteOf.as_db_str(), "route_of");
}
#[test]
fn test_edge_type_equality() {
assert_eq!(EdgeType::Calls, EdgeType::Calls);
assert_ne!(EdgeType::Calls, EdgeType::Imports);
}
#[test]
fn test_edge_type_from_db_str() {
assert_eq!(EdgeType::from_db_str("imports"), Some(EdgeType::Imports));
assert_eq!(EdgeType::from_db_str("exports"), Some(EdgeType::Exports));
assert_eq!(EdgeType::from_db_str("calls"), Some(EdgeType::Calls));
assert_eq!(EdgeType::from_db_str("called_by"), Some(EdgeType::CalledBy));
assert_eq!(EdgeType::from_db_str("test_of"), Some(EdgeType::TestOf));
assert_eq!(EdgeType::from_db_str("route_of"), Some(EdgeType::RouteOf));
assert_eq!(EdgeType::from_db_str("unknown"), None);
}
#[test]
fn test_relevance_decay() {
let depth_0_relevance = RELEVANCE_DECAY.powi(0);
let depth_1_relevance = RELEVANCE_DECAY.powi(1);
let depth_2_relevance = RELEVANCE_DECAY.powi(2);
let depth_3_relevance = RELEVANCE_DECAY.powi(3);
assert!((depth_0_relevance - 1.0).abs() < 0.001);
assert!((depth_1_relevance - 0.7).abs() < 0.001);
assert!((depth_2_relevance - 0.49).abs() < 0.001);
assert!((depth_3_relevance - 0.343).abs() < 0.001);
}
#[test]
fn test_related_chunk_structure() {
let chunk = RelatedChunk {
id: 123,
relpath: "src/main.rs".to_string(),
symbol_name: Some("main".to_string()),
kind: "function".to_string(),
start_line: 1,
end_line: 10,
preview: "fn main() {}".to_string(),
depth: 1,
relevance: 0.7,
};
assert_eq!(chunk.id, 123);
assert_eq!(chunk.relpath, "src/main.rs");
assert_eq!(chunk.symbol_name, Some("main".to_string()));
assert_eq!(chunk.kind, "function");
assert_eq!(chunk.depth, 1);
assert!((chunk.relevance - 0.7).abs() < 0.001);
}
static TEST_DB_COUNTER: AtomicUsize = AtomicUsize::new(0);
async fn setup_test_store() -> SqliteStore {
let counter = TEST_DB_COUNTER.fetch_add(1, Ordering::SeqCst);
let db_name = format!("file:memdb_graph_test_{}?mode=memory&cache=shared", counter);
let store = SqliteStore::connect(&db_name)
.await
.expect("Failed to create test store");
store.migrate().await.expect("Failed to run migrations");
store
}
#[tokio::test]
async fn test_find_related_chunks_empty_result() {
let store = setup_test_store().await;
let results = find_related_chunks(&store, 99999, 3, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_find_related_chunks_directional_empty() {
let store = setup_test_store().await;
let forward = find_related_chunks_directional(&store, 99999, 3, None, true)
.await
.unwrap();
assert!(forward.is_empty());
let backward = find_related_chunks_directional(&store, 99999, 3, None, false)
.await
.unwrap();
assert!(backward.is_empty());
}
#[tokio::test]
async fn test_load_relationships_parallel_empty() {
let store = setup_test_store().await;
let (callers, callees, tests) = load_relationships_parallel(&store, 99999, 3).await;
assert!(callers.is_empty());
assert!(callees.is_empty());
assert!(tests.is_empty());
}
#[tokio::test]
async fn test_find_related_chunks_with_edge_filter() {
let store = setup_test_store().await;
let results = find_related_chunks(&store, 99999, 3, Some(vec![EdgeType::Calls]))
.await
.unwrap();
assert!(results.is_empty());
}
}