1use crate::tokenizer::{TokenModel, Tokenizer};
47use crate::types::Repository;
48
49#[derive(Debug, Clone, Copy, Default)]
51pub enum TruncationStrategy {
52 #[default]
54 Line,
55 Semantic,
57 Hard,
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct BudgetConfig {
64 pub budget: u32,
66 pub model: TokenModel,
68 pub strategy: TruncationStrategy,
70 pub overhead_reserve: u32,
72}
73
74impl Default for BudgetConfig {
75 fn default() -> Self {
76 Self {
77 budget: 100_000,
78 model: TokenModel::Claude,
79 strategy: TruncationStrategy::Line,
80 overhead_reserve: 1000,
81 }
82 }
83}
84
85pub struct BudgetEnforcer {
87 config: BudgetConfig,
88 tokenizer: Tokenizer,
89}
90
91impl BudgetEnforcer {
92 pub fn new(config: BudgetConfig) -> Self {
94 Self { config, tokenizer: Tokenizer::new() }
95 }
96
97 pub fn with_budget(budget: u32, model: TokenModel) -> Self {
99 Self::new(BudgetConfig { budget, model, ..Default::default() })
100 }
101
102 pub fn enforce(&self, repo: &mut Repository) -> EnforcementResult {
107 let available_budget = self
108 .config
109 .budget
110 .saturating_sub(self.config.overhead_reserve);
111 let mut used_tokens = 0u32;
112 let mut truncated_count = 0usize;
113 let mut excluded_count = 0usize;
114
115 let mut file_indices: Vec<usize> = (0..repo.files.len()).collect();
117 file_indices.sort_by(|&a, &b| {
118 repo.files[b]
119 .importance
120 .partial_cmp(&repo.files[a].importance)
121 .unwrap_or(std::cmp::Ordering::Equal)
122 });
123
124 for idx in file_indices {
125 let file = &mut repo.files[idx];
126
127 if let Some(content) = file.content.as_ref() {
128 let file_tokens = self.count_tokens(content);
129
130 if used_tokens + file_tokens <= available_budget {
131 used_tokens += file_tokens;
133 } else if used_tokens + 100 < available_budget {
134 let remaining = available_budget - used_tokens;
136 let truncated = self.truncate_to_tokens(content, remaining);
137 let truncated_tokens = self.count_tokens(&truncated);
138
139 file.content = Some(truncated);
140 used_tokens += truncated_tokens;
141 truncated_count += 1;
142 } else {
143 file.content = None;
145 excluded_count += 1;
146 }
147 }
148 }
149
150 EnforcementResult {
151 total_tokens: used_tokens,
152 truncated_files: truncated_count,
153 excluded_files: excluded_count,
154 budget_used_pct: (used_tokens as f32 / available_budget as f32) * 100.0,
155 }
156 }
157
158 fn count_tokens(&self, text: &str) -> u32 {
160 self.tokenizer.count(text, self.config.model)
161 }
162
163 pub fn truncate_to_tokens(&self, content: &str, max_tokens: u32) -> String {
168 let total_tokens = self.count_tokens(content);
170 if total_tokens <= max_tokens {
171 return content.to_owned();
172 }
173
174 let mut low = 0usize;
176 let mut high = content.len();
177 let mut best_pos = 0usize;
178
179 while low < high {
180 let mid = (low + high).div_ceil(2);
181
182 let safe_mid = self.find_char_boundary(content, mid);
184 let slice = &content[..safe_mid];
185 let tokens = self.count_tokens(slice);
186
187 if tokens <= max_tokens {
188 best_pos = safe_mid;
189 low = mid;
190 } else {
191 high = mid - 1;
192 }
193 }
194
195 let boundary = self.find_semantic_boundary(content, best_pos);
197
198 let mut result = content[..boundary].to_owned();
200 if boundary < content.len() {
201 result.push_str("\n\n... [truncated]");
202 }
203
204 result
205 }
206
207 fn find_char_boundary(&self, s: &str, pos: usize) -> usize {
209 if pos >= s.len() {
210 return s.len();
211 }
212
213 let mut boundary = pos;
214 while boundary > 0 && !s.is_char_boundary(boundary) {
215 boundary -= 1;
216 }
217 boundary
218 }
219
220 fn find_semantic_boundary(&self, content: &str, pos: usize) -> usize {
222 if pos == 0 || pos >= content.len() {
223 return pos;
224 }
225
226 let slice = &content[..pos];
227
228 match self.config.strategy {
229 TruncationStrategy::Hard => pos,
230 TruncationStrategy::Line => {
231 slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
233 },
234 TruncationStrategy::Semantic => {
235 if let Some(boundary) = self.find_function_boundary(slice) {
237 return boundary;
238 }
239 slice.rfind('\n').map(|p| p + 1).unwrap_or(pos)
241 },
242 }
243 }
244
245 fn find_function_boundary(&self, content: &str) -> Option<usize> {
247 let patterns = [
249 "\n\nfn ", "\n\ndef ", "\n\nclass ", "\n\nfunction ", "\n\npub fn ", "\n\nasync ", "\n\nimpl ", "\n\n#[", "\n\n@", ];
259
260 let mut best_pos = None;
262 for pattern in patterns {
263 if let Some(pos) = content.rfind(pattern) {
264 if best_pos.is_none() || pos > best_pos.unwrap() {
266 best_pos = Some(pos);
267 }
268 }
269 }
270
271 best_pos.map(|p| {
275 let boundary = p + 2;
277 if boundary <= content.len() {
278 boundary
279 } else {
280 (p + 1).min(content.len())
282 }
283 })
284 }
285}
286
287#[derive(Debug, Clone)]
289pub struct EnforcementResult {
290 pub total_tokens: u32,
292 pub truncated_files: usize,
294 pub excluded_files: usize,
296 pub budget_used_pct: f32,
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_truncate_preserves_short_content() {
306 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
307 let content = "Hello, world!";
308 let result = enforcer.truncate_to_tokens(content, 1000);
309 assert_eq!(result, content);
310 }
311
312 #[test]
313 fn test_truncate_adds_indicator() {
314 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
315 let content = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10";
316 let result = enforcer.truncate_to_tokens(content, 5);
317 assert!(result.contains("[truncated]"));
318 assert!(result.len() < content.len());
319 }
320
321 #[test]
322 fn test_find_char_boundary() {
323 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
324 let content = "Hello, δΈη!"; let boundary = enforcer.find_char_boundary(content, 8);
326 assert!(content.is_char_boundary(boundary));
328 }
329
330 #[test]
331 fn test_semantic_boundary_line() {
332 let config = BudgetConfig { strategy: TruncationStrategy::Line, ..Default::default() };
333 let enforcer = BudgetEnforcer::new(config);
334 let content = "line1\nline2\nline3";
335 let boundary = enforcer.find_semantic_boundary(content, 10);
336 assert_eq!(boundary, 6);
338 }
339
340 #[test]
341 fn test_semantic_boundary_function() {
342 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
343 let enforcer = BudgetEnforcer::new(config);
344 let content = "fn foo() {}\n\ndef bar():\n pass";
345 let boundary = enforcer.find_semantic_boundary(content, content.len());
346 assert!(boundary > 10);
348 }
349
350 #[test]
355 fn test_empty_content_truncation() {
356 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
357 let result = enforcer.truncate_to_tokens("", 100);
358 assert_eq!(result, "");
359 }
360
361 #[test]
362 fn test_single_character_content() {
363 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
364 let result = enforcer.truncate_to_tokens("x", 100);
365 assert_eq!(result, "x");
366 }
367
368 #[test]
369 fn test_zero_budget_truncation() {
370 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
371 let content = "Some content that will be truncated";
372 let result = enforcer.truncate_to_tokens(content, 0);
373 assert!(result.len() <= content.len());
375 }
376
377 #[test]
378 fn test_unicode_boundary_preservation() {
379 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
380 let content = "Hello δΈη! More text here. π¦ Rust! Even more...";
382
383 for budget in [1, 2, 3, 5, 10] {
385 let result = enforcer.truncate_to_tokens(content, budget);
386 let _ = result.chars().count();
388 assert!(std::str::from_utf8(result.as_bytes()).is_ok());
390 }
391 }
392
393 #[test]
394 fn test_content_smaller_than_indicator() {
395 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
396 let content = "Hi";
398 let result = enforcer.truncate_to_tokens(content, 1);
399 assert!(!result.is_empty() || content.is_empty());
401 }
402
403 #[test]
404 fn test_hard_truncation_strategy() {
405 let config = BudgetConfig { strategy: TruncationStrategy::Hard, ..Default::default() };
406 let enforcer = BudgetEnforcer::new(config);
407 let content = "line1\nline2\nline3";
408 let boundary = enforcer.find_semantic_boundary(content, 10);
409 assert_eq!(boundary, 10);
411 }
412
413 #[test]
414 fn test_boundary_at_start() {
415 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
416 let content = "Some content";
417 let boundary = enforcer.find_semantic_boundary(content, 0);
418 assert_eq!(boundary, 0);
419 }
420
421 #[test]
422 fn test_boundary_past_end() {
423 let enforcer = BudgetEnforcer::with_budget(10000, TokenModel::Claude);
424 let content = "Some content";
425 let boundary = enforcer.find_semantic_boundary(content, content.len() + 10);
426 assert_eq!(boundary, content.len() + 10); }
429
430 #[test]
431 fn test_function_boundary_rust_patterns() {
432 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
433 let enforcer = BudgetEnforcer::new(config);
434
435 let content = "use std::io;\n\nfn helper() {}\n\npub fn main() {}";
437 let boundary = enforcer.find_function_boundary(content);
438 assert!(boundary.is_some());
439
440 let content2 = "struct Foo;\n\nimpl Foo {\n fn new() {}\n}";
442 let boundary2 = enforcer.find_function_boundary(content2);
443 assert!(boundary2.is_some());
444 }
445
446 #[test]
447 fn test_function_boundary_python_patterns() {
448 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
449 let enforcer = BudgetEnforcer::new(config);
450
451 let content = "import os\n\n@decorator\ndef foo():\n pass";
453 let boundary = enforcer.find_function_boundary(content);
454 assert!(boundary.is_some());
455
456 let content2 = "import sys\n\nclass MyClass:\n pass";
458 let boundary2 = enforcer.find_function_boundary(content2);
459 assert!(boundary2.is_some());
460 }
461
462 #[test]
463 fn test_function_boundary_javascript_patterns() {
464 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
465 let enforcer = BudgetEnforcer::new(config);
466
467 let content = "const x = 1;\n\nfunction foo() {}\n\nasync function bar() {}";
469 let boundary = enforcer.find_function_boundary(content);
470 assert!(boundary.is_some());
471 }
472
473 #[test]
474 fn test_no_function_boundary_found() {
475 let config = BudgetConfig { strategy: TruncationStrategy::Semantic, ..Default::default() };
476 let enforcer = BudgetEnforcer::new(config);
477
478 let content = "just some text without any code patterns";
480 let boundary = enforcer.find_function_boundary(content);
481 assert!(boundary.is_none());
482 }
483
484 #[test]
485 fn test_enforcement_result_fields() {
486 let result = EnforcementResult {
487 total_tokens: 1000,
488 truncated_files: 5,
489 excluded_files: 2,
490 budget_used_pct: 85.5,
491 };
492
493 assert_eq!(result.total_tokens, 1000);
494 assert_eq!(result.truncated_files, 5);
495 assert_eq!(result.excluded_files, 2);
496 assert!((result.budget_used_pct - 85.5).abs() < 0.01);
497 }
498
499 #[test]
500 fn test_budget_config_default() {
501 let config = BudgetConfig::default();
502 assert_eq!(config.budget, 100_000);
503 assert!(matches!(config.strategy, TruncationStrategy::Line));
504 assert_eq!(config.overhead_reserve, 1000);
505 }
506}