1use crate::state::OciState;
7use crate::types::{InternedString, SymbolKind};
8use anyhow::{Context as _, Result};
9use std::collections::HashSet;
10use std::path::PathBuf;
11
12#[derive(Debug, Clone)]
18pub struct ContextQuery {
19 pub file: PathBuf,
21 pub line: u32,
23 pub surrounding_lines: u32,
25 pub intent: Option<String>,
27 pub max_tokens: usize,
29}
30
31impl ContextQuery {
32 pub fn new(file: PathBuf, line: u32) -> Self {
34 Self {
35 file,
36 line,
37 surrounding_lines: 5,
38 intent: None,
39 max_tokens: 4000,
40 }
41 }
42
43 pub fn with_surrounding_lines(mut self, lines: u32) -> Self {
45 self.surrounding_lines = lines;
46 self
47 }
48
49 pub fn with_intent(mut self, intent: String) -> Self {
51 self.intent = Some(intent);
52 self
53 }
54
55 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
57 self.max_tokens = max_tokens;
58 self
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ContextResult {
65 pub primary: Vec<ContextChunk>,
67 pub related: Vec<ContextChunk>,
69 pub total_tokens: usize,
71}
72
73impl ContextResult {
74 pub fn empty() -> Self {
76 Self {
77 primary: Vec::new(),
78 related: Vec::new(),
79 total_tokens: 0,
80 }
81 }
82
83 pub fn all_chunks(&self) -> Vec<&ContextChunk> {
85 self.primary.iter().chain(self.related.iter()).collect()
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct ContextChunk {
92 pub symbol: Option<InternedString>,
94 pub file: PathBuf,
96 pub content: String,
98 pub relevance: f64,
100 pub reason: String,
102}
103
104impl ContextChunk {
105 pub fn estimate_tokens(&self) -> usize {
108 self.content.len() / 4
109 }
110}
111
112pub struct ContextSynthesizer;
118
119impl ContextSynthesizer {
120 pub fn new() -> Self {
122 Self
123 }
124
125 pub async fn build_context(
127 &self,
128 state: &OciState,
129 query: &ContextQuery,
130 ) -> Result<ContextResult> {
131 let symbol_at_location = self.find_symbol_at_location(state, &query.file, query.line);
133
134 let mut candidates = Vec::new();
136
137 if let Some(current_symbol) = symbol_at_location {
138 candidates.push((current_symbol, 1.0, "Current location".to_string()));
140
141 let callees = state.find_callees(current_symbol);
143 for call_edge in callees {
144 if let Some(resolved) = self.resolve_callee(state, &call_edge.callee_name) {
146 candidates.push((
147 resolved,
148 0.8,
149 format!("Called by {}", state.resolve(current_symbol)),
150 ));
151 }
152 }
153
154 let current_name = state.resolve(current_symbol);
156 let callers = state.find_callers(current_name);
157 for call_edge in callers.iter().take(5) {
158 candidates.push((call_edge.caller, 0.6, format!("Calls {}", current_name)));
160 }
161
162 if let Some(symbol_def) = state.get_symbol(current_symbol) {
164 if let Some(sig) = &symbol_def.signature {
165 let types = self.extract_types_from_signature(sig);
167 for type_name in types {
168 if let Some(type_symbol) = self.find_type_symbol(state, &type_name) {
169 candidates.push((
170 type_symbol,
171 0.5,
172 format!("Type used in signature: {}", type_name),
173 ));
174 }
175 }
176 }
177
178 if let Some(parent) = symbol_def.parent {
180 candidates.push((parent, 0.7, format!("Parent of {}", current_name)));
181 }
182 }
183 }
184
185 let file_id = state.get_or_create_file_id(&query.file);
187 if let Some(imports) = state.imports.get(&file_id) {
188 for import in imports.iter().take(5) {
189 if let Some(import_symbol) = self.find_symbol_by_name(state, &import.name) {
192 candidates.push((import_symbol, 0.4, format!("Imported: {}", import.name)));
193 }
194 }
195 }
196
197 let ranked = self.rank_symbols_with_reasons(state, candidates);
199
200 let mut primary_chunks = Vec::new();
202 let mut related_chunks = Vec::new();
203 let mut total_tokens = 0;
204 let mut seen_symbols = HashSet::new();
205
206 if let Ok(location_chunk) = self
208 .create_location_chunk(state, &query.file, query.line, query.surrounding_lines)
209 .await
210 {
211 total_tokens += location_chunk.estimate_tokens();
212 primary_chunks.push(location_chunk);
213 }
214
215 for (symbol, score, reason) in ranked {
217 if total_tokens >= query.max_tokens {
218 break;
219 }
220
221 if seen_symbols.contains(&symbol) {
223 continue;
224 }
225 seen_symbols.insert(symbol);
226
227 if let Ok(chunk) = self.create_symbol_chunk(state, symbol, score, reason).await {
229 let chunk_tokens = chunk.estimate_tokens();
230
231 if total_tokens + chunk_tokens > query.max_tokens {
232 continue;
234 }
235
236 total_tokens += chunk_tokens;
237
238 if score >= 0.6 {
240 primary_chunks.push(chunk);
241 } else {
242 related_chunks.push(chunk);
243 }
244 }
245 }
246
247 Ok(ContextResult {
248 primary: primary_chunks,
249 related: related_chunks,
250 total_tokens,
251 })
252 }
253
254 pub fn rank_symbols(
256 &self,
257 state: &OciState,
258 symbols: &[InternedString],
259 ) -> Vec<(InternedString, f64)> {
260 let candidates: Vec<_> = symbols
261 .iter()
262 .map(|s| (*s, 1.0, "Candidate".to_string()))
263 .collect();
264
265 self.rank_symbols_with_reasons(state, candidates)
266 .into_iter()
267 .map(|(sym, score, _)| (sym, score))
268 .collect()
269 }
270
271 fn find_symbol_at_location(
277 &self,
278 state: &OciState,
279 file: &PathBuf,
280 line: u32,
281 ) -> Option<InternedString> {
282 let file_id = state.file_ids.get(file)?;
284 let file_symbols = state.file_symbols.get(&file_id)?;
285
286 for scoped_name in file_symbols.iter() {
288 if let Some(symbol) = state.get_symbol(*scoped_name) {
289 if symbol.location.start_line <= line as usize
290 && symbol.location.end_line >= line as usize
291 {
292 return Some(*scoped_name);
293 }
294 }
295 }
296
297 None
298 }
299
300 fn resolve_callee(&self, state: &OciState, callee_name: &str) -> Option<InternedString> {
302 let symbols = state.find_by_name(callee_name);
304 if symbols.is_empty() {
305 return None;
306 }
307
308 for symbol in &symbols {
310 if matches!(symbol.visibility, crate::types::Visibility::Public) {
311 return Some(symbol.scoped_name);
312 }
313 }
314
315 symbols.first().map(|s| s.scoped_name)
317 }
318
319 fn extract_types_from_signature(&self, sig: &crate::types::Signature) -> Vec<String> {
321 let mut types = Vec::new();
322
323 for param in &sig.params {
325 if let Some(type_name) = self.extract_type_name(param) {
326 types.push(type_name);
327 }
328 }
329
330 if let Some(ret_type) = &sig.return_type {
332 if let Some(type_name) = self.extract_type_name(ret_type) {
333 types.push(type_name);
334 }
335 }
336
337 types
338 }
339
340 fn extract_type_name(&self, type_str: &str) -> Option<String> {
342 let trimmed = type_str.trim();
345
346 let without_refs = trimmed.trim_start_matches('&').trim_start_matches("mut ");
348
349 let main_type = without_refs.split('<').next()?.split_whitespace().next()?;
351
352 if main_type.is_empty() || main_type.starts_with(char::is_lowercase) {
353 None
355 } else {
356 Some(main_type.to_string())
357 }
358 }
359
360 fn find_type_symbol(&self, state: &OciState, type_name: &str) -> Option<InternedString> {
362 let symbols = state.find_by_name(type_name);
363
364 for symbol in symbols {
365 if matches!(
366 symbol.kind,
367 SymbolKind::Struct | SymbolKind::Enum | SymbolKind::Trait
368 ) {
369 return Some(symbol.scoped_name);
370 }
371 }
372
373 None
374 }
375
376 fn find_symbol_by_name(&self, state: &OciState, name: &str) -> Option<InternedString> {
378 let symbols = state.find_by_name(name);
379 if symbols.is_empty() {
380 return None;
381 }
382
383 let scoped_names: Vec<_> = symbols.iter().map(|s| s.scoped_name).collect();
385 let ranked = self.rank_symbols(state, &scoped_names);
386
387 ranked.first().map(|(sym, _)| *sym)
388 }
389
390 fn rank_symbols_with_reasons(
392 &self,
393 state: &OciState,
394 candidates: Vec<(InternedString, f64, String)>,
395 ) -> Vec<(InternedString, f64, String)> {
396 let mut scored: Vec<_> = candidates
397 .into_iter()
398 .map(|(symbol, base_score, reason)| {
399 let pagerank_score = self.get_pagerank_score(state, symbol);
400 let combined_score = base_score * 0.7 + pagerank_score * 0.3;
401 (symbol, combined_score, reason)
402 })
403 .collect();
404
405 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408 scored
409 }
410
411 fn get_pagerank_score(&self, state: &OciState, symbol: InternedString) -> f64 {
413 let symbol_def = match state.get_symbol(symbol) {
415 Some(s) => s,
416 None => return 0.0,
417 };
418
419 let node_idx = match state.path_to_node.get(&symbol_def.location.file) {
421 Some(idx) => *idx,
422 None => return 0.0,
423 };
424
425 state
427 .topology_metrics
428 .get(&node_idx)
429 .map(|m| m.relevance_score)
430 .unwrap_or(0.0)
431 }
432
433 async fn create_location_chunk(
435 &self,
436 state: &OciState,
437 file: &PathBuf,
438 line: u32,
439 surrounding_lines: u32,
440 ) -> Result<ContextChunk> {
441 let contents = state
442 .get_file_contents(file)
443 .await
444 .context("Failed to read file")?;
445
446 let lines: Vec<&str> = contents.lines().collect();
447 let total_lines = lines.len();
448
449 let start_line = (line as i32 - surrounding_lines as i32).max(0) as usize;
450 let end_line = ((line + surrounding_lines) as usize).min(total_lines);
451
452 let content = lines[start_line..end_line].join("\n");
453
454 Ok(ContextChunk {
455 symbol: None,
456 file: file.clone(),
457 content,
458 relevance: 1.0,
459 reason: format!("Query location at line {}", line),
460 })
461 }
462
463 async fn create_symbol_chunk(
465 &self,
466 state: &OciState,
467 symbol: InternedString,
468 relevance: f64,
469 reason: String,
470 ) -> Result<ContextChunk> {
471 let symbol_def = state
472 .get_symbol(symbol)
473 .context("Symbol not found in state")?;
474
475 let contents = state
476 .get_file_contents(&symbol_def.location.file)
477 .await
478 .context("Failed to read file")?;
479
480 let lines: Vec<&str> = contents.lines().collect();
481
482 let start_line = symbol_def.location.start_line.saturating_sub(1);
484 let end_line = symbol_def.location.end_line;
485
486 if start_line >= lines.len() || end_line > lines.len() {
487 anyhow::bail!("Symbol location out of bounds");
488 }
489
490 let content = lines[start_line..end_line].join("\n");
491
492 Ok(ContextChunk {
493 symbol: Some(symbol),
494 file: symbol_def.location.file.clone(),
495 content,
496 relevance,
497 reason,
498 })
499 }
500}
501
502impl Default for ContextSynthesizer {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::state::create_state;
516 use tempfile::TempDir;
517
518 #[tokio::test]
519 async fn test_build_empty_context() {
520 let temp = TempDir::new().unwrap();
521 let state = create_state(temp.path().to_path_buf());
522 let synthesizer = ContextSynthesizer::new();
523
524 let test_file = temp.path().join("test.rs");
525 std::fs::write(&test_file, "fn main() {}").unwrap();
526
527 let query = ContextQuery::new(test_file, 1);
528 let result = synthesizer.build_context(&state, &query).await;
529
530 assert!(result.is_ok());
532 }
533
534 #[test]
535 fn test_rank_symbols_empty() {
536 let temp = TempDir::new().unwrap();
537 let state = create_state(temp.path().to_path_buf());
538 let synthesizer = ContextSynthesizer::new();
539
540 let ranked = synthesizer.rank_symbols(&state, &[]);
541 assert!(ranked.is_empty());
542 }
543
544 #[test]
545 fn test_extract_type_name() {
546 let synthesizer = ContextSynthesizer::new();
547
548 assert_eq!(
549 synthesizer.extract_type_name("Vec<String>"),
550 Some("Vec".to_string())
551 );
552 assert_eq!(
553 synthesizer.extract_type_name("&mut Foo"),
554 Some("Foo".to_string())
555 );
556 assert_eq!(
557 synthesizer.extract_type_name("&Bar"),
558 Some("Bar".to_string())
559 );
560 assert_eq!(synthesizer.extract_type_name("i32"), None); }
562
563 #[tokio::test]
564 async fn test_context_chunk_token_estimation() {
565 let chunk = ContextChunk {
566 symbol: None,
567 file: PathBuf::from("test.rs"),
568 content: "a".repeat(400), relevance: 1.0,
570 reason: "Test".to_string(),
571 };
572
573 assert_eq!(chunk.estimate_tokens(), 100);
575 }
576
577 #[tokio::test]
578 async fn test_context_query_builder() {
579 let query = ContextQuery::new(PathBuf::from("test.rs"), 10)
580 .with_surrounding_lines(3)
581 .with_max_tokens(2000)
582 .with_intent("Testing".to_string());
583
584 assert_eq!(query.line, 10);
585 assert_eq!(query.surrounding_lines, 3);
586 assert_eq!(query.max_tokens, 2000);
587 assert_eq!(query.intent, Some("Testing".to_string()));
588 }
589
590 #[test]
591 fn test_context_result_all_chunks() {
592 let mut result = ContextResult::empty();
593
594 result.primary.push(ContextChunk {
595 symbol: None,
596 file: PathBuf::from("a.rs"),
597 content: "primary".to_string(),
598 relevance: 1.0,
599 reason: "test".to_string(),
600 });
601
602 result.related.push(ContextChunk {
603 symbol: None,
604 file: PathBuf::from("b.rs"),
605 content: "related".to_string(),
606 relevance: 0.5,
607 reason: "test".to_string(),
608 });
609
610 let all = result.all_chunks();
611 assert_eq!(all.len(), 2);
612 assert_eq!(all[0].content, "primary");
613 assert_eq!(all[1].content, "related");
614 }
615}