Skip to main content

omni_index/context/
mod.rs

1//! Context synthesis ("Ghost Docs").
2//!
3//! Auto-generates architectural context documents by intelligently assembling
4//! relevant code snippets based on call graphs, type relationships, and PageRank scores.
5
6use crate::state::OciState;
7use crate::types::{InternedString, SymbolKind};
8use anyhow::{Context as _, Result};
9use std::collections::HashSet;
10use std::path::PathBuf;
11
12// ============================================================================
13// Public Types
14// ============================================================================
15
16/// Query for context assembly.
17#[derive(Debug, Clone)]
18pub struct ContextQuery {
19    /// The file of interest
20    pub file: PathBuf,
21    /// The line number of interest
22    pub line: u32,
23    /// Number of lines to include around the query point
24    pub surrounding_lines: u32,
25    /// Optional intent describing what the user is trying to do
26    pub intent: Option<String>,
27    /// Maximum tokens to include in the result
28    pub max_tokens: usize,
29}
30
31impl ContextQuery {
32    /// Create a new context query.
33    pub fn new(file: PathBuf, line: u32) -> Self {
34        Self {
35            file,
36            line,
37            surrounding_lines: 5,
38            intent: None,
39            max_tokens: 4000,
40        }
41    }
42
43    /// Set the number of surrounding lines to include.
44    pub fn with_surrounding_lines(mut self, lines: u32) -> Self {
45        self.surrounding_lines = lines;
46        self
47    }
48
49    /// Set the intent description.
50    pub fn with_intent(mut self, intent: String) -> Self {
51        self.intent = Some(intent);
52        self
53    }
54
55    /// Set the maximum token budget.
56    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60}
61
62/// Result of context assembly.
63#[derive(Debug, Clone)]
64pub struct ContextResult {
65    /// Primary chunks - most relevant to the query
66    pub primary: Vec<ContextChunk>,
67    /// Related chunks - less relevant but still useful
68    pub related: Vec<ContextChunk>,
69    /// Total estimated tokens in the result
70    pub total_tokens: usize,
71}
72
73impl ContextResult {
74    /// Create an empty context result.
75    pub fn empty() -> Self {
76        Self {
77            primary: Vec::new(),
78            related: Vec::new(),
79            total_tokens: 0,
80        }
81    }
82
83    /// Get all chunks in priority order (primary first, then related).
84    pub fn all_chunks(&self) -> Vec<&ContextChunk> {
85        self.primary.iter().chain(self.related.iter()).collect()
86    }
87}
88
89/// A chunk of context with metadata.
90#[derive(Debug, Clone)]
91pub struct ContextChunk {
92    /// Optional symbol this chunk represents
93    pub symbol: Option<InternedString>,
94    /// File containing this chunk
95    pub file: PathBuf,
96    /// The actual content
97    pub content: String,
98    /// Relevance score (0.0 - 1.0)
99    pub relevance: f64,
100    /// Explanation of why this was included
101    pub reason: String,
102}
103
104impl ContextChunk {
105    /// Estimate the number of tokens in this chunk.
106    /// Uses a simple heuristic: ~4 chars per token.
107    pub fn estimate_tokens(&self) -> usize {
108        self.content.len() / 4
109    }
110}
111
112// ============================================================================
113// Context Synthesizer
114// ============================================================================
115
116/// Synthesizes relevant context from the code index.
117pub struct ContextSynthesizer;
118
119impl ContextSynthesizer {
120    /// Create a new context synthesizer.
121    pub fn new() -> Self {
122        Self
123    }
124
125    /// Build context for a given query.
126    pub async fn build_context(
127        &self,
128        state: &OciState,
129        query: &ContextQuery,
130    ) -> Result<ContextResult> {
131        // Step 1: Find the symbol at the query location
132        let symbol_at_location = self.find_symbol_at_location(state, &query.file, query.line);
133
134        // Step 2: Collect candidate symbols
135        let mut candidates = Vec::new();
136
137        if let Some(current_symbol) = symbol_at_location {
138            // Add the current symbol with highest priority
139            candidates.push((current_symbol, 1.0, "Current location".to_string()));
140
141            // Find callees (functions this symbol calls)
142            let callees = state.find_callees(current_symbol);
143            for call_edge in callees {
144                // Try to resolve the callee to a scoped name
145                if let Some(resolved) = self.resolve_callee(state, &call_edge.callee_name) {
146                    candidates.push((
147                        resolved,
148                        0.8,
149                        format!("Called by {}", state.resolve(current_symbol)),
150                    ));
151                }
152            }
153
154            // Find callers (functions that call this symbol)
155            let current_name = state.resolve(current_symbol);
156            let callers = state.find_callers(current_name);
157            for call_edge in callers.iter().take(5) {
158                // Limit callers to avoid explosion
159                candidates.push((call_edge.caller, 0.6, format!("Calls {}", current_name)));
160            }
161
162            // Find related types (from signatures)
163            if let Some(symbol_def) = state.get_symbol(current_symbol) {
164                if let Some(sig) = &symbol_def.signature {
165                    // Extract types from parameters and return type
166                    let types = self.extract_types_from_signature(sig);
167                    for type_name in types {
168                        if let Some(type_symbol) = self.find_type_symbol(state, &type_name) {
169                            candidates.push((
170                                type_symbol,
171                                0.5,
172                                format!("Type used in signature: {}", type_name),
173                            ));
174                        }
175                    }
176                }
177
178                // Find parent symbol (e.g., impl block for methods)
179                if let Some(parent) = symbol_def.parent {
180                    candidates.push((parent, 0.7, format!("Parent of {}", current_name)));
181                }
182            }
183        }
184
185        // Find relevant imports
186        let file_id = state.get_or_create_file_id(&query.file);
187        if let Some(imports) = state.imports.get(&file_id) {
188            for import in imports.iter().take(5) {
189                // Limit imports
190                // Try to find symbols matching the imported names
191                if let Some(import_symbol) = self.find_symbol_by_name(state, &import.name) {
192                    candidates.push((import_symbol, 0.4, format!("Imported: {}", import.name)));
193                }
194            }
195        }
196
197        // Step 3: Rank all candidates
198        let ranked = self.rank_symbols_with_reasons(state, candidates);
199
200        // Step 4: Build chunks with token budget
201        let mut primary_chunks = Vec::new();
202        let mut related_chunks = Vec::new();
203        let mut total_tokens = 0;
204        let mut seen_symbols = HashSet::new();
205
206        // First, add the query location itself
207        if let Ok(location_chunk) = self
208            .create_location_chunk(state, &query.file, query.line, query.surrounding_lines)
209            .await
210        {
211            total_tokens += location_chunk.estimate_tokens();
212            primary_chunks.push(location_chunk);
213        }
214
215        // Then add ranked symbols
216        for (symbol, score, reason) in ranked {
217            if total_tokens >= query.max_tokens {
218                break;
219            }
220
221            // Avoid duplicates
222            if seen_symbols.contains(&symbol) {
223                continue;
224            }
225            seen_symbols.insert(symbol);
226
227            // Create chunk for this symbol
228            if let Ok(chunk) = self.create_symbol_chunk(state, symbol, score, reason).await {
229                let chunk_tokens = chunk.estimate_tokens();
230
231                if total_tokens + chunk_tokens > query.max_tokens {
232                    // Would exceed budget - skip
233                    continue;
234                }
235
236                total_tokens += chunk_tokens;
237
238                // Categorize as primary or related based on score
239                if score >= 0.6 {
240                    primary_chunks.push(chunk);
241                } else {
242                    related_chunks.push(chunk);
243                }
244            }
245        }
246
247        Ok(ContextResult {
248            primary: primary_chunks,
249            related: related_chunks,
250            total_tokens,
251        })
252    }
253
254    /// Rank symbols by relevance, returning (symbol, score) pairs.
255    pub fn rank_symbols(
256        &self,
257        state: &OciState,
258        symbols: &[InternedString],
259    ) -> Vec<(InternedString, f64)> {
260        let candidates: Vec<_> = symbols
261            .iter()
262            .map(|s| (*s, 1.0, "Candidate".to_string()))
263            .collect();
264
265        self.rank_symbols_with_reasons(state, candidates)
266            .into_iter()
267            .map(|(sym, score, _)| (sym, score))
268            .collect()
269    }
270
271    // ========================================================================
272    // Private Helper Methods
273    // ========================================================================
274
275    /// Find the symbol at a given file location.
276    fn find_symbol_at_location(
277        &self,
278        state: &OciState,
279        file: &PathBuf,
280        line: u32,
281    ) -> Option<InternedString> {
282        // Get symbols in this file
283        let file_id = state.file_ids.get(file)?;
284        let file_symbols = state.file_symbols.get(&file_id)?;
285
286        // Find symbol containing this line
287        for scoped_name in file_symbols.iter() {
288            if let Some(symbol) = state.get_symbol(*scoped_name) {
289                if symbol.location.start_line <= line as usize
290                    && symbol.location.end_line >= line as usize
291                {
292                    return Some(*scoped_name);
293                }
294            }
295        }
296
297        None
298    }
299
300    /// Resolve a callee name to a scoped symbol.
301    fn resolve_callee(&self, state: &OciState, callee_name: &str) -> Option<InternedString> {
302        // Find symbols with this simple name
303        let symbols = state.find_by_name(callee_name);
304        if symbols.is_empty() {
305            return None;
306        }
307
308        // Prefer public symbols
309        for symbol in &symbols {
310            if matches!(symbol.visibility, crate::types::Visibility::Public) {
311                return Some(symbol.scoped_name);
312            }
313        }
314
315        // Otherwise, return the first one
316        symbols.first().map(|s| s.scoped_name)
317    }
318
319    /// Extract type names from a signature.
320    fn extract_types_from_signature(&self, sig: &crate::types::Signature) -> Vec<String> {
321        let mut types = Vec::new();
322
323        // Extract from parameters
324        for param in &sig.params {
325            if let Some(type_name) = self.extract_type_name(param) {
326                types.push(type_name);
327            }
328        }
329
330        // Extract from return type
331        if let Some(ret_type) = &sig.return_type {
332            if let Some(type_name) = self.extract_type_name(ret_type) {
333                types.push(type_name);
334            }
335        }
336
337        types
338    }
339
340    /// Extract a type name from a parameter or return type string.
341    fn extract_type_name(&self, type_str: &str) -> Option<String> {
342        // Simple heuristic: extract the main type name
343        // e.g., "Vec<String>" -> "Vec", "&mut Foo" -> "Foo"
344        let trimmed = type_str.trim();
345
346        // Remove reference markers
347        let without_refs = trimmed.trim_start_matches('&').trim_start_matches("mut ");
348
349        // Take the part before '<' or whitespace
350        let main_type = without_refs.split('<').next()?.split_whitespace().next()?;
351
352        if main_type.is_empty() || main_type.starts_with(char::is_lowercase) {
353            // Likely a primitive type or keyword
354            None
355        } else {
356            Some(main_type.to_string())
357        }
358    }
359
360    /// Find a type symbol (struct, enum, trait) by name.
361    fn find_type_symbol(&self, state: &OciState, type_name: &str) -> Option<InternedString> {
362        let symbols = state.find_by_name(type_name);
363
364        for symbol in symbols {
365            if matches!(
366                symbol.kind,
367                SymbolKind::Struct | SymbolKind::Enum | SymbolKind::Trait
368            ) {
369                return Some(symbol.scoped_name);
370            }
371        }
372
373        None
374    }
375
376    /// Find a symbol by simple name (prefer public, prefer higher PageRank).
377    fn find_symbol_by_name(&self, state: &OciState, name: &str) -> Option<InternedString> {
378        let symbols = state.find_by_name(name);
379        if symbols.is_empty() {
380            return None;
381        }
382
383        // Rank them and pick the best
384        let scoped_names: Vec<_> = symbols.iter().map(|s| s.scoped_name).collect();
385        let ranked = self.rank_symbols(state, &scoped_names);
386
387        ranked.first().map(|(sym, _)| *sym)
388    }
389
390    /// Rank symbols with reasons, considering PageRank and other factors.
391    fn rank_symbols_with_reasons(
392        &self,
393        state: &OciState,
394        candidates: Vec<(InternedString, f64, String)>,
395    ) -> Vec<(InternedString, f64, String)> {
396        let mut scored: Vec<_> = candidates
397            .into_iter()
398            .map(|(symbol, base_score, reason)| {
399                let pagerank_score = self.get_pagerank_score(state, symbol);
400                let combined_score = base_score * 0.7 + pagerank_score * 0.3;
401                (symbol, combined_score, reason)
402            })
403            .collect();
404
405        // Sort by score descending
406        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408        scored
409    }
410
411    /// Get the PageRank score for a symbol.
412    fn get_pagerank_score(&self, state: &OciState, symbol: InternedString) -> f64 {
413        // Get the symbol definition
414        let symbol_def = match state.get_symbol(symbol) {
415            Some(s) => s,
416            None => return 0.0,
417        };
418
419        // Find the topology node for this file
420        let node_idx = match state.path_to_node.get(&symbol_def.location.file) {
421            Some(idx) => *idx,
422            None => return 0.0,
423        };
424
425        // Get the metrics for this node
426        state
427            .topology_metrics
428            .get(&node_idx)
429            .map(|m| m.relevance_score)
430            .unwrap_or(0.0)
431    }
432
433    /// Create a chunk for a specific file location.
434    async fn create_location_chunk(
435        &self,
436        state: &OciState,
437        file: &PathBuf,
438        line: u32,
439        surrounding_lines: u32,
440    ) -> Result<ContextChunk> {
441        let contents = state
442            .get_file_contents(file)
443            .await
444            .context("Failed to read file")?;
445
446        let lines: Vec<&str> = contents.lines().collect();
447        let total_lines = lines.len();
448
449        let start_line = (line as i32 - surrounding_lines as i32).max(0) as usize;
450        let end_line = ((line + surrounding_lines) as usize).min(total_lines);
451
452        let content = lines[start_line..end_line].join("\n");
453
454        Ok(ContextChunk {
455            symbol: None,
456            file: file.clone(),
457            content,
458            relevance: 1.0,
459            reason: format!("Query location at line {}", line),
460        })
461    }
462
463    /// Create a chunk for a symbol.
464    async fn create_symbol_chunk(
465        &self,
466        state: &OciState,
467        symbol: InternedString,
468        relevance: f64,
469        reason: String,
470    ) -> Result<ContextChunk> {
471        let symbol_def = state
472            .get_symbol(symbol)
473            .context("Symbol not found in state")?;
474
475        let contents = state
476            .get_file_contents(&symbol_def.location.file)
477            .await
478            .context("Failed to read file")?;
479
480        let lines: Vec<&str> = contents.lines().collect();
481
482        // Extract the symbol's definition
483        let start_line = symbol_def.location.start_line.saturating_sub(1);
484        let end_line = symbol_def.location.end_line;
485
486        if start_line >= lines.len() || end_line > lines.len() {
487            anyhow::bail!("Symbol location out of bounds");
488        }
489
490        let content = lines[start_line..end_line].join("\n");
491
492        Ok(ContextChunk {
493            symbol: Some(symbol),
494            file: symbol_def.location.file.clone(),
495            content,
496            relevance,
497            reason,
498        })
499    }
500}
501
502impl Default for ContextSynthesizer {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508// ============================================================================
509// Tests
510// ============================================================================
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::state::create_state;
516    use tempfile::TempDir;
517
518    #[tokio::test]
519    async fn test_build_empty_context() {
520        let temp = TempDir::new().unwrap();
521        let state = create_state(temp.path().to_path_buf());
522        let synthesizer = ContextSynthesizer::new();
523
524        let test_file = temp.path().join("test.rs");
525        std::fs::write(&test_file, "fn main() {}").unwrap();
526
527        let query = ContextQuery::new(test_file, 1);
528        let result = synthesizer.build_context(&state, &query).await;
529
530        // Should succeed even with no indexed data
531        assert!(result.is_ok());
532    }
533
534    #[test]
535    fn test_rank_symbols_empty() {
536        let temp = TempDir::new().unwrap();
537        let state = create_state(temp.path().to_path_buf());
538        let synthesizer = ContextSynthesizer::new();
539
540        let ranked = synthesizer.rank_symbols(&state, &[]);
541        assert!(ranked.is_empty());
542    }
543
544    #[test]
545    fn test_extract_type_name() {
546        let synthesizer = ContextSynthesizer::new();
547
548        assert_eq!(
549            synthesizer.extract_type_name("Vec<String>"),
550            Some("Vec".to_string())
551        );
552        assert_eq!(
553            synthesizer.extract_type_name("&mut Foo"),
554            Some("Foo".to_string())
555        );
556        assert_eq!(
557            synthesizer.extract_type_name("&Bar"),
558            Some("Bar".to_string())
559        );
560        assert_eq!(synthesizer.extract_type_name("i32"), None); // Primitive
561    }
562
563    #[tokio::test]
564    async fn test_context_chunk_token_estimation() {
565        let chunk = ContextChunk {
566            symbol: None,
567            file: PathBuf::from("test.rs"),
568            content: "a".repeat(400), // 400 chars
569            relevance: 1.0,
570            reason: "Test".to_string(),
571        };
572
573        // Should be ~100 tokens (400 / 4)
574        assert_eq!(chunk.estimate_tokens(), 100);
575    }
576
577    #[tokio::test]
578    async fn test_context_query_builder() {
579        let query = ContextQuery::new(PathBuf::from("test.rs"), 10)
580            .with_surrounding_lines(3)
581            .with_max_tokens(2000)
582            .with_intent("Testing".to_string());
583
584        assert_eq!(query.line, 10);
585        assert_eq!(query.surrounding_lines, 3);
586        assert_eq!(query.max_tokens, 2000);
587        assert_eq!(query.intent, Some("Testing".to_string()));
588    }
589
590    #[test]
591    fn test_context_result_all_chunks() {
592        let mut result = ContextResult::empty();
593
594        result.primary.push(ContextChunk {
595            symbol: None,
596            file: PathBuf::from("a.rs"),
597            content: "primary".to_string(),
598            relevance: 1.0,
599            reason: "test".to_string(),
600        });
601
602        result.related.push(ContextChunk {
603            symbol: None,
604            file: PathBuf::from("b.rs"),
605            content: "related".to_string(),
606            relevance: 0.5,
607            reason: "test".to_string(),
608        });
609
610        let all = result.all_chunks();
611        assert_eq!(all.len(), 2);
612        assert_eq!(all[0].content, "primary");
613        assert_eq!(all[1].content, "related");
614    }
615}