use sqlx::Row;
use super::FunctionCallStats;
use crate::store::helpers::StoreError;
use crate::store::Store;
impl Store {
async fn batch_count_query(
&self,
filter_column: &str,
group_column: &str,
count_expr: &str,
names: &[&str],
) -> Result<std::collections::HashMap<String, u64>, StoreError> {
let mut result = std::collections::HashMap::new();
const BATCH_SIZE: usize = 500;
for batch in names.chunks(BATCH_SIZE) {
let placeholders = super::super::helpers::make_placeholders(batch.len());
let sql = format!(
"SELECT {group_column}, {count_expr} FROM function_calls WHERE {filter_column} IN ({placeholders}) GROUP BY {group_column}",
);
let mut q = sqlx::query(&sql);
for name in batch {
q = q.bind(name);
}
let rows: Vec<_> = q.fetch_all(&self.pool).await?;
for row in rows {
let name: String = row.get(0);
let count: i64 = row.get(1);
result.insert(name, count as u64);
}
}
Ok(result)
}
pub fn get_caller_counts_batch(
&self,
names: &[&str],
) -> Result<std::collections::HashMap<String, u64>, StoreError> {
let _span = tracing::info_span!("get_caller_counts_batch", count = names.len()).entered();
if names.is_empty() {
return Ok(std::collections::HashMap::new());
}
self.rt
.block_on(self.batch_count_query("callee_name", "callee_name", "COUNT(*)", names))
}
pub fn get_callee_counts_batch(
&self,
names: &[&str],
) -> Result<std::collections::HashMap<String, u64>, StoreError> {
let _span = tracing::info_span!("get_callee_counts_batch", count = names.len()).entered();
if names.is_empty() {
return Ok(std::collections::HashMap::new());
}
self.rt.block_on(self.batch_count_query(
"caller_name",
"caller_name",
"COUNT(DISTINCT callee_name)",
names,
))
}
pub fn find_shared_callers(
&self,
target: &str,
limit: usize,
) -> Result<Vec<(String, u32)>, StoreError> {
let _span = tracing::debug_span!("find_shared_callers", function = %target).entered();
self.rt.block_on(async {
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT fc2.callee_name, COUNT(DISTINCT fc2.caller_name) AS overlap
FROM function_calls fc1
JOIN function_calls fc2 ON fc1.caller_name = fc2.caller_name
WHERE fc1.callee_name = ?1 AND fc2.callee_name != ?1
GROUP BY fc2.callee_name
ORDER BY overlap DESC
LIMIT ?2",
)
.bind(target)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(name, count)| (name, count as u32))
.collect())
})
}
pub fn find_shared_callees(
&self,
target: &str,
limit: usize,
) -> Result<Vec<(String, u32)>, StoreError> {
let _span = tracing::debug_span!("find_shared_callees", function = %target).entered();
self.rt.block_on(async {
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT fc2.caller_name, COUNT(DISTINCT fc2.callee_name) AS overlap
FROM function_calls fc1
JOIN function_calls fc2 ON fc1.callee_name = fc2.callee_name
WHERE fc1.caller_name = ?1 AND fc2.caller_name != ?1
GROUP BY fc2.caller_name
ORDER BY overlap DESC
LIMIT ?2",
)
.bind(target)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(name, count)| (name, count as u32))
.collect())
})
}
pub fn function_call_stats(&self) -> Result<FunctionCallStats, StoreError> {
let _span = tracing::debug_span!("function_call_stats").entered();
self.rt.block_on(async {
let (total_calls, unique_callers, unique_callees): (i64, i64, i64) = sqlx::query_as(
"SELECT COUNT(*), COUNT(DISTINCT caller_name), COUNT(DISTINCT callee_name) FROM function_calls",
)
.fetch_one(&self.pool)
.await?;
Ok(FunctionCallStats {
total_calls: total_calls as u64,
unique_callers: unique_callers as u64,
unique_callees: unique_callees as u64,
})
})
}
pub fn callee_caller_counts(&self) -> Result<Vec<(String, usize)>, StoreError> {
let _span = tracing::debug_span!("callee_caller_counts").entered();
self.rt.block_on(async {
let rows: Vec<_> = sqlx::query(
"SELECT callee_name, COUNT(DISTINCT caller_name) as caller_count \
FROM function_calls GROUP BY callee_name",
)
.fetch_all(&self.pool)
.await?;
Ok(rows
.iter()
.map(|row| {
let name: String = row.get("callee_name");
let count: i64 = row.get("caller_count");
(name, count as usize)
})
.collect())
})
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use crate::test_helpers::setup_store;
use super::*;
fn seed_call_graph(store: &Store) {
let calls = vec![
crate::parser::FunctionCalls {
name: "func_a".to_string(),
line_start: 1,
calls: vec![
crate::parser::CallSite {
callee_name: "func_b".to_string(),
line_number: 2,
},
crate::parser::CallSite {
callee_name: "func_c".to_string(),
line_number: 3,
},
],
},
crate::parser::FunctionCalls {
name: "func_b".to_string(),
line_start: 10,
calls: vec![crate::parser::CallSite {
callee_name: "func_c".to_string(),
line_number: 11,
}],
},
crate::parser::FunctionCalls {
name: "func_d".to_string(),
line_start: 20,
calls: vec![crate::parser::CallSite {
callee_name: "func_b".to_string(),
line_number: 21,
}],
},
];
store
.upsert_function_calls(Path::new("src/test.rs"), &calls)
.unwrap();
}
#[test]
fn test_get_caller_counts_batch() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let counts = store
.get_caller_counts_batch(&["func_b", "func_c"])
.unwrap();
assert_eq!(counts.get("func_b").copied().unwrap_or(0), 2);
assert_eq!(counts.get("func_c").copied().unwrap_or(0), 2);
}
#[test]
fn test_get_callee_counts_batch() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let counts = store
.get_callee_counts_batch(&["func_a", "func_b", "func_d"])
.unwrap();
assert_eq!(counts.get("func_a").copied().unwrap_or(0), 2);
assert_eq!(counts.get("func_b").copied().unwrap_or(0), 1);
assert_eq!(counts.get("func_d").copied().unwrap_or(0), 1);
}
#[test]
fn test_get_caller_counts_batch_empty() {
let (store, _dir) = setup_store();
let counts = store.get_caller_counts_batch(&[]).unwrap();
assert!(counts.is_empty());
}
#[test]
fn test_get_callee_counts_batch_empty() {
let (store, _dir) = setup_store();
let counts = store.get_callee_counts_batch(&[]).unwrap();
assert!(counts.is_empty());
}
#[test]
fn test_get_caller_counts_batch_unknown_names() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let counts = store
.get_caller_counts_batch(&["nonexistent_func", "also_missing"])
.unwrap();
assert!(counts.is_empty());
}
#[test]
fn test_get_callee_counts_batch_unknown_names() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let counts = store
.get_callee_counts_batch(&["nonexistent_func"])
.unwrap();
assert!(counts.is_empty());
}
#[test]
fn test_find_shared_callers() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callers("func_b", 10).unwrap();
let names: Vec<&str> = shared.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"func_c"),
"func_c should share caller func_a with func_b"
);
}
#[test]
fn test_find_shared_callees() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callees("func_a", 10).unwrap();
let names: Vec<&str> = shared.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"func_b"),
"func_b should share callee func_c with func_a"
);
}
#[test]
fn test_find_shared_callers_no_callers() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callers("func_a", 10).unwrap();
assert!(shared.is_empty());
}
#[test]
fn test_find_shared_callees_no_callees() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callees("func_c", 10).unwrap();
assert!(shared.is_empty());
}
#[test]
fn test_find_shared_callers_limit() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callers("func_b", 1).unwrap();
assert!(shared.len() <= 1);
}
#[test]
fn test_find_shared_callers_unknown() {
let (store, _dir) = setup_store();
seed_call_graph(&store);
let shared = store.find_shared_callers("nonexistent", 10).unwrap();
assert!(shared.is_empty());
}
}