infiniloom_engine/
budget.rs

1//! Smart token budget enforcement with binary search truncation
2//!
3//! This module provides intelligent content truncation to fit within
4//! a token budget while preserving semantic boundaries (line ends,
5//! function boundaries, etc.).
6//!
7//! # Overview
8//!
9//! When processing repositories, the total content often exceeds the
10//! context window of target LLMs. The `BudgetEnforcer` intelligently
11//! truncates content by:
12//!
13//! 1. **Prioritizing important files** - Files with higher importance scores are kept first
14//! 2. **Binary search truncation** - Efficiently finds the optimal cut point
15//! 3. **Semantic boundaries** - Truncates at meaningful boundaries (line, function)
16//!
17//! # Example
18//!
19//! ```rust,ignore
20//! use infiniloom_engine::budget::{BudgetEnforcer, BudgetConfig, TruncationStrategy};
21//! use infiniloom_engine::{Repository, TokenModel};
22//!
23//! // Create a budget enforcer with 50K token limit
24//! let config = BudgetConfig {
25//!     budget: 50_000,
26//!     model: TokenModel::Claude,
27//!     strategy: TruncationStrategy::Semantic,
28//!     overhead_reserve: 1000,
29//! };
30//! let enforcer = BudgetEnforcer::new(config);
31//!
32//! // Enforce budget on repository
33//! let mut repo = Repository::new("my-project", "/path");
34//! let result = enforcer.enforce(&mut repo);
35//!
36//! println!("Used {:.1}% of budget", result.budget_used_pct);
37//! println!("{} files truncated, {} excluded", result.truncated_files, result.excluded_files);
38//! ```
39//!
40//! # Truncation Strategies
41//!
42//! - **Line**: Truncates at newline boundaries (default, fast)
43//! - **Semantic**: Truncates at function/class boundaries (slower, preserves context)
44//! - **Hard**: Truncates at exact byte position (fastest, may break mid-statement)
45
46use crate::constants::budget as budget_consts;
47use crate::newtypes::TokenCount;
48use crate::tokenizer::{TokenModel, Tokenizer};
49use crate::types::Repository;
50
51/// Budget enforcement strategies
52#[derive(Debug, Clone, Copy, Default)]
53pub enum TruncationStrategy {
54    /// Truncate at line boundaries (default)
55    #[default]
56    Line,
57    /// Truncate at function/class boundaries
58    Semantic,
59    /// Hard truncation with "..." suffix
60    Hard,
61}
62
63/// Configuration for budget enforcement
64#[derive(Debug, Clone, Copy)]
65pub struct BudgetConfig {
66    /// Total token budget
67    pub budget: TokenCount,
68    /// Target tokenizer model
69    pub model: TokenModel,
70    /// Truncation strategy
71    pub strategy: TruncationStrategy,
72    /// Reserve tokens for overhead (headers, map, etc.)
73    pub overhead_reserve: TokenCount,
74}
75
76impl Default for BudgetConfig {
77    fn default() -> Self {
78        Self {
79            budget: TokenCount::new(budget_consts::DEFAULT_BUDGET),
80            model: TokenModel::Claude,
81            strategy: TruncationStrategy::Line,
82            overhead_reserve: TokenCount::new(budget_consts::OVERHEAD_RESERVE),
83        }
84    }
85}
86
87/// Smart token budget enforcer using binary search
88pub struct BudgetEnforcer {
89    config: BudgetConfig,
90    tokenizer: Tokenizer,
91}
92
93impl BudgetEnforcer {
94    /// Create a new budget enforcer with the given configuration
95    pub fn new(config: BudgetConfig) -> Self {
96        Self { config, tokenizer: Tokenizer::new() }
97    }
98
99    /// Create with just budget and model
100    pub fn with_budget(budget: u32, model: TokenModel) -> Self {
101        Self::new(BudgetConfig { budget: TokenCount::new(budget), model, ..Default::default() })
102    }
103
104    /// Enforce budget on repository, truncating file contents as needed
105    ///
106    /// Files are processed in importance order (highest first).
107    /// Returns the number of files that were truncated.
108    pub fn enforce(&self, repo: &mut Repository) -> EnforcementResult {
109        let available_budget = self
110            .config
111            .budget
112            .saturating_sub(self.config.overhead_reserve);
113        let mut used_tokens = TokenCount::zero();
114        let mut truncated_count = 0usize;
115        let mut excluded_count = 0usize;
116        let min_partial = TokenCount::new(budget_consts::MIN_PARTIAL_FIT_TOKENS);
117
118        // Sort files by importance (descending)
119        let mut file_indices: Vec<usize> = (0..repo.files.len()).collect();
120        file_indices.sort_by(|&a, &b| {
121            repo.files[b]
122                .importance
123                .partial_cmp(&repo.files[a].importance)
124                .unwrap_or(std::cmp::Ordering::Equal)
125        });
126
127        for idx in file_indices {
128            let file = &mut repo.files[idx];
129
130            if let Some(content) = file.content.as_ref() {
131                let file_tokens = TokenCount::new(self.count_tokens(content));
132
133                if used_tokens + file_tokens <= available_budget {
134                    // File fits entirely
135                    used_tokens += file_tokens;
136                } else if used_tokens + min_partial < available_budget {
137                    // Partial fit - truncate to remaining budget
138                    let remaining = available_budget.saturating_sub(used_tokens);
139                    let truncated = self.truncate_to_tokens(content, remaining.get());
140                    let truncated_tokens = TokenCount::new(self.count_tokens(&truncated));
141
142                    file.content = Some(truncated);
143                    used_tokens += truncated_tokens;
144                    truncated_count += 1;
145                } else {
146                    // No room - exclude content
147                    file.content = None;
148                    excluded_count += 1;
149                }
150            }
151        }
152
153        EnforcementResult {
154            total_tokens: used_tokens,
155            truncated_files: truncated_count,
156            excluded_files: excluded_count,
157            budget_used_pct: used_tokens.percentage_of(available_budget),
158        }
159    }
160
161    /// Count tokens in text using configured model
162    fn count_tokens(&self, text: &str) -> u32 {
163        self.tokenizer.count(text, self.config.model)
164    }
165
166    /// Truncate content to fit within max_tokens using binary search
167    ///
168    /// Uses binary search to find the optimal cut point, then adjusts
169    /// to the nearest semantic boundary.
170    pub fn truncate_to_tokens(&self, content: &str, max_tokens: u32) -> String {
171        // Quick check if content already fits
172        let total_tokens = self.count_tokens(content);
173        if total_tokens <= max_tokens {
174            return content.to_owned();
175        }
176
177        // Binary search for optimal byte position
178        let mut low = 0usize;
179        let mut high = content.len();
180        let mut best_pos = 0usize;
181
182        while low < high {
183            let mid = (low + high).div_ceil(2);
184
185            // Ensure we don't split a UTF-8 character
186            let safe_mid = self.find_char_boundary(content, mid);
187            let slice = &content[..safe_mid];
188            let tokens = self.count_tokens(slice);
189
190            if tokens <= max_tokens {
191                best_pos = safe_mid;
192                low = mid;
193            } else {
194                high = mid - 1;
195            }
196        }
197
198        // Find semantic boundary near best_pos
199        let boundary = self.find_semantic_boundary(content, best_pos);
200
201        // Add truncation indicator
202        let mut result = content[..boundary].to_owned();
203        if boundary < content.len() {
204            result.push_str("\n\n... [truncated]");
205        }
206
207        result
208    }
209
210    /// Find a valid UTF-8 character boundary at or before position
211    fn find_char_boundary(&self, s: &str, pos: usize) -> usize {
212        if pos >= s.len() {
213            return s.len();
214        }
215
216        let mut boundary = pos;
217        while boundary > 0 && !s.is_char_boundary(boundary) {
218            boundary -= 1;
219        }
220        boundary
221    }
222
223    /// Find a semantic boundary (line end, function end, etc.) near position
224    fn find_semantic_boundary(&self, content: &str, pos: usize) -> usize {
225        if pos == 0 || pos >= content.len() {
226            return pos;
227        }
228
229        let slice = &content[..pos];
230
231        match self.config.strategy {
232            TruncationStrategy::Hard => pos,
233            TruncationStrategy::Line => {
234                // Find last newline
235                slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
236            },
237            TruncationStrategy::Semantic => {
238                // Try to find function/class boundary first
239                if let Some(boundary) = self.find_function_boundary(slice) {
240                    return boundary;
241                }
242                // Fall back to line boundary
243                slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
244            },
245        }
246    }
247
248    /// Find a function/class boundary in the content
249    fn find_function_boundary(&self, content: &str) -> Option<usize> {
250        // Look for common function/class start patterns from the end
251        let patterns = [
252            "\n\nfn ",       // Rust
253            "\n\ndef ",      // Python
254            "\n\nclass ",    // Python/JS
255            "\n\nfunction ", // JavaScript
256            "\n\npub fn ",   // Rust public
257            "\n\nasync ",    // JavaScript async
258            "\n\nimpl ",     // Rust impl
259            "\n\n#[",        // Rust attributes
260            "\n\n@",         // Decorators
261        ];
262
263        // Search from the end for function boundaries
264        let mut best_pos = None;
265        for pattern in patterns {
266            if let Some(pos) = content.rfind(pattern) {
267                // Check if this position is better (closer to end)
268                if best_pos.map_or(true, |bp| pos > bp) {
269                    best_pos = Some(pos);
270                }
271            }
272        }
273
274        // Return position after the double newline but before the pattern keyword
275        // The +2 accounts for "\n\n" - we want to include the newlines but start
276        // before the actual function/class keyword
277        best_pos.map(|p| {
278            // Validate bounds to prevent off-by-one errors
279            let boundary = p + 2;
280            if boundary <= content.len() {
281                boundary
282            } else {
283                // Fallback to just after the match position if bounds check fails
284                (p + 1).min(content.len())
285            }
286        })
287    }
288}
289
290/// Result of budget enforcement
291#[derive(Debug, Clone)]
292pub struct EnforcementResult {
293    /// Total tokens used after enforcement
294    pub total_tokens: TokenCount,
295    /// Number of files that were truncated
296    pub truncated_files: usize,
297    /// Number of files excluded entirely
298    pub excluded_files: usize,
299    /// Percentage of budget used
300    pub budget_used_pct: f32,
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_truncate_preserves_short_content() {
309        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
310        let content = "Hello, world!";
311        let result = enforcer.truncate_to_tokens(content, 1000);
312        assert_eq!(result, content);
313    }
314
315    #[test]
316    fn test_truncate_adds_indicator() {
317        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
318        let content = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10";
319        let result = enforcer.truncate_to_tokens(content, 5);
320        assert!(result.contains("[truncated]"));
321        assert!(result.len() < content.len());
322    }
323
324    #[test]
325    fn test_find_char_boundary() {
326        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
327        let content = "Hello, δΈ–η•Œ!"; // Multi-byte UTF-8
328        let boundary = enforcer.find_char_boundary(content, 8);
329        // Should find boundary at valid UTF-8 position
330        assert!(content.is_char_boundary(boundary));
331    }
332
333    #[test]
334    fn test_semantic_boundary_line() {
335        let config = BudgetConfig { strategy: TruncationStrategy::Line, ..Default::default() };
336        let enforcer = BudgetEnforcer::new(config);
337        let content = "line1\nline2\nline3";
338        let boundary = enforcer.find_semantic_boundary(content, 10);
339        // Should find boundary after "line1\n"
340        assert_eq!(boundary, 6);
341    }
342
343    #[test]
344    fn test_semantic_boundary_function() {
345        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
346        let enforcer = BudgetEnforcer::new(config);
347        let content = "fn foo() {}\n\ndef bar():\n    pass";
348        let boundary = enforcer.find_semantic_boundary(content, content.len());
349        // Should find boundary at "def bar"
350        assert!(boundary > 10);
351    }
352
353    // =========================================================================
354    // Additional edge case tests for comprehensive coverage
355    // =========================================================================
356
357    #[test]
358    fn test_empty_content_truncation() {
359        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
360        let result = enforcer.truncate_to_tokens("", 100);
361        assert_eq!(result, "");
362    }
363
364    #[test]
365    fn test_single_character_content() {
366        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
367        let result = enforcer.truncate_to_tokens("x", 100);
368        assert_eq!(result, "x");
369    }
370
371    #[test]
372    fn test_zero_budget_truncation() {
373        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
374        let content = "Some content that will be truncated";
375        let result = enforcer.truncate_to_tokens(content, 0);
376        // Should return minimal content or truncation indicator
377        assert!(result.len() <= content.len());
378    }
379
380    #[test]
381    fn test_unicode_boundary_preservation() {
382        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
383        // Content with multi-byte UTF-8 characters
384        let content = "Hello δΈ–η•Œ! More text here. πŸ¦€ Rust! Even more...";
385
386        // Truncate to various budgets
387        for budget in [1, 2, 3, 5, 10] {
388            let result = enforcer.truncate_to_tokens(content, budget);
389            // Verify we can still iterate over chars (valid UTF-8)
390            let _ = result.chars().count();
391            // Verify the string is valid
392            assert!(std::str::from_utf8(result.as_bytes()).is_ok());
393        }
394    }
395
396    #[test]
397    fn test_content_smaller_than_indicator() {
398        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
399        // Very small content
400        let content = "Hi";
401        let result = enforcer.truncate_to_tokens(content, 1);
402        // Should handle gracefully
403        assert!(!result.is_empty() || content.is_empty());
404    }
405
406    #[test]
407    fn test_hard_truncation_strategy() {
408        let config = BudgetConfig { strategy: TruncationStrategy::Hard, ..Default::default() };
409        let enforcer = BudgetEnforcer::new(config);
410        let content = "line1\nline2\nline3";
411        let boundary = enforcer.find_semantic_boundary(content, 10);
412        // Hard strategy should return exact position
413        assert_eq!(boundary, 10);
414    }
415
416    #[test]
417    fn test_boundary_at_start() {
418        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
419        let content = "Some content";
420        let boundary = enforcer.find_semantic_boundary(content, 0);
421        assert_eq!(boundary, 0);
422    }
423
424    #[test]
425    fn test_boundary_past_end() {
426        let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
427        let content = "Some content";
428        let boundary = enforcer.find_semantic_boundary(content, content.len() + 10);
429        // Should clamp to content length
430        assert_eq!(boundary, content.len() + 10); // Returns pos as-is when >= len
431    }
432
433    #[test]
434    fn test_function_boundary_rust_patterns() {
435        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
436        let enforcer = BudgetEnforcer::new(config);
437
438        // Test various Rust function patterns
439        let content = "use std::io;\n\nfn helper() {}\n\npub fn main() {}";
440        let boundary = enforcer.find_function_boundary(content);
441        assert!(boundary.is_some());
442
443        // Test impl block
444        let content2 = "struct Foo;\n\nimpl Foo {\n    fn new() {}\n}";
445        let boundary2 = enforcer.find_function_boundary(content2);
446        assert!(boundary2.is_some());
447    }
448
449    #[test]
450    fn test_function_boundary_python_patterns() {
451        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
452        let enforcer = BudgetEnforcer::new(config);
453
454        // Test Python function with decorator
455        let content = "import os\n\n@decorator\ndef foo():\n    pass";
456        let boundary = enforcer.find_function_boundary(content);
457        assert!(boundary.is_some());
458
459        // Test Python class
460        let content2 = "import sys\n\nclass MyClass:\n    pass";
461        let boundary2 = enforcer.find_function_boundary(content2);
462        assert!(boundary2.is_some());
463    }
464
465    #[test]
466    fn test_function_boundary_javascript_patterns() {
467        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
468        let enforcer = BudgetEnforcer::new(config);
469
470        // Test JavaScript function
471        let content = "const x = 1;\n\nfunction foo() {}\n\nasync function bar() {}";
472        let boundary = enforcer.find_function_boundary(content);
473        assert!(boundary.is_some());
474    }
475
476    #[test]
477    fn test_no_function_boundary_found() {
478        let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
479        let enforcer = BudgetEnforcer::new(config);
480
481        // Content without any function boundaries
482        let content = "just some text without any code patterns";
483        let boundary = enforcer.find_function_boundary(content);
484        assert!(boundary.is_none());
485    }
486
487    #[test]
488    fn test_enforcement_result_fields() {
489        let result = EnforcementResult {
490            total_tokens: TokenCount::new(1000),
491            truncated_files: 5,
492            excluded_files: 2,
493            budget_used_pct: 85.5,
494        };
495
496        assert_eq!(result.total_tokens.get(), 1000);
497        assert_eq!(result.truncated_files, 5);
498        assert_eq!(result.excluded_files, 2);
499        assert!((result.budget_used_pct - 85.5).abs() < 0.01);
500    }
501
502    #[test]
503    fn test_budget_config_default() {
504        use crate::constants::budget as budget_consts;
505        let config = BudgetConfig::default();
506        assert_eq!(config.budget.get(), budget_consts::DEFAULT_BUDGET);
507        assert!(matches!(config.strategy, TruncationStrategy::Line));
508        assert_eq!(config.overhead_reserve.get(), budget_consts::OVERHEAD_RESERVE);
509    }
510}