use anyhow::Result;
use rusqlite::{params, Connection};
pub const DEFAULT_MAX_DEPTH: usize = 3;
pub const HARD_MAX_DEPTH: usize = 10;
pub use crate::db::types::GraphResult;
pub use crate::db::types::ImportDirection;
fn parse_path(path_str: &str) -> Vec<i64> {
path_str
.trim_matches('/')
.split('/')
.filter(|s| !s.is_empty())
.filter_map(|s| s.parse().ok())
.collect()
}
pub fn find_callers(
conn: &Connection,
target_chunk_id: i64,
max_depth: Option<usize>,
) -> Result<Vec<GraphResult>> {
let depth = max_depth.unwrap_or(DEFAULT_MAX_DEPTH).min(HARD_MAX_DEPTH);
let sql = r#"
WITH RECURSIVE callers(chunk_id, depth, path) AS (
-- Base case: direct callers
SELECT src_chunk_id, 1, '/' || src_chunk_id
FROM chunk_edges
WHERE dst_chunk_id = ?1 AND type = 'calls'
UNION ALL
-- Recursive case: callers of callers
SELECT e.src_chunk_id, c.depth + 1,
c.path || '/' || e.src_chunk_id
FROM chunk_edges e
JOIN callers c ON e.dst_chunk_id = c.chunk_id
WHERE c.depth < ?2
AND e.type = 'calls'
-- Cycle detection: don't revisit chunks in path
AND c.path NOT LIKE '%/' || e.src_chunk_id || '/%'
AND c.path NOT LIKE '%/' || e.src_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM callers
ORDER BY depth, chunk_id
"#;
let mut stmt = conn.prepare(sql)?;
let results = stmt.query_map(params![target_chunk_id, depth], |row| {
let chunk_id: i64 = row.get(0)?;
let depth: i64 = row.get(1)?;
let path_str: String = row.get(2)?;
Ok(GraphResult {
chunk_id,
depth: depth as usize,
path: parse_path(&path_str),
edge_type: "calls".to_string(),
})
})?;
results
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub fn find_callees(
conn: &Connection,
source_chunk_id: i64,
max_depth: Option<usize>,
) -> Result<Vec<GraphResult>> {
let depth = max_depth.unwrap_or(DEFAULT_MAX_DEPTH).min(HARD_MAX_DEPTH);
let sql = r#"
WITH RECURSIVE callees(chunk_id, depth, path) AS (
-- Base case: direct callees
SELECT dst_chunk_id, 1, '/' || dst_chunk_id
FROM chunk_edges
WHERE src_chunk_id = ?1 AND type = 'calls'
UNION ALL
-- Recursive case: callees of callees
SELECT e.dst_chunk_id, c.depth + 1,
c.path || '/' || e.dst_chunk_id
FROM chunk_edges e
JOIN callees c ON e.src_chunk_id = c.chunk_id
WHERE c.depth < ?2
AND e.type = 'calls'
-- Cycle detection: don't revisit chunks in path
AND c.path NOT LIKE '%/' || e.dst_chunk_id || '/%'
AND c.path NOT LIKE '%/' || e.dst_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM callees
ORDER BY depth, chunk_id
"#;
let mut stmt = conn.prepare(sql)?;
let results = stmt.query_map(params![source_chunk_id, depth], |row| {
let chunk_id: i64 = row.get(0)?;
let depth: i64 = row.get(1)?;
let path_str: String = row.get(2)?;
Ok(GraphResult {
chunk_id,
depth: depth as usize,
path: parse_path(&path_str),
edge_type: "calls".to_string(),
})
})?;
results
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub fn find_imports(
conn: &Connection,
chunk_id: i64,
direction: ImportDirection,
max_depth: Option<usize>,
) -> Result<Vec<GraphResult>> {
let depth = max_depth.unwrap_or(DEFAULT_MAX_DEPTH).min(HARD_MAX_DEPTH);
let sql = match direction {
ImportDirection::Incoming => {
r#"
WITH RECURSIVE importers(chunk_id, depth, path) AS (
-- Base case: direct importers
SELECT src_chunk_id, 1, '/' || src_chunk_id
FROM chunk_edges
WHERE dst_chunk_id = ?1 AND type = 'imports'
UNION ALL
-- Recursive case: importers of importers
SELECT e.src_chunk_id, i.depth + 1,
i.path || '/' || e.src_chunk_id
FROM chunk_edges e
JOIN importers i ON e.dst_chunk_id = i.chunk_id
WHERE i.depth < ?2
AND e.type = 'imports'
-- Cycle detection
AND i.path NOT LIKE '%/' || e.src_chunk_id || '/%'
AND i.path NOT LIKE '%/' || e.src_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM importers
ORDER BY depth, chunk_id
"#
}
ImportDirection::Outgoing => {
r#"
WITH RECURSIVE imported(chunk_id, depth, path) AS (
-- Base case: direct imports
SELECT dst_chunk_id, 1, '/' || dst_chunk_id
FROM chunk_edges
WHERE src_chunk_id = ?1 AND type = 'imports'
UNION ALL
-- Recursive case: imports of imports
SELECT e.dst_chunk_id, i.depth + 1,
i.path || '/' || e.dst_chunk_id
FROM chunk_edges e
JOIN imported i ON e.src_chunk_id = i.chunk_id
WHERE i.depth < ?2
AND e.type = 'imports'
-- Cycle detection
AND i.path NOT LIKE '%/' || e.dst_chunk_id || '/%'
AND i.path NOT LIKE '%/' || e.dst_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM imported
ORDER BY depth, chunk_id
"#
}
};
let mut stmt = conn.prepare(sql)?;
let results = stmt.query_map(params![chunk_id, depth], |row| {
let chunk_id: i64 = row.get(0)?;
let depth: i64 = row.get(1)?;
let path_str: String = row.get(2)?;
Ok(GraphResult {
chunk_id,
depth: depth as usize,
path: parse_path(&path_str),
edge_type: "imports".to_string(),
})
})?;
results
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub fn find_extensions(
conn: &Connection,
chunk_id: i64,
direction: ImportDirection, max_depth: Option<usize>,
) -> Result<Vec<GraphResult>> {
let depth = max_depth.unwrap_or(DEFAULT_MAX_DEPTH).min(HARD_MAX_DEPTH);
let sql = match direction {
ImportDirection::Incoming => {
r#"
WITH RECURSIVE subclasses(chunk_id, depth, path) AS (
-- Base case: direct subclasses
SELECT src_chunk_id, 1, '/' || src_chunk_id
FROM chunk_edges
WHERE dst_chunk_id = ?1 AND type = 'extends'
UNION ALL
-- Recursive case: subclasses of subclasses
SELECT e.src_chunk_id, s.depth + 1,
s.path || '/' || e.src_chunk_id
FROM chunk_edges e
JOIN subclasses s ON e.dst_chunk_id = s.chunk_id
WHERE s.depth < ?2
AND e.type = 'extends'
-- Cycle detection
AND s.path NOT LIKE '%/' || e.src_chunk_id || '/%'
AND s.path NOT LIKE '%/' || e.src_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM subclasses
ORDER BY depth, chunk_id
"#
}
ImportDirection::Outgoing => {
r#"
WITH RECURSIVE superclasses(chunk_id, depth, path) AS (
-- Base case: direct superclasses
SELECT dst_chunk_id, 1, '/' || dst_chunk_id
FROM chunk_edges
WHERE src_chunk_id = ?1 AND type = 'extends'
UNION ALL
-- Recursive case: superclasses of superclasses
SELECT e.dst_chunk_id, s.depth + 1,
s.path || '/' || e.dst_chunk_id
FROM chunk_edges e
JOIN superclasses s ON e.src_chunk_id = s.chunk_id
WHERE s.depth < ?2
AND e.type = 'extends'
-- Cycle detection
AND s.path NOT LIKE '%/' || e.dst_chunk_id || '/%'
AND s.path NOT LIKE '%/' || e.dst_chunk_id
)
SELECT DISTINCT chunk_id, depth, path
FROM superclasses
ORDER BY depth, chunk_id
"#
}
};
let mut stmt = conn.prepare(sql)?;
let results = stmt.query_map(params![chunk_id, depth], |row| {
let chunk_id: i64 = row.get(0)?;
let depth: i64 = row.get(1)?;
let path_str: String = row.get(2)?;
Ok(GraphResult {
chunk_id,
depth: depth as usize,
path: parse_path(&path_str),
edge_type: "extends".to_string(),
})
})?;
results
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub fn get_direct_edges(
conn: &Connection,
chunk_id: i64,
direction: ImportDirection,
) -> Result<Vec<GraphResult>> {
let sql = match direction {
ImportDirection::Incoming => {
"SELECT src_chunk_id, type FROM chunk_edges WHERE dst_chunk_id = ?1"
}
ImportDirection::Outgoing => {
"SELECT dst_chunk_id, type FROM chunk_edges WHERE src_chunk_id = ?1"
}
};
let mut stmt = conn.prepare(sql)?;
let results = stmt.query_map(params![chunk_id], |row| {
let related_chunk_id: i64 = row.get(0)?;
let edge_type: String = row.get(1)?;
Ok(GraphResult {
chunk_id: related_chunk_id,
depth: 1,
path: vec![related_chunk_id],
edge_type,
})
})?;
results
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("{}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_path_basic() {
let path = parse_path("/1/2/3");
assert_eq!(path, vec![1, 2, 3]);
}
#[test]
fn test_parse_path_single() {
let path = parse_path("/42");
assert_eq!(path, vec![42]);
}
#[test]
fn test_parse_path_empty() {
let path = parse_path("");
assert!(path.is_empty());
}
#[test]
fn test_parse_path_no_slashes() {
let path = parse_path("123");
assert_eq!(path, vec![123]);
}
#[test]
fn test_parse_path_trailing_slash() {
let path = parse_path("/1/2/3/");
assert_eq!(path, vec![1, 2, 3]);
}
#[test]
fn test_parse_path_invalid_elements() {
let path = parse_path("/1/abc/2");
assert_eq!(path, vec![1, 2]); }
#[test]
fn test_default_max_depth() {
assert_eq!(DEFAULT_MAX_DEPTH, 3);
}
#[test]
fn test_hard_max_depth() {
assert_eq!(HARD_MAX_DEPTH, 10);
}
#[test]
fn test_import_direction_variants() {
let incoming = ImportDirection::Incoming;
let outgoing = ImportDirection::Outgoing;
assert_ne!(incoming, outgoing);
}
#[test]
fn test_graph_result_construction() {
let result = GraphResult {
chunk_id: 42,
depth: 2,
path: vec![1, 2, 42],
edge_type: "calls".to_string(),
};
assert_eq!(result.chunk_id, 42);
assert_eq!(result.depth, 2);
assert_eq!(result.path.len(), 3);
assert_eq!(result.edge_type, "calls");
}
#[test]
fn test_depth_clamping() {
let clamped = Some(100_usize)
.unwrap_or(DEFAULT_MAX_DEPTH)
.min(HARD_MAX_DEPTH);
assert_eq!(clamped, HARD_MAX_DEPTH);
let default = None::<usize>
.unwrap_or(DEFAULT_MAX_DEPTH)
.min(HARD_MAX_DEPTH);
assert_eq!(default, DEFAULT_MAX_DEPTH);
}
}