1use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13
14use serde::Serialize;
15
16use crate::gather::{
17 bfs_expand, fetch_and_assemble, GatherDirection, GatherOptions, GatheredChunk,
18};
19use crate::impact::{find_affected_tests_with_chunks, TestInfo, DEFAULT_MAX_TEST_SEARCH_DEPTH};
20use crate::language::{ChunkType, Language};
21use crate::parser::TypeEdgeKind;
22use crate::store::Store;
23use crate::{AnalysisError, Embedder};
24
25pub const DEFAULT_ONBOARD_DEPTH: usize = 3;
27
28const MAX_CALLEE_FETCH: usize = 30;
31
32const MAX_CALLER_FETCH: usize = 15;
34
35#[derive(Debug, Clone, Serialize)]
39pub struct OnboardResult {
40 pub concept: String,
41 pub entry_point: OnboardEntry,
42 pub call_chain: Vec<OnboardEntry>,
43 pub callers: Vec<OnboardEntry>,
44 pub key_types: Vec<TypeInfo>,
45 pub tests: Vec<TestEntry>,
46 pub summary: OnboardSummary,
47}
48
49#[derive(Debug, Clone, Serialize)]
51pub struct OnboardEntry {
52 pub name: String,
53 #[serde(serialize_with = "crate::serialize_path_normalized")]
54 pub file: PathBuf,
55 pub line_start: u32,
56 pub line_end: u32,
57 pub language: Language,
58 pub chunk_type: ChunkType,
59 pub signature: String,
60 pub content: String,
61 pub depth: usize,
62}
63
64#[derive(Debug, Clone, Serialize)]
66pub struct TypeInfo {
67 pub type_name: String,
68 pub edge_kind: TypeEdgeKind,
69}
70
71#[derive(Debug, Clone, Serialize)]
73pub struct TestEntry {
74 pub name: String,
75 #[serde(serialize_with = "crate::serialize_path_normalized")]
76 pub file: PathBuf,
77 pub line: u32,
78 pub call_depth: usize,
79}
80
81#[derive(Debug, Clone, Serialize)]
83pub struct OnboardSummary {
84 pub total_items: usize,
85 pub files_covered: usize,
86 pub callee_depth: usize,
87 pub tests_found: usize,
88}
89
90pub fn onboard(
94 store: &Store,
95 embedder: &Embedder,
96 concept: &str,
97 root: &Path,
98 depth: usize,
99) -> Result<OnboardResult, AnalysisError> {
100 let _span = tracing::info_span!("onboard", concept).entered();
101 let depth = depth.min(10);
102
103 let query_embedding = embedder.embed_query(concept)?;
105 let filter = crate::store::SearchFilter {
106 query_text: concept.to_string(),
107 enable_rrf: false, ..crate::store::SearchFilter::default()
109 };
110 let results = store.search_filtered(&query_embedding, &filter, 10, 0.0)?;
111
112 if results.is_empty() {
113 return Err(AnalysisError::NotFound(format!(
114 "No relevant code found for concept: {concept}"
115 )));
116 }
117
118 let entry = results
120 .iter()
121 .find(|r| is_callable_type(r.chunk.chunk_type))
122 .or(results.first())
123 .expect("results guaranteed non-empty by early return above");
124 let entry_name = entry.chunk.name.clone();
125 let entry_file = entry
126 .chunk
127 .file
128 .strip_prefix(root)
129 .unwrap_or(&entry.chunk.file)
130 .to_path_buf();
131 tracing::info!(entry_point = %entry_name, file = ?entry_file, "Selected entry point");
132
133 let graph = store.get_call_graph()?;
135
136 let test_chunks = match store.find_test_chunks() {
137 Ok(tc) => tc,
138 Err(e) => {
139 tracing::warn!(error = %e, "Test chunk loading failed, skipping tests");
140 std::sync::Arc::new(Vec::new())
141 }
142 };
143
144 let mut callee_scores: HashMap<String, (f32, usize)> = HashMap::new();
146 callee_scores.insert(entry_name.clone(), (1.0, 0));
147 let callee_opts = GatherOptions::default()
148 .with_expand_depth(depth)
149 .with_direction(GatherDirection::Callees)
150 .with_decay_factor(0.7)
151 .with_max_expanded_nodes(100);
152 let _callee_capped = bfs_expand(&mut callee_scores, &graph, &callee_opts);
153
154 callee_scores.remove(&entry_name);
156 tracing::debug!(callee_count = callee_scores.len(), "Callee BFS complete");
157
158 let mut caller_scores: HashMap<String, (f32, usize)> = HashMap::new();
160 caller_scores.insert(entry_name.clone(), (1.0, 0));
161 let caller_opts = GatherOptions::default()
162 .with_expand_depth(1)
163 .with_direction(GatherDirection::Callers)
164 .with_decay_factor(0.8)
165 .with_max_expanded_nodes(50);
166 let _caller_capped = bfs_expand(&mut caller_scores, &graph, &caller_opts);
167
168 caller_scores.remove(&entry_name);
170 tracing::debug!(caller_count = caller_scores.len(), "Caller BFS complete");
171
172 let callee_scores = cap_scores(callee_scores, MAX_CALLEE_FETCH, |(_s, d)| *d);
175 let caller_scores = cap_scores(caller_scores, MAX_CALLER_FETCH, |(score, _)| {
176 let safe = if score.is_finite() && *score > 0.0 {
178 *score
179 } else {
180 0.0
181 };
182 std::cmp::Reverse((safe * 1e6) as u64)
183 });
184
185 let entry_point = fetch_entry_point(store, &entry_name, &entry_file, root)?;
189
190 let (mut callee_chunks, _) = fetch_and_assemble(store, &callee_scores, root);
191 callee_chunks.sort_by(|a, b| {
193 a.depth
194 .cmp(&b.depth)
195 .then_with(|| a.file.cmp(&b.file))
196 .then_with(|| a.line_start.cmp(&b.line_start))
197 });
198 let call_chain: Vec<OnboardEntry> =
199 callee_chunks.into_iter().map(gathered_to_onboard).collect();
200
201 let (mut caller_chunks, _) = fetch_and_assemble(store, &caller_scores, root);
202 caller_chunks.sort_by(|a, b| b.score.total_cmp(&a.score));
204 let callers: Vec<OnboardEntry> = caller_chunks.into_iter().map(gathered_to_onboard).collect();
205
206 let key_types = match store.get_types_used_by(&entry_name) {
208 Ok(types) => filter_common_types(types),
209 Err(e) => {
210 tracing::warn!(error = %e, "Type dependency lookup failed, skipping key_types");
211 Vec::new()
212 }
213 };
214
215 let tests: Vec<TestEntry> = find_affected_tests_with_chunks(
217 &graph,
218 &test_chunks,
219 &entry_name,
220 DEFAULT_MAX_TEST_SEARCH_DEPTH,
221 )
222 .into_iter()
223 .map(test_info_to_entry)
224 .collect();
225
226 let mut all_files: std::collections::HashSet<&Path> = std::collections::HashSet::new();
228 all_files.insert(&entry_point.file);
229 for c in &call_chain {
230 all_files.insert(&c.file);
231 }
232 for c in &callers {
233 all_files.insert(&c.file);
234 }
235
236 let max_callee_depth = call_chain.iter().map(|c| c.depth).max().unwrap_or(0);
237
238 let summary = OnboardSummary {
239 total_items: 1 + call_chain.len() + callers.len() + key_types.len() + tests.len(),
240 files_covered: all_files.len(),
241 callee_depth: max_callee_depth,
242 tests_found: tests.len(),
243 };
244
245 tracing::info!(
246 callees = call_chain.len(),
247 callers = callers.len(),
248 types = key_types.len(),
249 tests = tests.len(),
250 "Onboard complete"
251 );
252
253 Ok(OnboardResult {
254 concept: concept.to_string(),
255 entry_point,
256 call_chain,
257 callers,
258 key_types,
259 tests,
260 summary,
261 })
262}
263
264pub fn onboard_to_json(result: &OnboardResult) -> Result<serde_json::Value, serde_json::Error> {
266 serde_json::to_value(result)
267}
268
269fn cap_scores<F, K>(
273 scores: HashMap<String, (f32, usize)>,
274 max: usize,
275 key_fn: F,
276) -> HashMap<String, (f32, usize)>
277where
278 F: Fn(&(f32, usize)) -> K,
279 K: Ord,
280{
281 if scores.len() <= max {
282 return scores;
283 }
284 let mut entries: Vec<_> = scores.into_iter().collect();
285 entries.sort_by(|a, b| key_fn(&a.1).cmp(&key_fn(&b.1)));
286 entries.truncate(max);
287 entries.into_iter().collect()
288}
289
290fn is_callable_type(ct: ChunkType) -> bool {
292 ct.is_callable()
293}
294
295fn gathered_to_onboard(c: GatheredChunk) -> OnboardEntry {
297 OnboardEntry {
298 name: c.name,
299 file: c.file,
300 line_start: c.line_start,
301 line_end: c.line_end,
302 language: c.language,
303 chunk_type: c.chunk_type,
304 signature: c.signature,
305 content: c.content,
306 depth: c.depth,
307 }
308}
309
310fn fetch_entry_point(
316 store: &Store,
317 entry_name: &str,
318 entry_file: &Path,
319 root: &Path,
320) -> Result<OnboardEntry, AnalysisError> {
321 let results = store.search_by_name(entry_name, 10)?;
322
323 let best = results
325 .iter()
326 .filter(|r| r.chunk.name == entry_name)
327 .max_by(|a, b| {
328 let a_file_match = a.chunk.file.ends_with(entry_file);
330 let b_file_match = b.chunk.file.ends_with(entry_file);
331 a_file_match
332 .cmp(&b_file_match)
333 .then_with(|| a.score.total_cmp(&b.score))
334 })
335 .or_else(|| {
336 results.iter().find(|r| r.chunk.file.ends_with(entry_file))
338 })
339 .or_else(|| {
340 results.first()
342 });
343
344 match best {
345 Some(r) => {
346 let rel_file = r
347 .chunk
348 .file
349 .strip_prefix(root)
350 .unwrap_or(&r.chunk.file)
351 .to_path_buf();
352 Ok(OnboardEntry {
353 name: r.chunk.name.clone(),
354 file: rel_file,
355 line_start: r.chunk.line_start,
356 line_end: r.chunk.line_end,
357 language: r.chunk.language,
358 chunk_type: r.chunk.chunk_type,
359 signature: r.chunk.signature.clone(),
360 content: r.chunk.content.clone(),
361 depth: 0,
362 })
363 }
364 None => Err(AnalysisError::NotFound(format!(
365 "Entry point '{entry_name}' not found in index"
366 ))),
367 }
368}
369
370fn filter_common_types(types: Vec<crate::store::TypeUsage>) -> Vec<TypeInfo> {
374 types
375 .into_iter()
376 .filter(|t| !crate::COMMON_TYPES.contains(t.type_name.as_str()))
377 .map(|t| TypeInfo {
378 type_name: t.type_name,
379 edge_kind: t
380 .edge_kind
381 .parse::<TypeEdgeKind>()
382 .unwrap_or(TypeEdgeKind::Param),
383 })
384 .collect()
385}
386
387fn test_info_to_entry(t: TestInfo) -> TestEntry {
389 TestEntry {
390 name: t.name,
391 file: t.file,
392 line: t.line,
393 call_depth: t.call_depth,
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_common_types_filtered() {
403 use crate::store::TypeUsage;
404 let types = vec![
405 TypeUsage {
406 type_name: "String".to_string(),
407 edge_kind: "Param".to_string(),
408 },
409 TypeUsage {
410 type_name: "Vec".to_string(),
411 edge_kind: "Return".to_string(),
412 },
413 TypeUsage {
414 type_name: "Store".to_string(),
415 edge_kind: "Param".to_string(),
416 },
417 TypeUsage {
418 type_name: "Option".to_string(),
419 edge_kind: "Return".to_string(),
420 },
421 ];
422 let filtered = filter_common_types(types);
423 assert_eq!(filtered.len(), 1);
424 assert_eq!(filtered[0].type_name, "Store");
425 }
426
427 #[test]
428 fn test_common_types_canonical_set_filters_more() {
429 use crate::store::TypeUsage;
430 let types = vec![
434 TypeUsage {
435 type_name: "Error".to_string(),
436 edge_kind: "Return".to_string(),
437 },
438 TypeUsage {
439 type_name: "Mutex".to_string(),
440 edge_kind: "Field".to_string(),
441 },
442 TypeUsage {
443 type_name: "Debug".to_string(),
444 edge_kind: "Bound".to_string(),
445 },
446 TypeUsage {
447 type_name: "Store".to_string(),
448 edge_kind: "Param".to_string(),
449 },
450 ];
451 let filtered = filter_common_types(types);
452 assert_eq!(filtered.len(), 1);
453 assert_eq!(filtered[0].type_name, "Store");
454 }
455
456 #[test]
457 fn test_uncommon_types_kept() {
458 use crate::store::TypeUsage;
459 let types = vec![
460 TypeUsage {
461 type_name: "Embedder".to_string(),
462 edge_kind: "Param".to_string(),
463 },
464 TypeUsage {
465 type_name: "CallGraph".to_string(),
466 edge_kind: "Field".to_string(),
467 },
468 TypeUsage {
469 type_name: "SearchFilter".to_string(),
470 edge_kind: "Param".to_string(),
471 },
472 ];
473 let filtered = filter_common_types(types);
474 assert_eq!(filtered.len(), 3);
475 }
476
477 #[test]
478 fn test_callee_ordering_by_depth() {
479 use crate::parser::Language;
480
481 let mut chunks = vec![
483 GatheredChunk {
484 name: "deep".into(),
485 file: PathBuf::from("a.rs"),
486 line_start: 1,
487 line_end: 10,
488 language: Language::Rust,
489 chunk_type: ChunkType::Function,
490 signature: String::new(),
491 content: String::new(),
492 score: 0.5,
493 depth: 2,
494 source: None,
495 },
496 GatheredChunk {
497 name: "shallow".into(),
498 file: PathBuf::from("b.rs"),
499 line_start: 1,
500 line_end: 10,
501 language: Language::Rust,
502 chunk_type: ChunkType::Function,
503 signature: String::new(),
504 content: String::new(),
505 score: 0.3,
506 depth: 1,
507 source: None,
508 },
509 ];
510 chunks.sort_by(|a, b| {
511 a.depth
512 .cmp(&b.depth)
513 .then_with(|| a.file.cmp(&b.file))
514 .then_with(|| a.line_start.cmp(&b.line_start))
515 });
516 assert_eq!(chunks[0].name, "shallow"); assert_eq!(chunks[1].name, "deep");
518 }
519
520 #[test]
521 fn test_entry_point_excluded_from_call_chain() {
522 let mut scores: HashMap<String, (f32, usize)> = HashMap::new();
524 scores.insert("entry".into(), (1.0, 0));
525 scores.insert("callee_a".into(), (0.7, 1));
526 scores.insert("callee_b".into(), (0.5, 2));
527
528 scores.remove("entry");
529 assert_eq!(scores.len(), 2);
530 assert!(!scores.contains_key("entry"));
531 }
532
533 #[test]
534 fn test_test_info_to_entry() {
535 let info = TestInfo {
536 name: "test_foo".into(),
537 file: PathBuf::from("tests/foo.rs"),
538 line: 10,
539 call_depth: 2,
540 };
541 let entry = test_info_to_entry(info);
542 assert_eq!(entry.name, "test_foo");
543 assert_eq!(entry.call_depth, 2);
544 }
545}