1use crate::constants::budget as budget_consts;
47use crate::newtypes::TokenCount;
48use crate::tokenizer::{TokenModel, Tokenizer};
49use crate::types::Repository;
50
51#[derive(Debug, Clone, Copy, Default)]
53pub enum TruncationStrategy {
54 #[default]
56 Line,
57 Semantic,
59 Hard,
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct BudgetConfig {
66 pub budget: TokenCount,
68 pub model: TokenModel,
70 pub strategy: TruncationStrategy,
72 pub overhead_reserve: TokenCount,
74}
75
76impl Default for BudgetConfig {
77 fn default() -> Self {
78 Self {
79 budget: TokenCount::new(budget_consts::DEFAULT_BUDGET),
80 model: TokenModel::Claude,
81 strategy: TruncationStrategy::Line,
82 overhead_reserve: TokenCount::new(budget_consts::OVERHEAD_RESERVE),
83 }
84 }
85}
86
87pub struct BudgetEnforcer {
89 config: BudgetConfig,
90 tokenizer: Tokenizer,
91}
92
93impl BudgetEnforcer {
94 pub fn new(config: BudgetConfig) -> Self {
96 Self { config, tokenizer: Tokenizer::new() }
97 }
98
99 pub fn with_budget(budget: u32, model: TokenModel) -> Self {
101 Self::new(BudgetConfig { budget: TokenCount::new(budget), model, ..Default::default() })
102 }
103
104 pub fn enforce(&self, repo: &mut Repository) -> EnforcementResult {
109 let available_budget = self
110 .config
111 .budget
112 .saturating_sub(self.config.overhead_reserve);
113 let mut used_tokens = TokenCount::zero();
114 let mut truncated_count = 0usize;
115 let mut excluded_count = 0usize;
116 let min_partial = TokenCount::new(budget_consts::MIN_PARTIAL_FIT_TOKENS);
117
118 let mut file_indices: Vec<usize> = (0..repo.files.len()).collect();
120 file_indices.sort_by(|&a, &b| {
121 repo.files[b]
122 .importance
123 .partial_cmp(&repo.files[a].importance)
124 .unwrap_or(std::cmp::Ordering::Equal)
125 });
126
127 for idx in file_indices {
128 let file = &mut repo.files[idx];
129
130 if let Some(content) = file.content.as_ref() {
131 let file_tokens = TokenCount::new(self.count_tokens(content));
132
133 if used_tokens + file_tokens <= available_budget {
134 used_tokens += file_tokens;
136 } else if used_tokens + min_partial < available_budget {
137 let remaining = available_budget.saturating_sub(used_tokens);
139 let truncated = self.truncate_to_tokens(content, remaining.get());
140 let truncated_tokens = TokenCount::new(self.count_tokens(&truncated));
141
142 file.content = Some(truncated);
143 used_tokens += truncated_tokens;
144 truncated_count += 1;
145 } else {
146 file.content = None;
148 excluded_count += 1;
149 }
150 }
151 }
152
153 EnforcementResult {
154 total_tokens: used_tokens,
155 truncated_files: truncated_count,
156 excluded_files: excluded_count,
157 budget_used_pct: used_tokens.percentage_of(available_budget),
158 }
159 }
160
161 fn count_tokens(&self, text: &str) -> u32 {
163 self.tokenizer.count(text, self.config.model)
164 }
165
166 pub fn truncate_to_tokens(&self, content: &str, max_tokens: u32) -> String {
171 let total_tokens = self.count_tokens(content);
173 if total_tokens <= max_tokens {
174 return content.to_owned();
175 }
176
177 let mut low = 0usize;
179 let mut high = content.len();
180 let mut best_pos = 0usize;
181
182 while low < high {
183 let mid = (low + high).div_ceil(2);
184
185 let safe_mid = self.find_char_boundary(content, mid);
187 let slice = &content[..safe_mid];
188 let tokens = self.count_tokens(slice);
189
190 if tokens <= max_tokens {
191 best_pos = safe_mid;
192 low = mid;
193 } else {
194 high = mid - 1;
195 }
196 }
197
198 let boundary = self.find_semantic_boundary(content, best_pos);
200
201 let mut result = content[..boundary].to_owned();
203 if boundary < content.len() {
204 result.push_str("\n\n... [truncated]");
205 }
206
207 result
208 }
209
210 fn find_char_boundary(&self, s: &str, pos: usize) -> usize {
212 if pos >= s.len() {
213 return s.len();
214 }
215
216 let mut boundary = pos;
217 while boundary > 0 && !s.is_char_boundary(boundary) {
218 boundary -= 1;
219 }
220 boundary
221 }
222
223 fn find_semantic_boundary(&self, content: &str, pos: usize) -> usize {
225 if pos == 0 || pos >= content.len() {
226 return pos;
227 }
228
229 let slice = &content[..pos];
230
231 match self.config.strategy {
232 TruncationStrategy::Hard => pos,
233 TruncationStrategy::Line => {
234 slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
236 },
237 TruncationStrategy::Semantic => {
238 if let Some(boundary) = self.find_function_boundary(slice) {
240 return boundary;
241 }
242 slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
244 },
245 }
246 }
247
248 fn find_function_boundary(&self, content: &str) -> Option<usize> {
250 let patterns = [
252 "\n\nfn ", "\n\ndef ", "\n\nclass ", "\n\nfunction ", "\n\npub fn ", "\n\nasync ", "\n\nimpl ", "\n\n#[", "\n\n@", ];
262
263 let mut best_pos = None;
265 for pattern in patterns {
266 if let Some(pos) = content.rfind(pattern) {
267 if best_pos.map_or(true, |bp| pos > bp) {
269 best_pos = Some(pos);
270 }
271 }
272 }
273
274 best_pos.map(|p| {
278 let boundary = p + 2;
280 if boundary <= content.len() {
281 boundary
282 } else {
283 (p + 1).min(content.len())
285 }
286 })
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct EnforcementResult {
293 pub total_tokens: TokenCount,
295 pub truncated_files: usize,
297 pub excluded_files: usize,
299 pub budget_used_pct: f32,
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_truncate_preserves_short_content() {
309 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
310 let content = "Hello, world!";
311 let result = enforcer.truncate_to_tokens(content, 1000);
312 assert_eq!(result, content);
313 }
314
315 #[test]
316 fn test_truncate_adds_indicator() {
317 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
318 let content = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10";
319 let result = enforcer.truncate_to_tokens(content, 5);
320 assert!(result.contains("[truncated]"));
321 assert!(result.len() < content.len());
322 }
323
324 #[test]
325 fn test_find_char_boundary() {
326 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
327 let content = "Hello, δΈη!"; let boundary = enforcer.find_char_boundary(content, 8);
329 assert!(content.is_char_boundary(boundary));
331 }
332
333 #[test]
334 fn test_semantic_boundary_line() {
335 let config = BudgetConfig { strategy: TruncationStrategy::Line, ..Default::default() };
336 let enforcer = BudgetEnforcer::new(config);
337 let content = "line1\nline2\nline3";
338 let boundary = enforcer.find_semantic_boundary(content, 10);
339 assert_eq!(boundary, 6);
341 }
342
343 #[test]
344 fn test_semantic_boundary_function() {
345 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
346 let enforcer = BudgetEnforcer::new(config);
347 let content = "fn foo() {}\n\ndef bar():\n pass";
348 let boundary = enforcer.find_semantic_boundary(content, content.len());
349 assert!(boundary > 10);
351 }
352
353 #[test]
358 fn test_empty_content_truncation() {
359 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
360 let result = enforcer.truncate_to_tokens("", 100);
361 assert_eq!(result, "");
362 }
363
364 #[test]
365 fn test_single_character_content() {
366 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
367 let result = enforcer.truncate_to_tokens("x", 100);
368 assert_eq!(result, "x");
369 }
370
371 #[test]
372 fn test_zero_budget_truncation() {
373 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
374 let content = "Some content that will be truncated";
375 let result = enforcer.truncate_to_tokens(content, 0);
376 assert!(result.len() <= content.len());
378 }
379
380 #[test]
381 fn test_unicode_boundary_preservation() {
382 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
383 let content = "Hello δΈη! More text here. π¦ Rust! Even more...";
385
386 for budget in [1, 2, 3, 5, 10] {
388 let result = enforcer.truncate_to_tokens(content, budget);
389 let _ = result.chars().count();
391 assert!(std::str::from_utf8(result.as_bytes()).is_ok());
393 }
394 }
395
396 #[test]
397 fn test_content_smaller_than_indicator() {
398 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
399 let content = "Hi";
401 let result = enforcer.truncate_to_tokens(content, 1);
402 assert!(!result.is_empty() || content.is_empty());
404 }
405
406 #[test]
407 fn test_hard_truncation_strategy() {
408 let config = BudgetConfig { strategy: TruncationStrategy::Hard, ..Default::default() };
409 let enforcer = BudgetEnforcer::new(config);
410 let content = "line1\nline2\nline3";
411 let boundary = enforcer.find_semantic_boundary(content, 10);
412 assert_eq!(boundary, 10);
414 }
415
416 #[test]
417 fn test_boundary_at_start() {
418 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
419 let content = "Some content";
420 let boundary = enforcer.find_semantic_boundary(content, 0);
421 assert_eq!(boundary, 0);
422 }
423
424 #[test]
425 fn test_boundary_past_end() {
426 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
427 let content = "Some content";
428 let boundary = enforcer.find_semantic_boundary(content, content.len() + 10);
429 assert_eq!(boundary, content.len() + 10); }
432
433 #[test]
434 fn test_function_boundary_rust_patterns() {
435 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
436 let enforcer = BudgetEnforcer::new(config);
437
438 let content = "use std::io;\n\nfn helper() {}\n\npub fn main() {}";
440 let boundary = enforcer.find_function_boundary(content);
441 assert!(boundary.is_some());
442
443 let content2 = "struct Foo;\n\nimpl Foo {\n fn new() {}\n}";
445 let boundary2 = enforcer.find_function_boundary(content2);
446 assert!(boundary2.is_some());
447 }
448
449 #[test]
450 fn test_function_boundary_python_patterns() {
451 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
452 let enforcer = BudgetEnforcer::new(config);
453
454 let content = "import os\n\n@decorator\ndef foo():\n pass";
456 let boundary = enforcer.find_function_boundary(content);
457 assert!(boundary.is_some());
458
459 let content2 = "import sys\n\nclass MyClass:\n pass";
461 let boundary2 = enforcer.find_function_boundary(content2);
462 assert!(boundary2.is_some());
463 }
464
465 #[test]
466 fn test_function_boundary_javascript_patterns() {
467 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
468 let enforcer = BudgetEnforcer::new(config);
469
470 let content = "const x = 1;\n\nfunction foo() {}\n\nasync function bar() {}";
472 let boundary = enforcer.find_function_boundary(content);
473 assert!(boundary.is_some());
474 }
475
476 #[test]
477 fn test_no_function_boundary_found() {
478 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
479 let enforcer = BudgetEnforcer::new(config);
480
481 let content = "just some text without any code patterns";
483 let boundary = enforcer.find_function_boundary(content);
484 assert!(boundary.is_none());
485 }
486
487 #[test]
488 fn test_enforcement_result_fields() {
489 let result = EnforcementResult {
490 total_tokens: TokenCount::new(1000),
491 truncated_files: 5,
492 excluded_files: 2,
493 budget_used_pct: 85.5,
494 };
495
496 assert_eq!(result.total_tokens.get(), 1000);
497 assert_eq!(result.truncated_files, 5);
498 assert_eq!(result.excluded_files, 2);
499 assert!((result.budget_used_pct - 85.5).abs() < 0.01);
500 }
501
502 #[test]
503 fn test_budget_config_default() {
504 use crate::constants::budget as budget_consts;
505 let config = BudgetConfig::default();
506 assert_eq!(config.budget.get(), budget_consts::DEFAULT_BUDGET);
507 assert!(matches!(config.strategy, TruncationStrategy::Line));
508 assert_eq!(config.overhead_reserve.get(), budget_consts::OVERHEAD_RESERVE);
509 }
510}