1use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13
14use serde::Serialize;
15
16use crate::gather::{
17 bfs_expand, fetch_and_assemble, GatherDirection, GatherOptions, GatheredChunk,
18};
19use crate::impact::{find_affected_tests_with_chunks, TestInfo, DEFAULT_MAX_TEST_SEARCH_DEPTH};
20use crate::language::ChunkType;
21use crate::parser::TypeEdgeKind;
22use crate::store::Store;
23use crate::{AnalysisError, Embedder};
24
25pub const DEFAULT_ONBOARD_DEPTH: usize = 3;
27
28const MAX_CALLEE_FETCH: usize = 30;
31
32const MAX_CALLER_FETCH: usize = 15;
34
35fn serialize_path_forward_slash<S>(path: &std::path::Path, serializer: S) -> Result<S::Ok, S::Error>
36where
37 S: serde::Serializer,
38{
39 serializer.serialize_str(&path.to_string_lossy().replace('\\', "/"))
40}
41
42#[derive(Debug, Clone, Serialize)]
46pub struct OnboardResult {
47 pub concept: String,
48 pub entry_point: OnboardEntry,
49 pub call_chain: Vec<OnboardEntry>,
50 pub callers: Vec<OnboardEntry>,
51 pub key_types: Vec<TypeInfo>,
52 pub tests: Vec<TestEntry>,
53 pub summary: OnboardSummary,
54}
55
56#[derive(Debug, Clone, Serialize)]
58pub struct OnboardEntry {
59 pub name: String,
60 #[serde(serialize_with = "serialize_path_forward_slash")]
61 pub file: PathBuf,
62 pub line_start: u32,
63 pub line_end: u32,
64 pub language: String,
65 pub chunk_type: String,
66 pub signature: String,
67 pub content: String,
68 pub depth: usize,
69}
70
71#[derive(Debug, Clone, Serialize)]
73pub struct TypeInfo {
74 pub type_name: String,
75 pub edge_kind: TypeEdgeKind,
76}
77
78#[derive(Debug, Clone, Serialize)]
80pub struct TestEntry {
81 pub name: String,
82 #[serde(serialize_with = "serialize_path_forward_slash")]
83 pub file: PathBuf,
84 pub line: u32,
85 pub call_depth: usize,
86}
87
88#[derive(Debug, Clone, Serialize)]
90pub struct OnboardSummary {
91 pub total_items: usize,
92 pub files_covered: usize,
93 pub callee_depth: usize,
94 pub tests_found: usize,
95}
96
97pub fn onboard(
101 store: &Store,
102 embedder: &Embedder,
103 concept: &str,
104 root: &Path,
105 depth: usize,
106) -> Result<OnboardResult, AnalysisError> {
107 let _span = tracing::info_span!("onboard", concept).entered();
108 let depth = depth.min(10);
109
110 let query_embedding = embedder
112 .embed_query(concept)
113 .map_err(|e| AnalysisError::Embedder(e.to_string()))?;
114 let filter = crate::store::SearchFilter::default();
115 let results = store.search_filtered(&query_embedding, &filter, 10, 0.0)?;
116
117 if results.is_empty() {
118 return Err(AnalysisError::NotFound(format!(
119 "No relevant code found for concept: {concept}"
120 )));
121 }
122
123 let entry = results
125 .iter()
126 .find(|r| is_callable_type(r.chunk.chunk_type))
127 .or(results.first())
128 .unwrap();
129 let entry_name = entry.chunk.name.clone();
130 let entry_file = entry
131 .chunk
132 .file
133 .strip_prefix(root)
134 .unwrap_or(&entry.chunk.file)
135 .to_path_buf();
136 tracing::info!(entry_point = %entry_name, file = ?entry_file, "Selected entry point");
137
138 let graph = store.get_call_graph()?;
140
141 let test_chunks = match store.find_test_chunks() {
142 Ok(tc) => tc,
143 Err(e) => {
144 tracing::warn!(error = %e, "Test chunk loading failed, skipping tests");
145 Vec::new()
146 }
147 };
148
149 let mut callee_scores: HashMap<String, (f32, usize)> = HashMap::new();
151 callee_scores.insert(entry_name.clone(), (1.0, 0));
152 let callee_opts = GatherOptions::default()
153 .with_expand_depth(depth)
154 .with_direction(GatherDirection::Callees)
155 .with_decay_factor(0.7)
156 .with_max_expanded_nodes(100);
157 let _callee_capped = bfs_expand(&mut callee_scores, &graph, &callee_opts);
158
159 callee_scores.remove(&entry_name);
161 tracing::debug!(callee_count = callee_scores.len(), "Callee BFS complete");
162
163 let mut caller_scores: HashMap<String, (f32, usize)> = HashMap::new();
165 caller_scores.insert(entry_name.clone(), (1.0, 0));
166 let caller_opts = GatherOptions::default()
167 .with_expand_depth(1)
168 .with_direction(GatherDirection::Callers)
169 .with_decay_factor(0.8)
170 .with_max_expanded_nodes(50);
171 let _caller_capped = bfs_expand(&mut caller_scores, &graph, &caller_opts);
172
173 caller_scores.remove(&entry_name);
175 tracing::debug!(caller_count = caller_scores.len(), "Caller BFS complete");
176
177 let callee_scores = cap_scores(callee_scores, MAX_CALLEE_FETCH, |(_s, d)| *d);
180 let caller_scores = cap_scores(caller_scores, MAX_CALLER_FETCH, |(score, _)| {
181 u64::MAX - ((*score * 1e6) as u64)
183 });
184
185 let entry_point = fetch_entry_point(store, &entry_name, &entry_file, root)?;
189
190 let (mut callee_chunks, _) = fetch_and_assemble(store, &callee_scores, root);
191 callee_chunks.sort_by(|a, b| {
193 a.depth
194 .cmp(&b.depth)
195 .then_with(|| a.file.cmp(&b.file))
196 .then_with(|| a.line_start.cmp(&b.line_start))
197 });
198 let call_chain: Vec<OnboardEntry> =
199 callee_chunks.into_iter().map(gathered_to_onboard).collect();
200
201 let (mut caller_chunks, _) = fetch_and_assemble(store, &caller_scores, root);
202 caller_chunks.sort_by(|a, b| {
204 b.score
205 .partial_cmp(&a.score)
206 .unwrap_or(std::cmp::Ordering::Equal)
207 });
208 let callers: Vec<OnboardEntry> = caller_chunks.into_iter().map(gathered_to_onboard).collect();
209
210 let key_types = match store.get_types_used_by(&entry_name) {
212 Ok(types) => filter_common_types(types),
213 Err(e) => {
214 tracing::warn!(error = %e, "Type dependency lookup failed, skipping key_types");
215 Vec::new()
216 }
217 };
218
219 let tests: Vec<TestEntry> = find_affected_tests_with_chunks(
221 &graph,
222 &test_chunks,
223 &entry_name,
224 DEFAULT_MAX_TEST_SEARCH_DEPTH,
225 )
226 .into_iter()
227 .map(test_info_to_entry)
228 .collect();
229
230 let mut all_files: std::collections::HashSet<&Path> = std::collections::HashSet::new();
232 all_files.insert(&entry_point.file);
233 for c in &call_chain {
234 all_files.insert(&c.file);
235 }
236 for c in &callers {
237 all_files.insert(&c.file);
238 }
239
240 let max_callee_depth = call_chain.iter().map(|c| c.depth).max().unwrap_or(0);
241
242 let summary = OnboardSummary {
243 total_items: 1 + call_chain.len() + callers.len() + key_types.len() + tests.len(),
244 files_covered: all_files.len(),
245 callee_depth: max_callee_depth,
246 tests_found: tests.len(),
247 };
248
249 tracing::info!(
250 callees = call_chain.len(),
251 callers = callers.len(),
252 types = key_types.len(),
253 tests = tests.len(),
254 "Onboard complete"
255 );
256
257 Ok(OnboardResult {
258 concept: concept.to_string(),
259 entry_point,
260 call_chain,
261 callers,
262 key_types,
263 tests,
264 summary,
265 })
266}
267
268pub fn onboard_to_json(result: &OnboardResult) -> Result<serde_json::Value, serde_json::Error> {
270 serde_json::to_value(result)
271}
272
273fn cap_scores<F, K>(
277 scores: HashMap<String, (f32, usize)>,
278 max: usize,
279 key_fn: F,
280) -> HashMap<String, (f32, usize)>
281where
282 F: Fn(&(f32, usize)) -> K,
283 K: Ord,
284{
285 if scores.len() <= max {
286 return scores;
287 }
288 let mut entries: Vec<_> = scores.into_iter().collect();
289 entries.sort_by(|a, b| key_fn(&a.1).cmp(&key_fn(&b.1)));
290 entries.truncate(max);
291 entries.into_iter().collect()
292}
293
294fn is_callable_type(ct: ChunkType) -> bool {
296 ct.is_callable()
297}
298
299fn gathered_to_onboard(c: GatheredChunk) -> OnboardEntry {
301 OnboardEntry {
302 name: c.name,
303 file: c.file,
304 line_start: c.line_start,
305 line_end: c.line_end,
306 language: c.language.to_string(),
307 chunk_type: c.chunk_type.to_string(),
308 signature: c.signature,
309 content: c.content,
310 depth: c.depth,
311 }
312}
313
314fn fetch_entry_point(
320 store: &Store,
321 entry_name: &str,
322 entry_file: &Path,
323 root: &Path,
324) -> Result<OnboardEntry, AnalysisError> {
325 let results = store.search_by_name(entry_name, 10)?;
326
327 let best = results
329 .iter()
330 .filter(|r| r.chunk.name == entry_name)
331 .max_by(|a, b| {
332 let a_file_match = a.chunk.file.ends_with(entry_file);
334 let b_file_match = b.chunk.file.ends_with(entry_file);
335 a_file_match.cmp(&b_file_match).then_with(|| {
336 a.score
337 .partial_cmp(&b.score)
338 .unwrap_or(std::cmp::Ordering::Equal)
339 })
340 })
341 .or_else(|| {
342 results.iter().find(|r| r.chunk.file.ends_with(entry_file))
344 })
345 .or_else(|| {
346 results.first()
348 });
349
350 match best {
351 Some(r) => {
352 let rel_file = r
353 .chunk
354 .file
355 .strip_prefix(root)
356 .unwrap_or(&r.chunk.file)
357 .to_path_buf();
358 Ok(OnboardEntry {
359 name: r.chunk.name.clone(),
360 file: rel_file,
361 line_start: r.chunk.line_start,
362 line_end: r.chunk.line_end,
363 language: r.chunk.language.to_string(),
364 chunk_type: r.chunk.chunk_type.to_string(),
365 signature: r.chunk.signature.clone(),
366 content: r.chunk.content.clone(),
367 depth: 0,
368 })
369 }
370 None => Err(AnalysisError::NotFound(format!(
371 "Entry point '{entry_name}' not found in index"
372 ))),
373 }
374}
375
376fn filter_common_types(types: Vec<crate::store::TypeUsage>) -> Vec<TypeInfo> {
380 types
381 .into_iter()
382 .filter(|t| !crate::COMMON_TYPES.contains(t.type_name.as_str()))
383 .map(|t| TypeInfo {
384 type_name: t.type_name,
385 edge_kind: t
386 .edge_kind
387 .parse::<TypeEdgeKind>()
388 .unwrap_or(TypeEdgeKind::Param),
389 })
390 .collect()
391}
392
393fn test_info_to_entry(t: TestInfo) -> TestEntry {
395 TestEntry {
396 name: t.name,
397 file: t.file,
398 line: t.line,
399 call_depth: t.call_depth,
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_common_types_filtered() {
409 use crate::store::TypeUsage;
410 let types = vec![
411 TypeUsage {
412 type_name: "String".to_string(),
413 edge_kind: "Param".to_string(),
414 },
415 TypeUsage {
416 type_name: "Vec".to_string(),
417 edge_kind: "Return".to_string(),
418 },
419 TypeUsage {
420 type_name: "Store".to_string(),
421 edge_kind: "Param".to_string(),
422 },
423 TypeUsage {
424 type_name: "Option".to_string(),
425 edge_kind: "Return".to_string(),
426 },
427 ];
428 let filtered = filter_common_types(types);
429 assert_eq!(filtered.len(), 1);
430 assert_eq!(filtered[0].type_name, "Store");
431 }
432
433 #[test]
434 fn test_common_types_canonical_set_filters_more() {
435 use crate::store::TypeUsage;
436 let types = vec![
440 TypeUsage {
441 type_name: "Error".to_string(),
442 edge_kind: "Return".to_string(),
443 },
444 TypeUsage {
445 type_name: "Mutex".to_string(),
446 edge_kind: "Field".to_string(),
447 },
448 TypeUsage {
449 type_name: "Debug".to_string(),
450 edge_kind: "Bound".to_string(),
451 },
452 TypeUsage {
453 type_name: "Store".to_string(),
454 edge_kind: "Param".to_string(),
455 },
456 ];
457 let filtered = filter_common_types(types);
458 assert_eq!(filtered.len(), 1);
459 assert_eq!(filtered[0].type_name, "Store");
460 }
461
462 #[test]
463 fn test_uncommon_types_kept() {
464 use crate::store::TypeUsage;
465 let types = vec![
466 TypeUsage {
467 type_name: "Embedder".to_string(),
468 edge_kind: "Param".to_string(),
469 },
470 TypeUsage {
471 type_name: "CallGraph".to_string(),
472 edge_kind: "Field".to_string(),
473 },
474 TypeUsage {
475 type_name: "SearchFilter".to_string(),
476 edge_kind: "Param".to_string(),
477 },
478 ];
479 let filtered = filter_common_types(types);
480 assert_eq!(filtered.len(), 3);
481 }
482
483 #[test]
484 fn test_callee_ordering_by_depth() {
485 use crate::parser::Language;
486
487 let mut chunks = vec![
489 GatheredChunk {
490 name: "deep".into(),
491 file: PathBuf::from("a.rs"),
492 line_start: 1,
493 line_end: 10,
494 language: Language::Rust,
495 chunk_type: ChunkType::Function,
496 signature: String::new(),
497 content: String::new(),
498 score: 0.5,
499 depth: 2,
500 source: None,
501 },
502 GatheredChunk {
503 name: "shallow".into(),
504 file: PathBuf::from("b.rs"),
505 line_start: 1,
506 line_end: 10,
507 language: Language::Rust,
508 chunk_type: ChunkType::Function,
509 signature: String::new(),
510 content: String::new(),
511 score: 0.3,
512 depth: 1,
513 source: None,
514 },
515 ];
516 chunks.sort_by(|a, b| {
517 a.depth
518 .cmp(&b.depth)
519 .then_with(|| a.file.cmp(&b.file))
520 .then_with(|| a.line_start.cmp(&b.line_start))
521 });
522 assert_eq!(chunks[0].name, "shallow"); assert_eq!(chunks[1].name, "deep");
524 }
525
526 #[test]
527 fn test_entry_point_excluded_from_call_chain() {
528 let mut scores: HashMap<String, (f32, usize)> = HashMap::new();
530 scores.insert("entry".into(), (1.0, 0));
531 scores.insert("callee_a".into(), (0.7, 1));
532 scores.insert("callee_b".into(), (0.5, 2));
533
534 scores.remove("entry");
535 assert_eq!(scores.len(), 2);
536 assert!(!scores.contains_key("entry"));
537 }
538
539 #[test]
540 fn test_test_info_to_entry() {
541 let info = TestInfo {
542 name: "test_foo".into(),
543 file: PathBuf::from("tests/foo.rs"),
544 line: 10,
545 call_depth: 2,
546 };
547 let entry = test_info_to_entry(info);
548 assert_eq!(entry.name, "test_foo");
549 assert_eq!(entry.call_depth, 2);
550 }
551}