infiniloom_engine/
budget.rs

1//! Smart token budget enforcement with binary search truncation
2//!
3//! This module provides intelligent content truncation to fit within
4//! a token budget while preserving semantic boundaries (line ends,
5//! function boundaries, etc.).
6//!
7//! # Overview
8//!
9//! When processing repositories, the total content often exceeds the
10//! context window of target LLMs. The `BudgetEnforcer` intelligently
11//! truncates content by:
12//!
13//! 1. **Prioritizing important files** - Files with higher importance scores are kept first
14//! 2. **Binary search truncation** - Efficiently finds the optimal cut point
15//! 3. **Semantic boundaries** - Truncates at meaningful boundaries (line, function)
16//!
17//! # Example
18//!
19//! ```rust,ignore
20//! use infiniloom_engine::budget::{BudgetEnforcer, BudgetConfig, TruncationStrategy};
21//! use infiniloom_engine::{Repository, TokenModel};
22//!
23//! // Create a budget enforcer with 50K token limit
24//! let config = BudgetConfig {
25//!     budget: 50_000,
26//!     model: TokenModel::Claude,
27//!     strategy: TruncationStrategy::Semantic,
28//!     overhead_reserve: 1000,
29//! };
30//! let enforcer = BudgetEnforcer::new(config);
31//!
32//! // Enforce budget on repository
33//! let mut repo = Repository::new("my-project", "/path");
34//! let result = enforcer.enforce(&mut repo);
35//!
36//! println!("Used {:.1}% of budget", result.budget_used_pct);
37//! println!("{} files truncated, {} excluded", result.truncated_files, result.excluded_files);
38//! ```
39//!
40//! # Truncation Strategies
41//!
42//! - **Line**: Truncates at newline boundaries (default, fast)
43//! - **Semantic**: Truncates at function/class boundaries (slower, preserves context)
44//! - **Hard**: Truncates at exact byte position (fastest, may break mid-statement)
45
46use crate::constants::budget as budget_consts;
47use crate::newtypes::TokenCount;
48use crate::tokenizer::{TokenModel, Tokenizer};
49use crate::types::Repository;
50
51/// Budget enforcement strategies
52#[derive(Debug, Clone, Copy, Default)]
53pub enum TruncationStrategy {
54    /// Truncate at line boundaries (default)
55    #[default]
56    Line,
57    /// Truncate at function/class boundaries
58    Semantic,
59    /// Hard truncation with "..." suffix
60    Hard,
61}
62
63/// Configuration for budget enforcement
64#[derive(Debug, Clone, Copy)]
65pub struct BudgetConfig {
66    /// Total token budget
67    pub budget: TokenCount,
68    /// Target tokenizer model
69    pub model: TokenModel,
70    /// Truncation strategy
71    pub strategy: TruncationStrategy,
72    /// Reserve tokens for overhead (headers, map, etc.)
73    pub overhead_reserve: TokenCount,
74}
75
76impl Default for BudgetConfig {
77    fn default() -> Self {
78        Self {
79            budget: TokenCount::new(budget_consts::DEFAULT_BUDGET),
80            model: TokenModel::Claude,
81            strategy: TruncationStrategy::Line,
82            overhead_reserve: TokenCount::new(budget_consts::OVERHEAD_RESERVE),
83        }
84    }
85}
86
87/// Smart token budget enforcer using binary search
88pub struct BudgetEnforcer {
89    config: BudgetConfig,
90    tokenizer: Tokenizer,
91}
92
93impl BudgetEnforcer {
94    /// Create a new budget enforcer with the given configuration
95    pub fn new(config: BudgetConfig) -> Self {
96        Self { config, tokenizer: Tokenizer::new() }
97    }
98
99    /// Create with just budget and model
100    pub fn with_budget(budget: u32, model: TokenModel) -> Self {
101        Self::new(BudgetConfig { budget: TokenCount::new(budget), model, ..Default::default() })
102    }
103
104    /// Enforce budget on repository, truncating file contents as needed
105    ///
106    /// Files are processed in importance order (highest first).
107    /// Returns the number of files that were truncated.
108    pub fn enforce(&self, repo: &mut Repository) -> EnforcementResult {
109        let available_budget = self.config.budget.saturating_sub(self.config.overhead_reserve);
110        let mut used_tokens = TokenCount::zero();
111        let mut truncated_count = 0usize;
112        let mut excluded_count = 0usize;
113        let min_partial = TokenCount::new(budget_consts::MIN_PARTIAL_FIT_TOKENS);
114
115        // Sort files by importance (descending)
116        let mut file_indices: Vec<usize> = (0..repo.files.len()).collect();
117        file_indices.sort_by(|&a, &b| {
118            repo.files[b]
119                .importance
120                .partial_cmp(&repo.files[a].importance)
121                .unwrap_or(std::cmp::Ordering::Equal)
122        });
123
124        for idx in file_indices {
125            let file = &mut repo.files[idx];
126
127            if let Some(content) = file.content.as_ref() {
128                let file_tokens = TokenCount::new(self.count_tokens(content));
129
130                if used_tokens + file_tokens <= available_budget {
131                    // File fits entirely
132                    used_tokens += file_tokens;
133                } else if used_tokens + min_partial < available_budget {
134                    // Partial fit - truncate to remaining budget
135                    let remaining = available_budget.saturating_sub(used_tokens);
136                    let truncated = self.truncate_to_tokens(content, remaining.get());
137                    let truncated_tokens = TokenCount::new(self.count_tokens(&truncated));
138
139                    file.content = Some(truncated);
140                    used_tokens += truncated_tokens;
141                    truncated_count += 1;
142                } else {
143                    // No room - exclude content
144                    file.content = None;
145                    excluded_count += 1;
146                }
147            }
148        }
149
150        EnforcementResult {
151            total_tokens: used_tokens,
152            truncated_files: truncated_count,
153            excluded_files: excluded_count,
154            budget_used_pct: used_tokens.percentage_of(available_budget),
155        }
156    }
157
158    /// Count tokens in text using configured model
159    fn count_tokens(&self, text: &str) -> u32 {
160        self.tokenizer.count(text, self.config.model)
161    }
162
163    /// Truncate content to fit within max_tokens using binary search
164    ///
165    /// Uses binary search to find the optimal cut point, then adjusts
166    /// to the nearest semantic boundary.
167    pub fn truncate_to_tokens(&self, content: &str, max_tokens: u32) -> String {
168        // Quick check if content already fits
169        let total_tokens = self.count_tokens(content);
170        if total_tokens <= max_tokens {
171            return content.to_owned();
172        }
173
174        // Binary search for optimal byte position
175        let mut low = 0usize;
176        let mut high = content.len();
177        let mut best_pos = 0usize;
178
179        while low < high {
180            let mid = (low + high).div_ceil(2);
181
182            // Ensure we don't split a UTF-8 character
183            let safe_mid = self.find_char_boundary(content, mid);
184            let slice = &content[..safe_mid];
185            let tokens = self.count_tokens(slice);
186
187            if tokens <= max_tokens {
188                best_pos = safe_mid;
189                low = mid;
190            } else {
191                high = mid - 1;
192            }
193        }
194
195        // Find semantic boundary near best_pos
196        let boundary = self.find_semantic_boundary(content, best_pos);
197
198        // Add truncation indicator
199        let mut result = content[..boundary].to_owned();
200        if boundary < content.len() {
201            result.push_str("\n\n... [truncated]");
202        }
203
204        result
205    }
206
207    /// Find a valid UTF-8 character boundary at or before position
208    fn find_char_boundary(&self, s: &str, pos: usize) -> usize {
209        if pos >= s.len() {
210            return s.len();
211        }
212
213        let mut boundary = pos;
214        while boundary > 0 && !s.is_char_boundary(boundary) {
215            boundary -= 1;
216        }
217        boundary
218    }
219
220    /// Find a semantic boundary (line end, function end, etc.) near position
221    fn find_semantic_boundary(&self, content: &str, pos: usize) -> usize {
222        if pos == 0 || pos >= content.len() {
223            return pos;
224        }
225
226        let slice = &content[..pos];
227
228        match self.config.strategy {
229            TruncationStrategy::Hard => pos,
230            TruncationStrategy::Line => {
231                // Find last newline
232                slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
233            },
234            TruncationStrategy::Semantic => {
235                // Try to find function/class boundary first
236                if let Some(boundary) = self.find_function_boundary(slice) {
237                    return boundary;
238                }
239                // Fall back to line boundary
240                slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
241            },
242        }
243    }
244
245    /// Find a function/class boundary in the content
246    fn find_function_boundary(&self, content: &str) -> Option<usize> {
247        // Look for common function/class start patterns from the end
248        let patterns = [
249            "\n\nfn ",       // Rust
250            "\n\ndef ",      // Python
251            "\n\nclass ",    // Python/JS
252            "\n\nfunction ", // JavaScript
253            "\n\npub fn ",   // Rust public
254            "\n\nasync ",    // JavaScript async
255            "\n\nimpl ",     // Rust impl
256            "\n\n#[",        // Rust attributes
257            "\n\n@",         // Decorators
258        ];
259
260        // Search from the end for function boundaries
261        let mut best_pos = None;
262        for pattern in patterns {
263            if let Some(pos) = content.rfind(pattern) {
264                // Check if this position is better (closer to end)
265                if best_pos.map_or(true, |bp| pos > bp) {
266                    best_pos = Some(pos);
267                }
268            }
269        }
270
271        // Return position after the double newline but before the pattern keyword
272        // The +2 accounts for "\n\n" - we want to include the newlines but start
273        // before the actual function/class keyword
274        best_pos.map(|p| {
275            // Validate bounds to prevent off-by-one errors
276            let boundary = p + 2;
277            if boundary <= content.len() {
278                boundary
279            } else {
280                // Fallback to just after the match position if bounds check fails
281                (p + 1).min(content.len())
282            }
283        })
284    }
285}
286
287/// Result of budget enforcement
288#[derive(Debug, Clone)]
289pub struct EnforcementResult {
290    /// Total tokens used after enforcement
291    pub total_tokens: TokenCount,
292    /// Number of files that were truncated
293    pub truncated_files: usize,
294    /// Number of files excluded entirely
295    pub excluded_files: usize,
296    /// Percentage of budget used
297    pub budget_used_pct: f32,
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_truncate_preserves_short_content() {
306        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
307        let content = "Hello, world!";
308        let result = enforcer.truncate_to_tokens(content, 1000);
309        assert_eq!(result, content);
310    }
311
312    #[test]
313    fn test_truncate_adds_indicator() {
314        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
315        let content = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10";
316        let result = enforcer.truncate_to_tokens(content, 5);
317        assert!(result.contains("[truncated]"));
318        assert!(result.len() < content.len());
319    }
320
321    #[test]
322    fn test_find_char_boundary() {
323        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
324        let content = "Hello, δΈ–η•Œ!"; // Multi-byte UTF-8
325        let boundary = enforcer.find_char_boundary(content, 8);
326        // Should find boundary at valid UTF-8 position
327        assert!(content.is_char_boundary(boundary));
328    }
329
330    #[test]
331    fn test_semantic_boundary_line() {
332        let config = BudgetConfig { strategy: TruncationStrategy::Line, ..Default::default() };
333        let enforcer = BudgetEnforcer::new(config);
334        let content = "line1\nline2\nline3";
335        let boundary = enforcer.find_semantic_boundary(content, 10);
336        // Should find boundary after "line1\n"
337        assert_eq!(boundary, 6);
338    }
339
340    #[test]
341    fn test_semantic_boundary_function() {
342        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
343        let enforcer = BudgetEnforcer::new(config);
344        let content = "fn foo() {}\n\ndef bar():\n    pass";
345        let boundary = enforcer.find_semantic_boundary(content, content.len());
346        // Should find boundary at "def bar"
347        assert!(boundary > 10);
348    }
349
350    // =========================================================================
351    // Additional edge case tests for comprehensive coverage
352    // =========================================================================
353
354    #[test]
355    fn test_empty_content_truncation() {
356        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
357        let result = enforcer.truncate_to_tokens("", 100);
358        assert_eq!(result, "");
359    }
360
361    #[test]
362    fn test_single_character_content() {
363        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
364        let result = enforcer.truncate_to_tokens("x", 100);
365        assert_eq!(result, "x");
366    }
367
368    #[test]
369    fn test_zero_budget_truncation() {
370        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
371        let content = "Some content that will be truncated";
372        let result = enforcer.truncate_to_tokens(content, 0);
373        // Should return minimal content or truncation indicator
374        assert!(result.len() <= content.len());
375    }
376
377    #[test]
378    fn test_unicode_boundary_preservation() {
379        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
380        // Content with multi-byte UTF-8 characters
381        let content = "Hello δΈ–η•Œ! More text here. πŸ¦€ Rust! Even more...";
382
383        // Truncate to various budgets
384        for budget in [1, 2, 3, 5, 10] {
385            let result = enforcer.truncate_to_tokens(content, budget);
386            // Verify we can still iterate over chars (valid UTF-8)
387            let _ = result.chars().count();
388            // Verify the string is valid
389            assert!(std::str::from_utf8(result.as_bytes()).is_ok());
390        }
391    }
392
393    #[test]
394    fn test_content_smaller_than_indicator() {
395        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
396        // Very small content
397        let content = "Hi";
398        let result = enforcer.truncate_to_tokens(content, 1);
399        // Should handle gracefully
400        assert!(!result.is_empty() || content.is_empty());
401    }
402
403    #[test]
404    fn test_hard_truncation_strategy() {
405        let config = BudgetConfig { strategy: TruncationStrategy::Hard, ..Default::default() };
406        let enforcer = BudgetEnforcer::new(config);
407        let content = "line1\nline2\nline3";
408        let boundary = enforcer.find_semantic_boundary(content, 10);
409        // Hard strategy should return exact position
410        assert_eq!(boundary, 10);
411    }
412
413    #[test]
414    fn test_boundary_at_start() {
415        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
416        let content = "Some content";
417        let boundary = enforcer.find_semantic_boundary(content, 0);
418        assert_eq!(boundary, 0);
419    }
420
421    #[test]
422    fn test_boundary_past_end() {
423        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
424        let content = "Some content";
425        let boundary = enforcer.find_semantic_boundary(content, content.len() + 10);
426        // Should clamp to content length
427        assert_eq!(boundary, content.len() + 10); // Returns pos as-is when >= len
428    }
429
430    #[test]
431    fn test_function_boundary_rust_patterns() {
432        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
433        let enforcer = BudgetEnforcer::new(config);
434
435        // Test various Rust function patterns
436        let content = "use std::io;\n\nfn helper() {}\n\npub fn main() {}";
437        let boundary = enforcer.find_function_boundary(content);
438        assert!(boundary.is_some());
439
440        // Test impl block
441        let content2 = "struct Foo;\n\nimpl Foo {\n    fn new() {}\n}";
442        let boundary2 = enforcer.find_function_boundary(content2);
443        assert!(boundary2.is_some());
444    }
445
446    #[test]
447    fn test_function_boundary_python_patterns() {
448        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
449        let enforcer = BudgetEnforcer::new(config);
450
451        // Test Python function with decorator
452        let content = "import os\n\n@decorator\ndef foo():\n    pass";
453        let boundary = enforcer.find_function_boundary(content);
454        assert!(boundary.is_some());
455
456        // Test Python class
457        let content2 = "import sys\n\nclass MyClass:\n    pass";
458        let boundary2 = enforcer.find_function_boundary(content2);
459        assert!(boundary2.is_some());
460    }
461
462    #[test]
463    fn test_function_boundary_javascript_patterns() {
464        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
465        let enforcer = BudgetEnforcer::new(config);
466
467        // Test JavaScript function
468        let content = "const x = 1;\n\nfunction foo() {}\n\nasync function bar() {}";
469        let boundary = enforcer.find_function_boundary(content);
470        assert!(boundary.is_some());
471    }
472
473    #[test]
474    fn test_no_function_boundary_found() {
475        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
476        let enforcer = BudgetEnforcer::new(config);
477
478        // Content without any function boundaries
479        let content = "just some text without any code patterns";
480        let boundary = enforcer.find_function_boundary(content);
481        assert!(boundary.is_none());
482    }
483
484    #[test]
485    fn test_enforcement_result_fields() {
486        let result = EnforcementResult {
487            total_tokens: TokenCount::new(1000),
488            truncated_files: 5,
489            excluded_files: 2,
490            budget_used_pct: 85.5,
491        };
492
493        assert_eq!(result.total_tokens.get(), 1000);
494        assert_eq!(result.truncated_files, 5);
495        assert_eq!(result.excluded_files, 2);
496        assert!((result.budget_used_pct - 85.5).abs() < 0.01);
497    }
498
499    #[test]
500    fn test_budget_config_default() {
501        use crate::constants::budget as budget_consts;
502        let config = BudgetConfig::default();
503        assert_eq!(config.budget.get(), budget_consts::DEFAULT_BUDGET);
504        assert!(matches!(config.strategy, TruncationStrategy::Line));
505        assert_eq!(config.overhead_reserve.get(), budget_consts::OVERHEAD_RESERVE);
506    }
507}