use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::Serialize;
use crate::gather::{
bfs_expand, fetch_and_assemble, GatherDirection, GatherOptions, GatheredChunk,
};
use crate::impact::{find_affected_tests_with_chunks, TestInfo, DEFAULT_MAX_TEST_SEARCH_DEPTH};
use crate::language::{ChunkType, Language};
use crate::parser::TypeEdgeKind;
use crate::store::Store;
use crate::{AnalysisError, Embedder};
pub const DEFAULT_ONBOARD_DEPTH: usize = 3;
const MAX_CALLEE_FETCH: usize = 30;
const MAX_CALLER_FETCH: usize = 15;
#[derive(Debug, Clone, Serialize)]
pub struct OnboardResult {
pub concept: String,
pub entry_point: OnboardEntry,
pub call_chain: Vec<OnboardEntry>,
pub callers: Vec<OnboardEntry>,
pub key_types: Vec<TypeInfo>,
pub tests: Vec<TestEntry>,
pub summary: OnboardSummary,
}
#[derive(Debug, Clone, Serialize)]
pub struct OnboardEntry {
pub name: String,
#[serde(serialize_with = "crate::serialize_path_normalized")]
pub file: PathBuf,
pub line_start: u32,
pub line_end: u32,
pub language: Language,
pub chunk_type: ChunkType,
pub signature: String,
pub content: String,
pub depth: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct TypeInfo {
pub type_name: String,
pub edge_kind: TypeEdgeKind,
}
#[derive(Debug, Clone, Serialize)]
pub struct TestEntry {
pub name: String,
#[serde(serialize_with = "crate::serialize_path_normalized")]
pub file: PathBuf,
pub line: u32,
pub call_depth: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct OnboardSummary {
pub total_items: usize,
pub files_covered: usize,
pub callee_depth: usize,
pub tests_found: usize,
}
pub fn onboard(
store: &Store,
embedder: &Embedder,
concept: &str,
root: &Path,
depth: usize,
) -> Result<OnboardResult, AnalysisError> {
let _span = tracing::info_span!("onboard", concept).entered();
let depth = depth.min(10);
let query_embedding = embedder.embed_query(concept)?;
let filter = crate::store::SearchFilter {
query_text: concept.to_string(),
enable_rrf: false, ..crate::store::SearchFilter::default()
};
let results = store.search_filtered(&query_embedding, &filter, 10, 0.0)?;
if results.is_empty() {
return Err(AnalysisError::NotFound(format!(
"No relevant code found for concept: {concept}"
)));
}
let entry = results
.iter()
.find(|r| is_callable_type(r.chunk.chunk_type))
.or(results.first())
.expect("results guaranteed non-empty by early return above");
let entry_name = entry.chunk.name.clone();
let entry_file = entry
.chunk
.file
.strip_prefix(root)
.unwrap_or(&entry.chunk.file)
.to_path_buf();
tracing::info!(entry_point = %entry_name, file = ?entry_file, "Selected entry point");
let graph = store.get_call_graph()?;
let test_chunks = match store.find_test_chunks() {
Ok(tc) => tc,
Err(e) => {
tracing::warn!(error = %e, "Test chunk loading failed, skipping tests");
std::sync::Arc::new(Vec::new())
}
};
let mut callee_scores: HashMap<String, (f32, usize)> = HashMap::new();
callee_scores.insert(entry_name.clone(), (1.0, 0));
let callee_opts = GatherOptions::default()
.with_expand_depth(depth)
.with_direction(GatherDirection::Callees)
.with_decay_factor(0.7)
.with_max_expanded_nodes(100);
let _callee_capped = bfs_expand(&mut callee_scores, &graph, &callee_opts);
callee_scores.remove(&entry_name);
tracing::debug!(callee_count = callee_scores.len(), "Callee BFS complete");
let mut caller_scores: HashMap<String, (f32, usize)> = HashMap::new();
caller_scores.insert(entry_name.clone(), (1.0, 0));
let caller_opts = GatherOptions::default()
.with_expand_depth(1)
.with_direction(GatherDirection::Callers)
.with_decay_factor(0.8)
.with_max_expanded_nodes(50);
let _caller_capped = bfs_expand(&mut caller_scores, &graph, &caller_opts);
caller_scores.remove(&entry_name);
tracing::debug!(caller_count = caller_scores.len(), "Caller BFS complete");
let callee_scores = cap_scores(callee_scores, MAX_CALLEE_FETCH, |(_s, d)| *d);
let caller_scores = cap_scores(caller_scores, MAX_CALLER_FETCH, |(score, _)| {
let safe = if score.is_finite() && *score > 0.0 {
*score
} else {
0.0
};
std::cmp::Reverse((safe * 1e6) as u64)
});
let entry_point = fetch_entry_point(store, &entry_name, &entry_file, root)?;
let (mut callee_chunks, _) = fetch_and_assemble(store, &callee_scores, root);
callee_chunks.sort_by(|a, b| {
a.depth
.cmp(&b.depth)
.then_with(|| a.file.cmp(&b.file))
.then_with(|| a.line_start.cmp(&b.line_start))
});
let call_chain: Vec<OnboardEntry> =
callee_chunks.into_iter().map(gathered_to_onboard).collect();
let (mut caller_chunks, _) = fetch_and_assemble(store, &caller_scores, root);
caller_chunks.sort_by(|a, b| b.score.total_cmp(&a.score));
let callers: Vec<OnboardEntry> = caller_chunks.into_iter().map(gathered_to_onboard).collect();
let key_types = match store.get_types_used_by(&entry_name) {
Ok(types) => filter_common_types(types),
Err(e) => {
tracing::warn!(error = %e, "Type dependency lookup failed, skipping key_types");
Vec::new()
}
};
let tests: Vec<TestEntry> = find_affected_tests_with_chunks(
&graph,
&test_chunks,
&entry_name,
DEFAULT_MAX_TEST_SEARCH_DEPTH,
)
.into_iter()
.map(test_info_to_entry)
.collect();
let mut all_files: std::collections::HashSet<&Path> = std::collections::HashSet::new();
all_files.insert(&entry_point.file);
for c in &call_chain {
all_files.insert(&c.file);
}
for c in &callers {
all_files.insert(&c.file);
}
let max_callee_depth = call_chain.iter().map(|c| c.depth).max().unwrap_or(0);
let summary = OnboardSummary {
total_items: 1 + call_chain.len() + callers.len() + key_types.len() + tests.len(),
files_covered: all_files.len(),
callee_depth: max_callee_depth,
tests_found: tests.len(),
};
tracing::info!(
callees = call_chain.len(),
callers = callers.len(),
types = key_types.len(),
tests = tests.len(),
"Onboard complete"
);
Ok(OnboardResult {
concept: concept.to_string(),
entry_point,
call_chain,
callers,
key_types,
tests,
summary,
})
}
fn cap_scores<F, K>(
scores: HashMap<String, (f32, usize)>,
max: usize,
key_fn: F,
) -> HashMap<String, (f32, usize)>
where
F: Fn(&(f32, usize)) -> K,
K: Ord,
{
if scores.len() <= max {
return scores;
}
let mut entries: Vec<_> = scores.into_iter().collect();
entries.sort_by(|a, b| key_fn(&a.1).cmp(&key_fn(&b.1)));
entries.truncate(max);
entries.into_iter().collect()
}
fn is_callable_type(ct: ChunkType) -> bool {
ct.is_callable()
}
fn gathered_to_onboard(c: GatheredChunk) -> OnboardEntry {
OnboardEntry {
name: c.name,
file: c.file,
line_start: c.line_start,
line_end: c.line_end,
language: c.language,
chunk_type: c.chunk_type,
signature: c.signature,
content: c.content,
depth: c.depth,
}
}
fn fetch_entry_point(
store: &Store,
entry_name: &str,
entry_file: &Path,
root: &Path,
) -> Result<OnboardEntry, AnalysisError> {
let results = store.search_by_name(entry_name, 10)?;
let best = results
.iter()
.filter(|r| r.chunk.name == entry_name)
.max_by(|a, b| {
let a_file_match = a.chunk.file.ends_with(entry_file);
let b_file_match = b.chunk.file.ends_with(entry_file);
a_file_match
.cmp(&b_file_match)
.then_with(|| a.score.total_cmp(&b.score))
})
.or_else(|| {
results.iter().find(|r| r.chunk.file.ends_with(entry_file))
})
.or_else(|| {
results.first()
});
match best {
Some(r) => {
let rel_file = r
.chunk
.file
.strip_prefix(root)
.unwrap_or(&r.chunk.file)
.to_path_buf();
Ok(OnboardEntry {
name: r.chunk.name.clone(),
file: rel_file,
line_start: r.chunk.line_start,
line_end: r.chunk.line_end,
language: r.chunk.language,
chunk_type: r.chunk.chunk_type,
signature: r.chunk.signature.clone(),
content: r.chunk.content.clone(),
depth: 0,
})
}
None => Err(AnalysisError::NotFound(format!(
"Entry point '{entry_name}' not found in index"
))),
}
}
fn filter_common_types(types: Vec<crate::store::TypeUsage>) -> Vec<TypeInfo> {
types
.into_iter()
.filter(|t| !crate::COMMON_TYPES.contains(t.type_name.as_str()))
.map(|t| TypeInfo {
type_name: t.type_name,
edge_kind: t
.edge_kind
.parse::<TypeEdgeKind>()
.unwrap_or(TypeEdgeKind::Param),
})
.collect()
}
fn test_info_to_entry(t: TestInfo) -> TestEntry {
TestEntry {
name: t.name,
file: t.file,
line: t.line,
call_depth: t.call_depth,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_common_types_filtered() {
use crate::store::TypeUsage;
let types = vec![
TypeUsage {
type_name: "String".to_string(),
edge_kind: "Param".to_string(),
},
TypeUsage {
type_name: "Vec".to_string(),
edge_kind: "Return".to_string(),
},
TypeUsage {
type_name: "Store".to_string(),
edge_kind: "Param".to_string(),
},
TypeUsage {
type_name: "Option".to_string(),
edge_kind: "Return".to_string(),
},
];
let filtered = filter_common_types(types);
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].type_name, "Store");
}
#[test]
fn test_common_types_canonical_set_filters_more() {
use crate::store::TypeUsage;
let types = vec![
TypeUsage {
type_name: "Error".to_string(),
edge_kind: "Return".to_string(),
},
TypeUsage {
type_name: "Mutex".to_string(),
edge_kind: "Field".to_string(),
},
TypeUsage {
type_name: "Debug".to_string(),
edge_kind: "Bound".to_string(),
},
TypeUsage {
type_name: "Store".to_string(),
edge_kind: "Param".to_string(),
},
];
let filtered = filter_common_types(types);
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].type_name, "Store");
}
#[test]
fn test_uncommon_types_kept() {
use crate::store::TypeUsage;
let types = vec![
TypeUsage {
type_name: "Embedder".to_string(),
edge_kind: "Param".to_string(),
},
TypeUsage {
type_name: "CallGraph".to_string(),
edge_kind: "Field".to_string(),
},
TypeUsage {
type_name: "SearchFilter".to_string(),
edge_kind: "Param".to_string(),
},
];
let filtered = filter_common_types(types);
assert_eq!(filtered.len(), 3);
}
#[test]
fn test_callee_ordering_by_depth() {
use crate::parser::Language;
let mut chunks = vec![
GatheredChunk {
name: "deep".into(),
file: PathBuf::from("a.rs"),
line_start: 1,
line_end: 10,
language: Language::Rust,
chunk_type: ChunkType::Function,
signature: String::new(),
content: String::new(),
score: 0.5,
depth: 2,
source: None,
},
GatheredChunk {
name: "shallow".into(),
file: PathBuf::from("b.rs"),
line_start: 1,
line_end: 10,
language: Language::Rust,
chunk_type: ChunkType::Function,
signature: String::new(),
content: String::new(),
score: 0.3,
depth: 1,
source: None,
},
];
chunks.sort_by(|a, b| {
a.depth
.cmp(&b.depth)
.then_with(|| a.file.cmp(&b.file))
.then_with(|| a.line_start.cmp(&b.line_start))
});
assert_eq!(chunks[0].name, "shallow"); assert_eq!(chunks[1].name, "deep");
}
#[test]
fn test_entry_point_excluded_from_call_chain() {
let mut scores: HashMap<String, (f32, usize)> = HashMap::new();
scores.insert("entry".into(), (1.0, 0));
scores.insert("callee_a".into(), (0.7, 1));
scores.insert("callee_b".into(), (0.5, 2));
scores.remove("entry");
assert_eq!(scores.len(), 2);
assert!(!scores.contains_key("entry"));
}
#[test]
fn test_test_info_to_entry() {
let info = TestInfo {
name: "test_foo".into(),
file: PathBuf::from("tests/foo.rs"),
line: 10,
call_depth: 2,
};
let entry = test_info_to_entry(info);
assert_eq!(entry.name, "test_foo");
assert_eq!(entry.call_depth, 2);
}
}