1use crate::soch_ql::SochValue;
48use std::sync::atomic::{AtomicUsize, Ordering};
49
50#[derive(Debug, Clone)]
67pub struct TokenEstimatorConfig {
68 pub int_factor: f32,
70 pub float_factor: f32,
72 pub string_factor: f32,
74 pub hex_factor: f32,
76 pub bytes_per_token: f32,
78 pub safety_margin: f32,
81 pub separator_tokens: usize,
83 pub newline_tokens: usize,
85 pub header_tokens: usize,
87}
88
89impl Default for TokenEstimatorConfig {
90 fn default() -> Self {
91 Self {
92 int_factor: 1.0,
93 float_factor: 1.2,
94 string_factor: 1.1,
95 hex_factor: 2.5,
96 bytes_per_token: 4.0, safety_margin: 1.15, separator_tokens: 1,
99 newline_tokens: 1,
100 header_tokens: 10, }
102 }
103}
104
105impl TokenEstimatorConfig {
106 pub fn gpt4() -> Self {
108 Self {
109 bytes_per_token: 3.8,
110 safety_margin: 1.15,
111 ..Default::default()
112 }
113 }
114
115 pub fn claude() -> Self {
117 Self {
118 bytes_per_token: 4.2,
119 safety_margin: 1.15,
120 ..Default::default()
121 }
122 }
123
124 pub fn conservative() -> Self {
126 Self {
127 int_factor: 1.2,
128 float_factor: 1.4,
129 string_factor: 1.3,
130 hex_factor: 3.0,
131 bytes_per_token: 3.5,
132 safety_margin: 1.25, ..Default::default()
134 }
135 }
136}
137
138pub struct TokenEstimator {
140 config: TokenEstimatorConfig,
141}
142
143impl TokenEstimator {
144 pub fn new() -> Self {
146 Self {
147 config: TokenEstimatorConfig::default(),
148 }
149 }
150
151 pub fn with_config(config: TokenEstimatorConfig) -> Self {
153 Self { config }
154 }
155
156 pub fn estimate_value(&self, value: &SochValue) -> usize {
161 let raw = self.estimate_value_raw(value);
162 ((raw as f32) * self.config.safety_margin).ceil() as usize
163 }
164
165 fn estimate_value_raw(&self, value: &SochValue) -> usize {
167 match value {
168 SochValue::Null => 1,
169 SochValue::Bool(_) => 1, SochValue::Int(n) => {
171 let digits = if *n == 0 {
173 1
174 } else {
175 ((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
176 };
177 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
178 as usize
179 }
180 SochValue::UInt(n) => {
181 let digits = if *n == 0 {
182 1
183 } else {
184 ((*n as f64).log10().ceil() as usize).max(1)
185 };
186 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
187 as usize
188 }
189 SochValue::Float(f) => {
190 let s = format!("{:.2}", f);
192 ((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
193 as usize
194 }
195 SochValue::Text(s) => {
196 ((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
198 as usize
199 }
200 SochValue::Binary(b) => {
201 let hex_len = 2 + b.len() * 2;
203 ((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
204 as usize
205 }
206 SochValue::Array(arr) => {
207 let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
209 let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
210 2 + elem_tokens + separator_tokens }
212 }
213 }
214
215 pub fn estimate_row(&self, values: &[SochValue]) -> usize {
217 if values.is_empty() {
218 return 0;
219 }
220
221 let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
222 let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
223 let newline = self.config.newline_tokens;
224
225 value_tokens + separator_tokens + newline
226 }
227
228 pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
230 let base = self.config.header_tokens;
232 let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
233 let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
234 let col_tokens: usize = columns
235 .iter()
236 .map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
237 .sum();
238
239 base + table_tokens + count_tokens + col_tokens
240 }
241
242 pub fn estimate_table(
244 &self,
245 table: &str,
246 columns: &[String],
247 rows: &[Vec<SochValue>],
248 ) -> usize {
249 let header = self.estimate_header(table, columns, rows.len());
250 let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
251 header + row_tokens
252 }
253
254 pub fn estimate_text(&self, text: &str) -> usize {
256 let raw = ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize;
257 ((raw as f32) * self.config.safety_margin).ceil() as usize
258 }
259
260 pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
264 truncate_to_tokens(text, max_tokens, self, "...")
265 }
266}
267
268impl Default for TokenEstimator {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[derive(Debug, Clone)]
280pub struct BudgetAllocation {
281 pub full_sections: Vec<String>,
283 pub truncated_sections: Vec<(String, usize, usize)>,
285 pub dropped_sections: Vec<String>,
287 pub tokens_allocated: usize,
289 pub tokens_remaining: usize,
291 pub explain: Vec<AllocationDecision>,
293}
294
295#[derive(Debug, Clone)]
297pub struct AllocationDecision {
298 pub section: String,
300 pub priority: i32,
302 pub requested: usize,
304 pub allocated: usize,
306 pub outcome: AllocationOutcome,
308 pub reason: String,
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum AllocationOutcome {
315 Full,
317 Truncated,
319 Dropped,
321}
322
323#[derive(Debug, Clone)]
325pub struct BudgetSection {
326 pub name: String,
328 pub priority: i32,
330 pub estimated_tokens: usize,
332 pub minimum_tokens: Option<usize>,
334 pub required: bool,
336 pub weight: f32,
338}
339
340impl Default for BudgetSection {
341 fn default() -> Self {
342 Self {
343 name: String::new(),
344 priority: 0,
345 estimated_tokens: 0,
346 minimum_tokens: None,
347 required: false,
348 weight: 1.0,
349 }
350 }
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
355pub enum AllocationStrategy {
356 #[default]
358 GreedyPriority,
359 Proportional,
361 StrictPriority,
363}
364
365pub struct TokenBudgetEnforcer {
370 budget: usize,
372 allocated: AtomicUsize,
374 estimator: TokenEstimator,
376 reserved: usize,
378 strategy: AllocationStrategy,
380}
381
382#[derive(Debug, Clone)]
384pub struct TokenBudgetConfig {
385 pub total_budget: usize,
387 pub reserved_tokens: usize,
389 pub strict: bool,
391 pub default_priority: i32,
393 pub strategy: AllocationStrategy,
395}
396
397impl Default for TokenBudgetConfig {
398 fn default() -> Self {
399 Self {
400 total_budget: 4096,
401 reserved_tokens: 100,
402 strict: false,
403 default_priority: 10,
404 strategy: AllocationStrategy::GreedyPriority,
405 }
406 }
407}
408
409impl TokenBudgetEnforcer {
410 pub fn new(config: TokenBudgetConfig) -> Self {
412 Self {
413 budget: config.total_budget,
414 allocated: AtomicUsize::new(0),
415 estimator: TokenEstimator::new(),
416 reserved: config.reserved_tokens,
417 strategy: config.strategy,
418 }
419 }
420
421 pub fn with_budget(budget: usize) -> Self {
423 Self {
424 budget,
425 allocated: AtomicUsize::new(0),
426 estimator: TokenEstimator::new(),
427 reserved: 0,
428 strategy: AllocationStrategy::GreedyPriority,
429 }
430 }
431
432 pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
434 Self {
435 budget,
436 allocated: AtomicUsize::new(0),
437 estimator,
438 reserved: 0,
439 strategy: AllocationStrategy::GreedyPriority,
440 }
441 }
442
443 pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
445 self.strategy = strategy;
446 self
447 }
448
449 pub fn reserve(&mut self, tokens: usize) {
451 self.reserved = tokens;
452 }
453
454 pub fn available(&self) -> usize {
456 let allocated = self.allocated.load(Ordering::Acquire);
457 self.budget.saturating_sub(self.reserved + allocated)
458 }
459
460 pub fn total_budget(&self) -> usize {
462 self.budget
463 }
464
465 pub fn allocated(&self) -> usize {
467 self.allocated.load(Ordering::Acquire)
468 }
469
470 pub fn try_allocate(&self, tokens: usize) -> bool {
472 loop {
473 let current = self.allocated.load(Ordering::Acquire);
474 let new_total = current + tokens;
475
476 if new_total + self.reserved > self.budget {
477 return false;
478 }
479
480 if self
481 .allocated
482 .compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
483 .is_ok()
484 {
485 return true;
486 }
487 }
489 }
490
491 pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
493 match self.strategy {
494 AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
495 AllocationStrategy::Proportional => self.allocate_proportional(sections),
496 AllocationStrategy::StrictPriority => self.allocate_strict(sections),
497 }
498 }
499
500 fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
502 let mut sorted: Vec<_> = sections.iter().collect();
504 sorted.sort_by_key(|s| s.priority);
505
506 let mut allocation = BudgetAllocation {
507 full_sections: Vec::new(),
508 truncated_sections: Vec::new(),
509 dropped_sections: Vec::new(),
510 tokens_allocated: 0,
511 tokens_remaining: self.budget.saturating_sub(self.reserved),
512 explain: Vec::new(),
513 };
514
515 for section in sorted {
516 let remaining = allocation.tokens_remaining;
517
518 if section.estimated_tokens <= remaining {
519 allocation.full_sections.push(section.name.clone());
521 allocation.tokens_allocated += section.estimated_tokens;
522 allocation.tokens_remaining -= section.estimated_tokens;
523 allocation.explain.push(AllocationDecision {
524 section: section.name.clone(),
525 priority: section.priority,
526 requested: section.estimated_tokens,
527 allocated: section.estimated_tokens,
528 outcome: AllocationOutcome::Full,
529 reason: format!("Fits in remaining budget ({} tokens)", remaining),
530 });
531 } else if let Some(min) = section.minimum_tokens {
532 if min <= remaining {
534 let truncated_to = remaining;
535 allocation.truncated_sections.push((
536 section.name.clone(),
537 section.estimated_tokens,
538 truncated_to,
539 ));
540 allocation.tokens_allocated += truncated_to;
541 allocation.explain.push(AllocationDecision {
542 section: section.name.clone(),
543 priority: section.priority,
544 requested: section.estimated_tokens,
545 allocated: truncated_to,
546 outcome: AllocationOutcome::Truncated,
547 reason: format!(
548 "Truncated from {} to {} tokens (min: {})",
549 section.estimated_tokens, truncated_to, min
550 ),
551 });
552 allocation.tokens_remaining = 0;
553 } else {
554 allocation.dropped_sections.push(section.name.clone());
555 allocation.explain.push(AllocationDecision {
556 section: section.name.clone(),
557 priority: section.priority,
558 requested: section.estimated_tokens,
559 allocated: 0,
560 outcome: AllocationOutcome::Dropped,
561 reason: format!("Minimum {} exceeds remaining {} tokens", min, remaining),
562 });
563 }
564 } else {
565 allocation.dropped_sections.push(section.name.clone());
567 allocation.explain.push(AllocationDecision {
568 section: section.name.clone(),
569 priority: section.priority,
570 requested: section.estimated_tokens,
571 allocated: 0,
572 outcome: AllocationOutcome::Dropped,
573 reason: format!(
574 "Requested {} exceeds remaining {} (no truncation allowed)",
575 section.estimated_tokens, remaining
576 ),
577 });
578 }
579 }
580
581 allocation
582 }
583
584 fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
591 let available = self.budget.saturating_sub(self.reserved);
592 let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
593
594 if total_weight == 0.0 {
595 return self.allocate_greedy(sections);
596 }
597
598 let mut allocation = BudgetAllocation {
599 full_sections: Vec::new(),
600 truncated_sections: Vec::new(),
601 dropped_sections: Vec::new(),
602 tokens_allocated: 0,
603 tokens_remaining: available,
604 explain: Vec::new(),
605 };
606
607 let mut allocations: Vec<(usize, usize, bool)> = sections
609 .iter()
610 .map(|s| {
611 let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
612 let capped = proportional.min(s.estimated_tokens);
613 let min = s.minimum_tokens.unwrap_or(0);
614 (
615 capped.max(min),
616 s.estimated_tokens,
617 capped < s.estimated_tokens,
618 )
619 })
620 .collect();
621
622 let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
624
625 while total > available {
627 let max_idx = allocations
629 .iter()
630 .enumerate()
631 .filter(|(i, (a, _, _))| *a > sections[*i].minimum_tokens.unwrap_or(0))
632 .max_by_key(|(_, (a, _, _))| *a)
633 .map(|(i, _)| i);
634
635 match max_idx {
636 Some(idx) => {
637 let reduce = (total - available)
638 .min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
639 allocations[idx].0 -= reduce;
640 total -= reduce;
641 }
642 None => break, }
644 }
645
646 for (i, section) in sections.iter().enumerate() {
648 let (allocated, requested, truncated) = allocations[i];
649
650 if allocated == 0 {
651 allocation.dropped_sections.push(section.name.clone());
652 allocation.explain.push(AllocationDecision {
653 section: section.name.clone(),
654 priority: section.priority,
655 requested,
656 allocated: 0,
657 outcome: AllocationOutcome::Dropped,
658 reason: "No budget available after proportional allocation".to_string(),
659 });
660 } else if truncated {
661 allocation
662 .truncated_sections
663 .push((section.name.clone(), requested, allocated));
664 allocation.tokens_allocated += allocated;
665 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
666 allocation.explain.push(AllocationDecision {
667 section: section.name.clone(),
668 priority: section.priority,
669 requested,
670 allocated,
671 outcome: AllocationOutcome::Truncated,
672 reason: format!(
673 "Proportional allocation: {:.1}% of budget (weight {:.1})",
674 (allocated as f32 / available as f32) * 100.0,
675 section.weight
676 ),
677 });
678 } else {
679 allocation.full_sections.push(section.name.clone());
680 allocation.tokens_allocated += allocated;
681 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
682 allocation.explain.push(AllocationDecision {
683 section: section.name.clone(),
684 priority: section.priority,
685 requested,
686 allocated,
687 outcome: AllocationOutcome::Full,
688 reason: format!(
689 "Full allocation within proportional budget (weight {:.1})",
690 section.weight
691 ),
692 });
693 }
694 }
695
696 allocation
697 }
698
699 fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
701 let mut sorted: Vec<_> = sections.iter().collect();
702 sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
703
704 let mut allocation = BudgetAllocation {
706 full_sections: Vec::new(),
707 truncated_sections: Vec::new(),
708 dropped_sections: Vec::new(),
709 tokens_allocated: 0,
710 tokens_remaining: self.budget.saturating_sub(self.reserved),
711 explain: Vec::new(),
712 };
713
714 for section in sorted.iter().filter(|s| s.required) {
716 let remaining = allocation.tokens_remaining;
717 let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
718
719 if section.estimated_tokens <= remaining {
720 allocation.full_sections.push(section.name.clone());
721 allocation.tokens_allocated += section.estimated_tokens;
722 allocation.tokens_remaining -= section.estimated_tokens;
723 allocation.explain.push(AllocationDecision {
724 section: section.name.clone(),
725 priority: section.priority,
726 requested: section.estimated_tokens,
727 allocated: section.estimated_tokens,
728 outcome: AllocationOutcome::Full,
729 reason: "Required section - full allocation".to_string(),
730 });
731 } else if min <= remaining {
732 allocation.truncated_sections.push((
733 section.name.clone(),
734 section.estimated_tokens,
735 remaining,
736 ));
737 allocation.tokens_allocated += remaining;
738 allocation.explain.push(AllocationDecision {
739 section: section.name.clone(),
740 priority: section.priority,
741 requested: section.estimated_tokens,
742 allocated: remaining,
743 outcome: AllocationOutcome::Truncated,
744 reason: "Required section - truncated to fit".to_string(),
745 });
746 allocation.tokens_remaining = 0;
747 }
748 }
750
751 for section in sorted.iter().filter(|s| !s.required) {
753 let remaining = allocation.tokens_remaining;
754
755 if remaining == 0 {
756 allocation.dropped_sections.push(section.name.clone());
757 allocation.explain.push(AllocationDecision {
758 section: section.name.clone(),
759 priority: section.priority,
760 requested: section.estimated_tokens,
761 allocated: 0,
762 outcome: AllocationOutcome::Dropped,
763 reason: "No budget remaining after required sections".to_string(),
764 });
765 continue;
766 }
767
768 if section.estimated_tokens <= remaining {
769 allocation.full_sections.push(section.name.clone());
770 allocation.tokens_allocated += section.estimated_tokens;
771 allocation.tokens_remaining -= section.estimated_tokens;
772 allocation.explain.push(AllocationDecision {
773 section: section.name.clone(),
774 priority: section.priority,
775 requested: section.estimated_tokens,
776 allocated: section.estimated_tokens,
777 outcome: AllocationOutcome::Full,
778 reason: "Optional section - fits in remaining budget".to_string(),
779 });
780 } else if let Some(min) = section.minimum_tokens {
781 if min <= remaining {
782 allocation.truncated_sections.push((
783 section.name.clone(),
784 section.estimated_tokens,
785 remaining,
786 ));
787 allocation.tokens_allocated += remaining;
788 allocation.explain.push(AllocationDecision {
789 section: section.name.clone(),
790 priority: section.priority,
791 requested: section.estimated_tokens,
792 allocated: remaining,
793 outcome: AllocationOutcome::Truncated,
794 reason: "Optional section - truncated to fit".to_string(),
795 });
796 allocation.tokens_remaining = 0;
797 } else {
798 allocation.dropped_sections.push(section.name.clone());
799 allocation.explain.push(AllocationDecision {
800 section: section.name.clone(),
801 priority: section.priority,
802 requested: section.estimated_tokens,
803 allocated: 0,
804 outcome: AllocationOutcome::Dropped,
805 reason: format!("Minimum {} exceeds remaining {}", min, remaining),
806 });
807 }
808 } else {
809 allocation.dropped_sections.push(section.name.clone());
810 allocation.explain.push(AllocationDecision {
811 section: section.name.clone(),
812 priority: section.priority,
813 requested: section.estimated_tokens,
814 allocated: 0,
815 outcome: AllocationOutcome::Dropped,
816 reason: format!(
817 "Requested {} exceeds remaining {}",
818 section.estimated_tokens, remaining
819 ),
820 });
821 }
822 }
823
824 allocation
825 }
826
827 pub fn reset(&self) {
829 self.allocated.store(0, Ordering::Release);
830 }
831
832 pub fn estimator(&self) -> &TokenEstimator {
834 &self.estimator
835 }
836}
837
838impl BudgetAllocation {
843 pub fn explain_text(&self) -> String {
845 let mut output = String::new();
846 output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
847 output.push_str(&format!(
848 "Total Allocated: {} tokens\n",
849 self.tokens_allocated
850 ));
851 output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
852
853 output.push_str("SECTIONS:\n");
854 for decision in &self.explain {
855 let status = match decision.outcome {
856 AllocationOutcome::Full => "✓ FULL",
857 AllocationOutcome::Truncated => "◐ TRUNCATED",
858 AllocationOutcome::Dropped => "✗ DROPPED",
859 };
860 output.push_str(&format!(
861 " [{:^12}] {} (priority {})\n",
862 status, decision.section, decision.priority
863 ));
864 output.push_str(&format!(
865 " Requested: {}, Allocated: {}\n",
866 decision.requested, decision.allocated
867 ));
868 output.push_str(&format!(" Reason: {}\n", decision.reason));
869 }
870
871 output
872 }
873
874 pub fn explain_json(&self) -> String {
876 serde_json::to_string_pretty(&ExplainOutput {
877 tokens_allocated: self.tokens_allocated,
878 tokens_remaining: self.tokens_remaining,
879 full_sections: self.full_sections.clone(),
880 truncated_sections: self.truncated_sections.clone(),
881 dropped_sections: self.dropped_sections.clone(),
882 decisions: self
883 .explain
884 .iter()
885 .map(|d| ExplainDecision {
886 section: d.section.clone(),
887 priority: d.priority,
888 requested: d.requested,
889 allocated: d.allocated,
890 outcome: format!("{:?}", d.outcome),
891 reason: d.reason.clone(),
892 })
893 .collect(),
894 })
895 .unwrap_or_else(|_| "{}".to_string())
896 }
897}
898
899#[derive(serde::Serialize)]
900struct ExplainOutput {
901 tokens_allocated: usize,
902 tokens_remaining: usize,
903 full_sections: Vec<String>,
904 truncated_sections: Vec<(String, usize, usize)>,
905 dropped_sections: Vec<String>,
906 decisions: Vec<ExplainDecision>,
907}
908
909#[derive(serde::Serialize)]
910struct ExplainDecision {
911 section: String,
912 priority: i32,
913 requested: usize,
914 allocated: usize,
915 outcome: String,
916 reason: String,
917}
918
919pub fn truncate_to_tokens(
925 text: &str,
926 max_tokens: usize,
927 estimator: &TokenEstimator,
928 suffix: &str,
929) -> String {
930 let current = estimator.estimate_text(text);
931
932 if current <= max_tokens {
933 return text.to_string();
934 }
935
936 let suffix_tokens = estimator.estimate_text(suffix);
937 let target_tokens = max_tokens.saturating_sub(suffix_tokens);
938
939 if target_tokens == 0 {
940 return suffix.to_string();
941 }
942
943 let mut low = 0;
945 let mut high = text.len();
946
947 while low < high {
948 let mid = (low + high).div_ceil(2);
949
950 let boundary = text
952 .char_indices()
953 .take_while(|(i, _)| *i < mid)
954 .last()
955 .map(|(i, c)| i + c.len_utf8())
956 .unwrap_or(0);
957
958 let truncated = &text[..boundary];
959 let tokens = estimator.estimate_text(truncated);
960
961 if tokens <= target_tokens {
962 low = boundary;
963 } else {
964 high = boundary.saturating_sub(1);
965 }
966 }
967
968 let truncated = &text[..low];
970 let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
971
972 format!("{}{}", &text[..word_boundary], suffix)
973}
974
975pub fn truncate_rows(
977 rows: &[Vec<SochValue>],
978 max_tokens: usize,
979 estimator: &TokenEstimator,
980) -> Vec<Vec<SochValue>> {
981 let mut result = Vec::new();
982 let mut used = 0;
983
984 for row in rows {
985 let row_tokens = estimator.estimate_row(row);
986
987 if used + row_tokens <= max_tokens {
988 result.push(row.clone());
989 used += row_tokens;
990 } else {
991 break; }
993 }
994
995 result
996}
997
998#[cfg(test)]
1003mod tests {
1004 use super::*;
1005
1006 #[test]
1007 fn test_estimate_value_int() {
1008 let est = TokenEstimator::new();
1009
1010 assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
1012 assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
1013
1014 let small = est.estimate_value(&SochValue::Int(42));
1016 let large = est.estimate_value(&SochValue::Int(1_000_000_000));
1017 assert!(large >= small);
1018 }
1019
1020 #[test]
1021 fn test_estimate_value_text() {
1022 let est = TokenEstimator::new();
1023
1024 let short = est.estimate_value(&SochValue::Text("hello".to_string()));
1025 let long = est.estimate_value(&SochValue::Text(
1026 "hello world this is a longer string".to_string(),
1027 ));
1028
1029 assert!(long > short);
1030 }
1031
1032 #[test]
1033 #[allow(clippy::approx_constant)]
1034 fn test_estimate_row() {
1035 let est = TokenEstimator::new();
1036
1037 let row = vec![
1038 SochValue::Int(1),
1039 SochValue::Text("Alice".to_string()),
1040 SochValue::Float(3.14),
1041 ];
1042
1043 let tokens = est.estimate_row(&row);
1044
1045 assert!(tokens >= 3); }
1048
1049 #[test]
1050 fn test_estimate_table() {
1051 let est = TokenEstimator::new();
1052
1053 let columns = vec!["id".to_string(), "name".to_string()];
1054 let rows = vec![
1055 vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
1056 vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
1057 ];
1058
1059 let tokens = est.estimate_table("users", &columns, &rows);
1060
1061 assert!(tokens > est.estimate_row(&rows[0]) * 2);
1063 }
1064
1065 #[test]
1066 fn test_budget_enforcer_allocation() {
1067 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1068
1069 assert!(enforcer.try_allocate(500));
1070 assert_eq!(enforcer.allocated(), 500);
1071 assert_eq!(enforcer.available(), 500);
1072
1073 assert!(enforcer.try_allocate(400));
1074 assert_eq!(enforcer.allocated(), 900);
1075
1076 assert!(!enforcer.try_allocate(200));
1078 assert_eq!(enforcer.allocated(), 900);
1079 }
1080
1081 #[test]
1082 fn test_budget_enforcer_reset() {
1083 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1084
1085 enforcer.try_allocate(800);
1086 assert_eq!(enforcer.allocated(), 800);
1087
1088 enforcer.reset();
1089 assert_eq!(enforcer.allocated(), 0);
1090 }
1091
1092 #[test]
1093 fn test_allocate_sections() {
1094 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1095
1096 let sections = vec![
1097 BudgetSection {
1098 name: "A".to_string(),
1099 priority: 0,
1100 estimated_tokens: 300,
1101 minimum_tokens: None,
1102 required: true,
1103 weight: 1.0,
1104 },
1105 BudgetSection {
1106 name: "B".to_string(),
1107 priority: 1,
1108 estimated_tokens: 400,
1109 minimum_tokens: Some(200),
1110 required: false,
1111 weight: 1.0,
1112 },
1113 BudgetSection {
1114 name: "C".to_string(),
1115 priority: 2,
1116 estimated_tokens: 500,
1117 minimum_tokens: None,
1118 required: false,
1119 weight: 1.0,
1120 },
1121 ];
1122
1123 let allocation = enforcer.allocate_sections(§ions);
1124
1125 assert!(allocation.full_sections.contains(&"A".to_string()));
1127
1128 assert!(allocation.dropped_sections.contains(&"C".to_string()));
1131
1132 assert!(allocation.tokens_allocated <= 1000);
1133 }
1134
1135 #[test]
1136 fn test_allocate_by_priority() {
1137 let enforcer = TokenBudgetEnforcer::with_budget(500);
1138
1139 let sections = vec![
1140 BudgetSection {
1141 name: "LowPriority".to_string(),
1142 priority: 10,
1143 estimated_tokens: 200,
1144 minimum_tokens: None,
1145 required: false,
1146 weight: 1.0,
1147 },
1148 BudgetSection {
1149 name: "HighPriority".to_string(),
1150 priority: 0,
1151 estimated_tokens: 400,
1152 minimum_tokens: None,
1153 required: true,
1154 weight: 1.0,
1155 },
1156 ];
1157
1158 let allocation = enforcer.allocate_sections(§ions);
1159
1160 assert!(
1162 allocation
1163 .full_sections
1164 .contains(&"HighPriority".to_string())
1165 );
1166
1167 assert!(
1169 allocation
1170 .dropped_sections
1171 .contains(&"LowPriority".to_string())
1172 );
1173 }
1174
1175 #[test]
1176 fn test_truncate_to_tokens() {
1177 let est = TokenEstimator::new();
1178
1179 let text = "This is a long text that needs to be truncated to fit within the token budget";
1180 let truncated = truncate_to_tokens(text, 10, &est, "...");
1181
1182 assert!(truncated.len() < text.len());
1184
1185 assert!(truncated.ends_with("..."));
1187
1188 assert!(est.estimate_text(&truncated) <= 10);
1190 }
1191
1192 #[test]
1193 fn test_truncate_rows() {
1194 let est = TokenEstimator::new();
1195
1196 let rows: Vec<Vec<SochValue>> = (0..100)
1197 .map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
1198 .collect();
1199
1200 let truncated = truncate_rows(&rows, 50, &est);
1201
1202 assert!(truncated.len() < rows.len());
1204
1205 let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
1207 assert!(total <= 50);
1208 }
1209
1210 #[test]
1211 fn test_reserved_budget() {
1212 let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
1213 enforcer.reserve(200);
1214
1215 assert_eq!(enforcer.available(), 800);
1216
1217 assert!(enforcer.try_allocate(700));
1218 assert_eq!(enforcer.available(), 100);
1219
1220 assert!(!enforcer.try_allocate(200));
1222 }
1223
1224 #[test]
1225 fn test_estimator_configs() {
1226 let default = TokenEstimator::new();
1227 let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
1228 let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
1229
1230 let text = "Hello, this is a test string for comparing token estimation across different configurations.";
1231
1232 let default_est = default.estimate_text(text);
1233 let gpt4_est = gpt4.estimate_text(text);
1234 let conservative_est = conservative.estimate_text(text);
1235
1236 assert!(conservative_est >= default_est);
1238
1239 assert!(default_est > 0);
1241 assert!(gpt4_est > 0);
1242 assert!(conservative_est > 0);
1243 }
1244
1245 #[test]
1246 fn test_section_with_truncation() {
1247 let enforcer = TokenBudgetEnforcer::with_budget(600);
1248
1249 let sections = vec![
1250 BudgetSection {
1251 name: "Required".to_string(),
1252 priority: 0,
1253 estimated_tokens: 500,
1254 minimum_tokens: None,
1255 required: true,
1256 weight: 1.0,
1257 },
1258 BudgetSection {
1259 name: "Optional".to_string(),
1260 priority: 1,
1261 estimated_tokens: 300,
1262 minimum_tokens: Some(50), required: false,
1264 weight: 1.0,
1265 },
1266 ];
1267
1268 let allocation = enforcer.allocate_sections(§ions);
1269
1270 assert!(allocation.full_sections.contains(&"Required".to_string()));
1272
1273 assert!(
1275 allocation
1276 .truncated_sections
1277 .iter()
1278 .any(|(n, _, _)| n == "Optional")
1279 );
1280 }
1281}