1use crate::soch_ql::SochValue;
45use std::sync::atomic::{AtomicUsize, Ordering};
46
47#[derive(Debug, Clone)]
53pub struct TokenEstimatorConfig {
54 pub int_factor: f32,
56 pub float_factor: f32,
58 pub string_factor: f32,
60 pub hex_factor: f32,
62 pub bytes_per_token: f32,
64 pub separator_tokens: usize,
66 pub newline_tokens: usize,
68 pub header_tokens: usize,
70}
71
72impl Default for TokenEstimatorConfig {
73 fn default() -> Self {
74 Self {
75 int_factor: 1.0,
76 float_factor: 1.2,
77 string_factor: 1.1,
78 hex_factor: 2.5,
79 bytes_per_token: 4.0, separator_tokens: 1,
81 newline_tokens: 1,
82 header_tokens: 10, }
84 }
85}
86
87impl TokenEstimatorConfig {
88 pub fn gpt4() -> Self {
90 Self {
91 bytes_per_token: 3.8,
92 ..Default::default()
93 }
94 }
95
96 pub fn claude() -> Self {
98 Self {
99 bytes_per_token: 4.2,
100 ..Default::default()
101 }
102 }
103
104 pub fn conservative() -> Self {
106 Self {
107 int_factor: 1.2,
108 float_factor: 1.4,
109 string_factor: 1.3,
110 hex_factor: 3.0,
111 bytes_per_token: 3.5,
112 ..Default::default()
113 }
114 }
115}
116
117pub struct TokenEstimator {
119 config: TokenEstimatorConfig,
120}
121
122impl TokenEstimator {
123 pub fn new() -> Self {
125 Self {
126 config: TokenEstimatorConfig::default(),
127 }
128 }
129
130 pub fn with_config(config: TokenEstimatorConfig) -> Self {
132 Self { config }
133 }
134
135 pub fn estimate_value(&self, value: &SochValue) -> usize {
137 match value {
138 SochValue::Null => 1,
139 SochValue::Bool(_) => 1, SochValue::Int(n) => {
141 let digits = if *n == 0 {
143 1
144 } else {
145 ((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
146 };
147 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
148 as usize
149 }
150 SochValue::UInt(n) => {
151 let digits = if *n == 0 {
152 1
153 } else {
154 ((*n as f64).log10().ceil() as usize).max(1)
155 };
156 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
157 as usize
158 }
159 SochValue::Float(f) => {
160 let s = format!("{:.2}", f);
162 ((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
163 as usize
164 }
165 SochValue::Text(s) => {
166 ((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
168 as usize
169 }
170 SochValue::Binary(b) => {
171 let hex_len = 2 + b.len() * 2;
173 ((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
174 as usize
175 }
176 SochValue::Array(arr) => {
177 let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
179 let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
180 2 + elem_tokens + separator_tokens }
182 }
183 }
184
185 pub fn estimate_row(&self, values: &[SochValue]) -> usize {
187 if values.is_empty() {
188 return 0;
189 }
190
191 let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
192 let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
193 let newline = self.config.newline_tokens;
194
195 value_tokens + separator_tokens + newline
196 }
197
198 pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
200 let base = self.config.header_tokens;
202 let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
203 let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
204 let col_tokens: usize = columns
205 .iter()
206 .map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
207 .sum();
208
209 base + table_tokens + count_tokens + col_tokens
210 }
211
212 pub fn estimate_table(
214 &self,
215 table: &str,
216 columns: &[String],
217 rows: &[Vec<SochValue>],
218 ) -> usize {
219 let header = self.estimate_header(table, columns, rows.len());
220 let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
221 header + row_tokens
222 }
223
224 pub fn estimate_text(&self, text: &str) -> usize {
226 ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize
227 }
228
229 pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
233 truncate_to_tokens(text, max_tokens, self, "...")
234 }
235}
236
237impl Default for TokenEstimator {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[derive(Debug, Clone)]
249pub struct BudgetAllocation {
250 pub full_sections: Vec<String>,
252 pub truncated_sections: Vec<(String, usize, usize)>,
254 pub dropped_sections: Vec<String>,
256 pub tokens_allocated: usize,
258 pub tokens_remaining: usize,
260 pub explain: Vec<AllocationDecision>,
262}
263
264#[derive(Debug, Clone)]
266pub struct AllocationDecision {
267 pub section: String,
269 pub priority: i32,
271 pub requested: usize,
273 pub allocated: usize,
275 pub outcome: AllocationOutcome,
277 pub reason: String,
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
283pub enum AllocationOutcome {
284 Full,
286 Truncated,
288 Dropped,
290}
291
292#[derive(Debug, Clone)]
294pub struct BudgetSection {
295 pub name: String,
297 pub priority: i32,
299 pub estimated_tokens: usize,
301 pub minimum_tokens: Option<usize>,
303 pub required: bool,
305 pub weight: f32,
307}
308
309impl Default for BudgetSection {
310 fn default() -> Self {
311 Self {
312 name: String::new(),
313 priority: 0,
314 estimated_tokens: 0,
315 minimum_tokens: None,
316 required: false,
317 weight: 1.0,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
324pub enum AllocationStrategy {
325 #[default]
327 GreedyPriority,
328 Proportional,
330 StrictPriority,
332}
333
334pub struct TokenBudgetEnforcer {
339 budget: usize,
341 allocated: AtomicUsize,
343 estimator: TokenEstimator,
345 reserved: usize,
347 strategy: AllocationStrategy,
349}
350
351#[derive(Debug, Clone)]
353pub struct TokenBudgetConfig {
354 pub total_budget: usize,
356 pub reserved_tokens: usize,
358 pub strict: bool,
360 pub default_priority: i32,
362 pub strategy: AllocationStrategy,
364}
365
366impl Default for TokenBudgetConfig {
367 fn default() -> Self {
368 Self {
369 total_budget: 4096,
370 reserved_tokens: 100,
371 strict: false,
372 default_priority: 10,
373 strategy: AllocationStrategy::GreedyPriority,
374 }
375 }
376}
377
378impl TokenBudgetEnforcer {
379 pub fn new(config: TokenBudgetConfig) -> Self {
381 Self {
382 budget: config.total_budget,
383 allocated: AtomicUsize::new(0),
384 estimator: TokenEstimator::new(),
385 reserved: config.reserved_tokens,
386 strategy: config.strategy,
387 }
388 }
389
390 pub fn with_budget(budget: usize) -> Self {
392 Self {
393 budget,
394 allocated: AtomicUsize::new(0),
395 estimator: TokenEstimator::new(),
396 reserved: 0,
397 strategy: AllocationStrategy::GreedyPriority,
398 }
399 }
400
401 pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
403 Self {
404 budget,
405 allocated: AtomicUsize::new(0),
406 estimator,
407 reserved: 0,
408 strategy: AllocationStrategy::GreedyPriority,
409 }
410 }
411
412 pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
414 self.strategy = strategy;
415 self
416 }
417
418 pub fn reserve(&mut self, tokens: usize) {
420 self.reserved = tokens;
421 }
422
423 pub fn available(&self) -> usize {
425 let allocated = self.allocated.load(Ordering::Acquire);
426 self.budget.saturating_sub(self.reserved + allocated)
427 }
428
429 pub fn total_budget(&self) -> usize {
431 self.budget
432 }
433
434 pub fn allocated(&self) -> usize {
436 self.allocated.load(Ordering::Acquire)
437 }
438
439 pub fn try_allocate(&self, tokens: usize) -> bool {
441 loop {
442 let current = self.allocated.load(Ordering::Acquire);
443 let new_total = current + tokens;
444
445 if new_total + self.reserved > self.budget {
446 return false;
447 }
448
449 if self
450 .allocated
451 .compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
452 .is_ok()
453 {
454 return true;
455 }
456 }
458 }
459
460 pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
462 match self.strategy {
463 AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
464 AllocationStrategy::Proportional => self.allocate_proportional(sections),
465 AllocationStrategy::StrictPriority => self.allocate_strict(sections),
466 }
467 }
468
469 fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
471 let mut sorted: Vec<_> = sections.iter().collect();
473 sorted.sort_by_key(|s| s.priority);
474
475 let mut allocation = BudgetAllocation {
476 full_sections: Vec::new(),
477 truncated_sections: Vec::new(),
478 dropped_sections: Vec::new(),
479 tokens_allocated: 0,
480 tokens_remaining: self.budget.saturating_sub(self.reserved),
481 explain: Vec::new(),
482 };
483
484 for section in sorted {
485 let remaining = allocation.tokens_remaining;
486
487 if section.estimated_tokens <= remaining {
488 allocation.full_sections.push(section.name.clone());
490 allocation.tokens_allocated += section.estimated_tokens;
491 allocation.tokens_remaining -= section.estimated_tokens;
492 allocation.explain.push(AllocationDecision {
493 section: section.name.clone(),
494 priority: section.priority,
495 requested: section.estimated_tokens,
496 allocated: section.estimated_tokens,
497 outcome: AllocationOutcome::Full,
498 reason: format!("Fits in remaining budget ({} tokens)", remaining),
499 });
500 } else if let Some(min) = section.minimum_tokens {
501 if min <= remaining {
503 let truncated_to = remaining;
504 allocation.truncated_sections.push((
505 section.name.clone(),
506 section.estimated_tokens,
507 truncated_to,
508 ));
509 allocation.tokens_allocated += truncated_to;
510 allocation.explain.push(AllocationDecision {
511 section: section.name.clone(),
512 priority: section.priority,
513 requested: section.estimated_tokens,
514 allocated: truncated_to,
515 outcome: AllocationOutcome::Truncated,
516 reason: format!(
517 "Truncated from {} to {} tokens (min: {})",
518 section.estimated_tokens, truncated_to, min
519 ),
520 });
521 allocation.tokens_remaining = 0;
522 } else {
523 allocation.dropped_sections.push(section.name.clone());
524 allocation.explain.push(AllocationDecision {
525 section: section.name.clone(),
526 priority: section.priority,
527 requested: section.estimated_tokens,
528 allocated: 0,
529 outcome: AllocationOutcome::Dropped,
530 reason: format!(
531 "Minimum {} exceeds remaining {} tokens",
532 min, remaining
533 ),
534 });
535 }
536 } else {
537 allocation.dropped_sections.push(section.name.clone());
539 allocation.explain.push(AllocationDecision {
540 section: section.name.clone(),
541 priority: section.priority,
542 requested: section.estimated_tokens,
543 allocated: 0,
544 outcome: AllocationOutcome::Dropped,
545 reason: format!(
546 "Requested {} exceeds remaining {} (no truncation allowed)",
547 section.estimated_tokens, remaining
548 ),
549 });
550 }
551 }
552
553 allocation
554 }
555
556 fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
563 let available = self.budget.saturating_sub(self.reserved);
564 let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
565
566 if total_weight == 0.0 {
567 return self.allocate_greedy(sections);
568 }
569
570 let mut allocation = BudgetAllocation {
571 full_sections: Vec::new(),
572 truncated_sections: Vec::new(),
573 dropped_sections: Vec::new(),
574 tokens_allocated: 0,
575 tokens_remaining: available,
576 explain: Vec::new(),
577 };
578
579 let mut allocations: Vec<(usize, usize, bool)> = sections
581 .iter()
582 .map(|s| {
583 let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
584 let capped = proportional.min(s.estimated_tokens);
585 let min = s.minimum_tokens.unwrap_or(0);
586 (capped.max(min), s.estimated_tokens, capped < s.estimated_tokens)
587 })
588 .collect();
589
590 let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
592
593 while total > available {
595 let max_idx = allocations
597 .iter()
598 .enumerate()
599 .filter(|(i, (a, _, _))| {
600 *a > sections[*i].minimum_tokens.unwrap_or(0)
601 })
602 .max_by_key(|(_, (a, _, _))| *a)
603 .map(|(i, _)| i);
604
605 match max_idx {
606 Some(idx) => {
607 let reduce = (total - available).min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
608 allocations[idx].0 -= reduce;
609 total -= reduce;
610 }
611 None => break, }
613 }
614
615 for (i, section) in sections.iter().enumerate() {
617 let (allocated, requested, truncated) = allocations[i];
618
619 if allocated == 0 {
620 allocation.dropped_sections.push(section.name.clone());
621 allocation.explain.push(AllocationDecision {
622 section: section.name.clone(),
623 priority: section.priority,
624 requested,
625 allocated: 0,
626 outcome: AllocationOutcome::Dropped,
627 reason: "No budget available after proportional allocation".to_string(),
628 });
629 } else if truncated {
630 allocation.truncated_sections.push((
631 section.name.clone(),
632 requested,
633 allocated,
634 ));
635 allocation.tokens_allocated += allocated;
636 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
637 allocation.explain.push(AllocationDecision {
638 section: section.name.clone(),
639 priority: section.priority,
640 requested,
641 allocated,
642 outcome: AllocationOutcome::Truncated,
643 reason: format!(
644 "Proportional allocation: {:.1}% of budget (weight {:.1})",
645 (allocated as f32 / available as f32) * 100.0,
646 section.weight
647 ),
648 });
649 } else {
650 allocation.full_sections.push(section.name.clone());
651 allocation.tokens_allocated += allocated;
652 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
653 allocation.explain.push(AllocationDecision {
654 section: section.name.clone(),
655 priority: section.priority,
656 requested,
657 allocated,
658 outcome: AllocationOutcome::Full,
659 reason: format!(
660 "Full allocation within proportional budget (weight {:.1})",
661 section.weight
662 ),
663 });
664 }
665 }
666
667 allocation
668 }
669
670 fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
672 let mut sorted: Vec<_> = sections.iter().collect();
673 sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
674
675 let mut allocation = BudgetAllocation {
677 full_sections: Vec::new(),
678 truncated_sections: Vec::new(),
679 dropped_sections: Vec::new(),
680 tokens_allocated: 0,
681 tokens_remaining: self.budget.saturating_sub(self.reserved),
682 explain: Vec::new(),
683 };
684
685 for section in sorted.iter().filter(|s| s.required) {
687 let remaining = allocation.tokens_remaining;
688 let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
689
690 if section.estimated_tokens <= remaining {
691 allocation.full_sections.push(section.name.clone());
692 allocation.tokens_allocated += section.estimated_tokens;
693 allocation.tokens_remaining -= section.estimated_tokens;
694 allocation.explain.push(AllocationDecision {
695 section: section.name.clone(),
696 priority: section.priority,
697 requested: section.estimated_tokens,
698 allocated: section.estimated_tokens,
699 outcome: AllocationOutcome::Full,
700 reason: "Required section - full allocation".to_string(),
701 });
702 } else if min <= remaining {
703 allocation.truncated_sections.push((
704 section.name.clone(),
705 section.estimated_tokens,
706 remaining,
707 ));
708 allocation.tokens_allocated += remaining;
709 allocation.explain.push(AllocationDecision {
710 section: section.name.clone(),
711 priority: section.priority,
712 requested: section.estimated_tokens,
713 allocated: remaining,
714 outcome: AllocationOutcome::Truncated,
715 reason: "Required section - truncated to fit".to_string(),
716 });
717 allocation.tokens_remaining = 0;
718 }
719 }
721
722 for section in sorted.iter().filter(|s| !s.required) {
724 let remaining = allocation.tokens_remaining;
725
726 if remaining == 0 {
727 allocation.dropped_sections.push(section.name.clone());
728 allocation.explain.push(AllocationDecision {
729 section: section.name.clone(),
730 priority: section.priority,
731 requested: section.estimated_tokens,
732 allocated: 0,
733 outcome: AllocationOutcome::Dropped,
734 reason: "No budget remaining after required sections".to_string(),
735 });
736 continue;
737 }
738
739 if section.estimated_tokens <= remaining {
740 allocation.full_sections.push(section.name.clone());
741 allocation.tokens_allocated += section.estimated_tokens;
742 allocation.tokens_remaining -= section.estimated_tokens;
743 allocation.explain.push(AllocationDecision {
744 section: section.name.clone(),
745 priority: section.priority,
746 requested: section.estimated_tokens,
747 allocated: section.estimated_tokens,
748 outcome: AllocationOutcome::Full,
749 reason: "Optional section - fits in remaining budget".to_string(),
750 });
751 } else if let Some(min) = section.minimum_tokens {
752 if min <= remaining {
753 allocation.truncated_sections.push((
754 section.name.clone(),
755 section.estimated_tokens,
756 remaining,
757 ));
758 allocation.tokens_allocated += remaining;
759 allocation.explain.push(AllocationDecision {
760 section: section.name.clone(),
761 priority: section.priority,
762 requested: section.estimated_tokens,
763 allocated: remaining,
764 outcome: AllocationOutcome::Truncated,
765 reason: "Optional section - truncated to fit".to_string(),
766 });
767 allocation.tokens_remaining = 0;
768 } else {
769 allocation.dropped_sections.push(section.name.clone());
770 allocation.explain.push(AllocationDecision {
771 section: section.name.clone(),
772 priority: section.priority,
773 requested: section.estimated_tokens,
774 allocated: 0,
775 outcome: AllocationOutcome::Dropped,
776 reason: format!("Minimum {} exceeds remaining {}", min, remaining),
777 });
778 }
779 } else {
780 allocation.dropped_sections.push(section.name.clone());
781 allocation.explain.push(AllocationDecision {
782 section: section.name.clone(),
783 priority: section.priority,
784 requested: section.estimated_tokens,
785 allocated: 0,
786 outcome: AllocationOutcome::Dropped,
787 reason: format!("Requested {} exceeds remaining {}", section.estimated_tokens, remaining),
788 });
789 }
790 }
791
792 allocation
793 }
794
795 pub fn reset(&self) {
797 self.allocated.store(0, Ordering::Release);
798 }
799
800 pub fn estimator(&self) -> &TokenEstimator {
802 &self.estimator
803 }
804}
805
806impl BudgetAllocation {
811 pub fn explain_text(&self) -> String {
813 let mut output = String::new();
814 output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
815 output.push_str(&format!(
816 "Total Allocated: {} tokens\n",
817 self.tokens_allocated
818 ));
819 output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
820
821 output.push_str("SECTIONS:\n");
822 for decision in &self.explain {
823 let status = match decision.outcome {
824 AllocationOutcome::Full => "✓ FULL",
825 AllocationOutcome::Truncated => "◐ TRUNCATED",
826 AllocationOutcome::Dropped => "✗ DROPPED",
827 };
828 output.push_str(&format!(
829 " [{:^12}] {} (priority {})\n",
830 status, decision.section, decision.priority
831 ));
832 output.push_str(&format!(
833 " Requested: {}, Allocated: {}\n",
834 decision.requested, decision.allocated
835 ));
836 output.push_str(&format!(" Reason: {}\n", decision.reason));
837 }
838
839 output
840 }
841
842 pub fn explain_json(&self) -> String {
844 serde_json::to_string_pretty(&ExplainOutput {
845 tokens_allocated: self.tokens_allocated,
846 tokens_remaining: self.tokens_remaining,
847 full_sections: self.full_sections.clone(),
848 truncated_sections: self.truncated_sections.clone(),
849 dropped_sections: self.dropped_sections.clone(),
850 decisions: self.explain.iter().map(|d| ExplainDecision {
851 section: d.section.clone(),
852 priority: d.priority,
853 requested: d.requested,
854 allocated: d.allocated,
855 outcome: format!("{:?}", d.outcome),
856 reason: d.reason.clone(),
857 }).collect(),
858 }).unwrap_or_else(|_| "{}".to_string())
859 }
860}
861
862#[derive(serde::Serialize)]
863struct ExplainOutput {
864 tokens_allocated: usize,
865 tokens_remaining: usize,
866 full_sections: Vec<String>,
867 truncated_sections: Vec<(String, usize, usize)>,
868 dropped_sections: Vec<String>,
869 decisions: Vec<ExplainDecision>,
870}
871
872#[derive(serde::Serialize)]
873struct ExplainDecision {
874 section: String,
875 priority: i32,
876 requested: usize,
877 allocated: usize,
878 outcome: String,
879 reason: String,
880}
881
882pub fn truncate_to_tokens(
888 text: &str,
889 max_tokens: usize,
890 estimator: &TokenEstimator,
891 suffix: &str,
892) -> String {
893 let current = estimator.estimate_text(text);
894
895 if current <= max_tokens {
896 return text.to_string();
897 }
898
899 let suffix_tokens = estimator.estimate_text(suffix);
900 let target_tokens = max_tokens.saturating_sub(suffix_tokens);
901
902 if target_tokens == 0 {
903 return suffix.to_string();
904 }
905
906 let mut low = 0;
908 let mut high = text.len();
909
910 while low < high {
911 let mid = (low + high).div_ceil(2);
912
913 let boundary = text
915 .char_indices()
916 .take_while(|(i, _)| *i < mid)
917 .last()
918 .map(|(i, c)| i + c.len_utf8())
919 .unwrap_or(0);
920
921 let truncated = &text[..boundary];
922 let tokens = estimator.estimate_text(truncated);
923
924 if tokens <= target_tokens {
925 low = boundary;
926 } else {
927 high = boundary.saturating_sub(1);
928 }
929 }
930
931 let truncated = &text[..low];
933 let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
934
935 format!("{}{}", &text[..word_boundary], suffix)
936}
937
938pub fn truncate_rows(
940 rows: &[Vec<SochValue>],
941 max_tokens: usize,
942 estimator: &TokenEstimator,
943) -> Vec<Vec<SochValue>> {
944 let mut result = Vec::new();
945 let mut used = 0;
946
947 for row in rows {
948 let row_tokens = estimator.estimate_row(row);
949
950 if used + row_tokens <= max_tokens {
951 result.push(row.clone());
952 used += row_tokens;
953 } else {
954 break; }
956 }
957
958 result
959}
960
961#[cfg(test)]
966mod tests {
967 use super::*;
968
969 #[test]
970 fn test_estimate_value_int() {
971 let est = TokenEstimator::new();
972
973 assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
975 assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
976
977 let small = est.estimate_value(&SochValue::Int(42));
979 let large = est.estimate_value(&SochValue::Int(1_000_000_000));
980 assert!(large >= small);
981 }
982
983 #[test]
984 fn test_estimate_value_text() {
985 let est = TokenEstimator::new();
986
987 let short = est.estimate_value(&SochValue::Text("hello".to_string()));
988 let long = est.estimate_value(&SochValue::Text(
989 "hello world this is a longer string".to_string(),
990 ));
991
992 assert!(long > short);
993 }
994
995 #[test]
996 #[allow(clippy::approx_constant)]
997 fn test_estimate_row() {
998 let est = TokenEstimator::new();
999
1000 let row = vec![
1001 SochValue::Int(1),
1002 SochValue::Text("Alice".to_string()),
1003 SochValue::Float(3.14),
1004 ];
1005
1006 let tokens = est.estimate_row(&row);
1007
1008 assert!(tokens >= 3); }
1011
1012 #[test]
1013 fn test_estimate_table() {
1014 let est = TokenEstimator::new();
1015
1016 let columns = vec!["id".to_string(), "name".to_string()];
1017 let rows = vec![
1018 vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
1019 vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
1020 ];
1021
1022 let tokens = est.estimate_table("users", &columns, &rows);
1023
1024 assert!(tokens > est.estimate_row(&rows[0]) * 2);
1026 }
1027
1028 #[test]
1029 fn test_budget_enforcer_allocation() {
1030 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1031
1032 assert!(enforcer.try_allocate(500));
1033 assert_eq!(enforcer.allocated(), 500);
1034 assert_eq!(enforcer.available(), 500);
1035
1036 assert!(enforcer.try_allocate(400));
1037 assert_eq!(enforcer.allocated(), 900);
1038
1039 assert!(!enforcer.try_allocate(200));
1041 assert_eq!(enforcer.allocated(), 900);
1042 }
1043
1044 #[test]
1045 fn test_budget_enforcer_reset() {
1046 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1047
1048 enforcer.try_allocate(800);
1049 assert_eq!(enforcer.allocated(), 800);
1050
1051 enforcer.reset();
1052 assert_eq!(enforcer.allocated(), 0);
1053 }
1054
1055 #[test]
1056 fn test_allocate_sections() {
1057 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1058
1059 let sections = vec![
1060 BudgetSection {
1061 name: "A".to_string(),
1062 priority: 0,
1063 estimated_tokens: 300,
1064 minimum_tokens: None,
1065 required: true,
1066 weight: 1.0,
1067 },
1068 BudgetSection {
1069 name: "B".to_string(),
1070 priority: 1,
1071 estimated_tokens: 400,
1072 minimum_tokens: Some(200),
1073 required: false,
1074 weight: 1.0,
1075 },
1076 BudgetSection {
1077 name: "C".to_string(),
1078 priority: 2,
1079 estimated_tokens: 500,
1080 minimum_tokens: None,
1081 required: false,
1082 weight: 1.0,
1083 },
1084 ];
1085
1086 let allocation = enforcer.allocate_sections(§ions);
1087
1088 assert!(allocation.full_sections.contains(&"A".to_string()));
1090
1091 assert!(allocation.dropped_sections.contains(&"C".to_string()));
1094
1095 assert!(allocation.tokens_allocated <= 1000);
1096 }
1097
1098 #[test]
1099 fn test_allocate_by_priority() {
1100 let enforcer = TokenBudgetEnforcer::with_budget(500);
1101
1102 let sections = vec![
1103 BudgetSection {
1104 name: "LowPriority".to_string(),
1105 priority: 10,
1106 estimated_tokens: 200,
1107 minimum_tokens: None,
1108 required: false,
1109 weight: 1.0,
1110 },
1111 BudgetSection {
1112 name: "HighPriority".to_string(),
1113 priority: 0,
1114 estimated_tokens: 400,
1115 minimum_tokens: None,
1116 required: true,
1117 weight: 1.0,
1118 },
1119 ];
1120
1121 let allocation = enforcer.allocate_sections(§ions);
1122
1123 assert!(
1125 allocation
1126 .full_sections
1127 .contains(&"HighPriority".to_string())
1128 );
1129
1130 assert!(
1132 allocation
1133 .dropped_sections
1134 .contains(&"LowPriority".to_string())
1135 );
1136 }
1137
1138 #[test]
1139 fn test_truncate_to_tokens() {
1140 let est = TokenEstimator::new();
1141
1142 let text = "This is a long text that needs to be truncated to fit within the token budget";
1143 let truncated = truncate_to_tokens(text, 10, &est, "...");
1144
1145 assert!(truncated.len() < text.len());
1147
1148 assert!(truncated.ends_with("..."));
1150
1151 assert!(est.estimate_text(&truncated) <= 10);
1153 }
1154
1155 #[test]
1156 fn test_truncate_rows() {
1157 let est = TokenEstimator::new();
1158
1159 let rows: Vec<Vec<SochValue>> = (0..100)
1160 .map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
1161 .collect();
1162
1163 let truncated = truncate_rows(&rows, 50, &est);
1164
1165 assert!(truncated.len() < rows.len());
1167
1168 let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
1170 assert!(total <= 50);
1171 }
1172
1173 #[test]
1174 fn test_reserved_budget() {
1175 let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
1176 enforcer.reserve(200);
1177
1178 assert_eq!(enforcer.available(), 800);
1179
1180 assert!(enforcer.try_allocate(700));
1181 assert_eq!(enforcer.available(), 100);
1182
1183 assert!(!enforcer.try_allocate(200));
1185 }
1186
1187 #[test]
1188 fn test_estimator_configs() {
1189 let default = TokenEstimator::new();
1190 let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
1191 let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
1192
1193 let text = "Hello, this is a test string for comparing token estimation across different configurations.";
1194
1195 let default_est = default.estimate_text(text);
1196 let gpt4_est = gpt4.estimate_text(text);
1197 let conservative_est = conservative.estimate_text(text);
1198
1199 assert!(conservative_est >= default_est);
1201
1202 assert!(default_est > 0);
1204 assert!(gpt4_est > 0);
1205 assert!(conservative_est > 0);
1206 }
1207
1208 #[test]
1209 fn test_section_with_truncation() {
1210 let enforcer = TokenBudgetEnforcer::with_budget(600);
1211
1212 let sections = vec![
1213 BudgetSection {
1214 name: "Required".to_string(),
1215 priority: 0,
1216 estimated_tokens: 500,
1217 minimum_tokens: None,
1218 required: true,
1219 weight: 1.0,
1220 },
1221 BudgetSection {
1222 name: "Optional".to_string(),
1223 priority: 1,
1224 estimated_tokens: 300,
1225 minimum_tokens: Some(50), required: false,
1227 weight: 1.0,
1228 },
1229 ];
1230
1231 let allocation = enforcer.allocate_sections(§ions);
1232
1233 assert!(allocation.full_sections.contains(&"Required".to_string()));
1235
1236 assert!(
1238 allocation
1239 .truncated_sections
1240 .iter()
1241 .any(|(n, _, _)| n == "Optional")
1242 );
1243 }
1244}