Skip to main content

sochdb_query/
token_budget.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Token Budget Enforcement
19//!
20//! This module implements token estimation and budget tracking for
21//! CONTEXT SELECT queries.
22//!
23//! ## Token Estimation Model
24//!
25//! The token count for a row is estimated as:
26//!
27//! $$T_{row} = \sum_{i=1}^{C} T(v_i) + (C - 1) \times T_{sep}$$
28//!
29//! Where:
30//! - $T(v_i)$ = tokens for value $i$
31//! - $C$ = number of columns
32//! - $T_{sep}$ = separator token cost (~1 token per separator)
33//!
34//! ## Type-Specific Estimation
35//!
36//! Different data types have different token characteristics:
37//!
38//! | Type | Factor | Notes |
39//! |------|--------|-------|
40//! | Integer | 1.0 | ~1 token per 3-4 digits |
41//! | Float | 1.2 | Decimal point adds overhead |
42//! | String | 1.1 | Potential subword splits |
43//! | Binary (hex) | 2.5 | 0x prefix + hex expansion |
44//! | Boolean | 1.0 | "true"/"false" are single tokens |
45//! | Null | 1.0 | "null" is a single token |
46
47use crate::soch_ql::SochValue;
48use std::sync::atomic::{AtomicUsize, Ordering};
49
50// ============================================================================
51// Token Estimator
52// ============================================================================
53
54/// Token estimation configuration
55///
56/// **Important**: This estimator uses a linear bytes-per-token heuristic,
57/// not actual BPE tokenization.  Accuracy varies by content:
58///
59/// - English prose: ~5% error (well-calibrated)
60/// - CJK / non-Latin: up to 30% *under*-estimate
61/// - Code / URLs / special chars: up to 20% error
62///
63/// A `safety_margin` factor (default 1.15) is applied to all estimates
64/// to reduce the risk of exceeding the LLM's context window.  For
65/// exact token counting, use `ExactTokenCounter` with a BPE vocabulary.
66#[derive(Debug, Clone)]
67pub struct TokenEstimatorConfig {
68    /// Multiplier for integer values
69    pub int_factor: f32,
70    /// Multiplier for float values
71    pub float_factor: f32,
72    /// Multiplier for string values
73    pub string_factor: f32,
74    /// Multiplier for binary (hex) values
75    pub hex_factor: f32,
76    /// Bytes per token (approximate)
77    pub bytes_per_token: f32,
78    /// Safety margin multiplier applied to all estimates to prevent
79    /// context window overflow.  1.15 = 15% headroom.
80    pub safety_margin: f32,
81    /// Separator cost in tokens
82    pub separator_tokens: usize,
83    /// Newline cost in tokens
84    pub newline_tokens: usize,
85    /// Header overhead tokens
86    pub header_tokens: usize,
87}
88
89impl Default for TokenEstimatorConfig {
90    fn default() -> Self {
91        Self {
92            int_factor: 1.0,
93            float_factor: 1.2,
94            string_factor: 1.1,
95            hex_factor: 2.5,
96            bytes_per_token: 4.0, // ~4 chars per token for English
97            safety_margin: 1.15,  // 15% headroom for non-Latin and special chars
98            separator_tokens: 1,
99            newline_tokens: 1,
100            header_tokens: 10, // table[N]{cols}: header
101        }
102    }
103}
104
105impl TokenEstimatorConfig {
106    /// Create config tuned for GPT-4 tokenizer
107    pub fn gpt4() -> Self {
108        Self {
109            bytes_per_token: 3.8,
110            safety_margin: 1.15,
111            ..Default::default()
112        }
113    }
114
115    /// Create config tuned for Claude tokenizer
116    pub fn claude() -> Self {
117        Self {
118            bytes_per_token: 4.2,
119            safety_margin: 1.15,
120            ..Default::default()
121        }
122    }
123
124    /// Create config with high precision (conservative)
125    pub fn conservative() -> Self {
126        Self {
127            int_factor: 1.2,
128            float_factor: 1.4,
129            string_factor: 1.3,
130            hex_factor: 3.0,
131            bytes_per_token: 3.5,
132            safety_margin: 1.25, // 25% headroom for safety
133            ..Default::default()
134        }
135    }
136}
137
138/// Token estimator
139pub struct TokenEstimator {
140    config: TokenEstimatorConfig,
141}
142
143impl TokenEstimator {
144    /// Create a new estimator with default config
145    pub fn new() -> Self {
146        Self {
147            config: TokenEstimatorConfig::default(),
148        }
149    }
150
151    /// Create with custom config
152    pub fn with_config(config: TokenEstimatorConfig) -> Self {
153        Self { config }
154    }
155
156    /// Estimate tokens for a single value.
157    ///
158    /// Applies `safety_margin` to the raw estimate to reduce the risk
159    /// of exceeding the LLM context window with non-Latin or structured text.
160    pub fn estimate_value(&self, value: &SochValue) -> usize {
161        let raw = self.estimate_value_raw(value);
162        ((raw as f32) * self.config.safety_margin).ceil() as usize
163    }
164
165    /// Raw token estimate without safety margin (for internal use / testing).
166    fn estimate_value_raw(&self, value: &SochValue) -> usize {
167        match value {
168            SochValue::Null => 1,
169            SochValue::Bool(_) => 1, // "true" or "false" is typically 1 token
170            SochValue::Int(n) => {
171                // Count digits + sign
172                let digits = if *n == 0 {
173                    1
174                } else {
175                    ((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
176                };
177                ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
178                    as usize
179            }
180            SochValue::UInt(n) => {
181                let digits = if *n == 0 {
182                    1
183                } else {
184                    ((*n as f64).log10().ceil() as usize).max(1)
185                };
186                ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
187                    as usize
188            }
189            SochValue::Float(f) => {
190                // Format to 2 decimal places
191                let s = format!("{:.2}", f);
192                ((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
193                    as usize
194            }
195            SochValue::Text(s) => {
196                // Account for potential subword splitting
197                ((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
198                    as usize
199            }
200            SochValue::Binary(b) => {
201                // Hex encoding: 0x + 2 chars per byte
202                let hex_len = 2 + b.len() * 2;
203                ((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
204                    as usize
205            }
206            SochValue::Array(arr) => {
207                // Sum tokens for array elements plus brackets and separators
208                let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
209                let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
210                2 + elem_tokens + separator_tokens // 2 for [ and ]
211            }
212        }
213    }
214
215    /// Estimate tokens for a row (multiple values)
216    pub fn estimate_row(&self, values: &[SochValue]) -> usize {
217        if values.is_empty() {
218            return 0;
219        }
220
221        let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
222        let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
223        let newline = self.config.newline_tokens;
224
225        value_tokens + separator_tokens + newline
226    }
227
228    /// Estimate tokens for a table header
229    pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
230        // Format: table[N]{col1,col2,...}:
231        let base = self.config.header_tokens;
232        let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
233        let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
234        let col_tokens: usize = columns
235            .iter()
236            .map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
237            .sum();
238
239        base + table_tokens + count_tokens + col_tokens
240    }
241
242    /// Estimate tokens for a complete TOON table
243    pub fn estimate_table(
244        &self,
245        table: &str,
246        columns: &[String],
247        rows: &[Vec<SochValue>],
248    ) -> usize {
249        let header = self.estimate_header(table, columns, rows.len());
250        let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
251        header + row_tokens
252    }
253
254    /// Estimate tokens for plain text
255    pub fn estimate_text(&self, text: &str) -> usize {
256        let raw = ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize;
257        ((raw as f32) * self.config.safety_margin).ceil() as usize
258    }
259
260    /// Truncate text to fit within token budget
261    ///
262    /// Uses binary search to find the optimal truncation point.
263    pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
264        truncate_to_tokens(text, max_tokens, self, "...")
265    }
266}
267
268impl Default for TokenEstimator {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274// ============================================================================
275// Token Budget Enforcer
276// ============================================================================
277
278/// Token budget enforcement result
279#[derive(Debug, Clone)]
280pub struct BudgetAllocation {
281    /// Sections that fit fully
282    pub full_sections: Vec<String>,
283    /// Sections that were truncated (name, original_tokens, allocated_tokens)
284    pub truncated_sections: Vec<(String, usize, usize)>,
285    /// Sections that were dropped
286    pub dropped_sections: Vec<String>,
287    /// Total tokens allocated
288    pub tokens_allocated: usize,
289    /// Remaining budget
290    pub tokens_remaining: usize,
291    /// Detailed allocation decisions for EXPLAIN CONTEXT
292    pub explain: Vec<AllocationDecision>,
293}
294
295/// Detailed explanation of a single allocation decision
296#[derive(Debug, Clone)]
297pub struct AllocationDecision {
298    /// Section name
299    pub section: String,
300    /// Priority value
301    pub priority: i32,
302    /// Requested tokens
303    pub requested: usize,
304    /// Allocated tokens
305    pub allocated: usize,
306    /// Decision outcome
307    pub outcome: AllocationOutcome,
308    /// Human-readable reason
309    pub reason: String,
310}
311
312/// Outcome of an allocation decision
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum AllocationOutcome {
315    /// Section included in full
316    Full,
317    /// Section truncated to fit
318    Truncated,
319    /// Section dropped entirely
320    Dropped,
321}
322
323/// Section for budget allocation
324#[derive(Debug, Clone)]
325pub struct BudgetSection {
326    /// Section name
327    pub name: String,
328    /// Priority (lower = higher priority)
329    pub priority: i32,
330    /// Estimated token count
331    pub estimated_tokens: usize,
332    /// Minimum tokens needed (for truncation)
333    pub minimum_tokens: Option<usize>,
334    /// Is this section required?
335    pub required: bool,
336    /// Weight for proportional allocation (default: 1.0)
337    pub weight: f32,
338}
339
340impl Default for BudgetSection {
341    fn default() -> Self {
342        Self {
343            name: String::new(),
344            priority: 0,
345            estimated_tokens: 0,
346            minimum_tokens: None,
347            required: false,
348            weight: 1.0,
349        }
350    }
351}
352
353/// Allocation strategy
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
355pub enum AllocationStrategy {
356    /// Greedy by priority (default) - process sections in priority order
357    #[default]
358    GreedyPriority,
359    /// Proportional / water-filling - allocate proportionally by weight
360    Proportional,
361    /// Strict priority with minimum guarantees
362    StrictPriority,
363}
364
365/// Token budget enforcer
366///
367/// Implements greedy token allocation by priority with optional
368/// truncation support.
369pub struct TokenBudgetEnforcer {
370    /// Total budget
371    budget: usize,
372    /// Current allocation
373    allocated: AtomicUsize,
374    /// Estimator for token counting
375    estimator: TokenEstimator,
376    /// Reserved tokens (for overhead, etc.)
377    reserved: usize,
378    /// Allocation strategy
379    strategy: AllocationStrategy,
380}
381
382/// Configuration for TokenBudgetEnforcer
383#[derive(Debug, Clone)]
384pub struct TokenBudgetConfig {
385    /// Total token budget
386    pub total_budget: usize,
387    /// Reserved tokens for overhead
388    pub reserved_tokens: usize,
389    /// Enable strict budget enforcement
390    pub strict: bool,
391    /// Default priority for unspecified sections
392    pub default_priority: i32,
393    /// Allocation strategy
394    pub strategy: AllocationStrategy,
395}
396
397impl Default for TokenBudgetConfig {
398    fn default() -> Self {
399        Self {
400            total_budget: 4096,
401            reserved_tokens: 100,
402            strict: false,
403            default_priority: 10,
404            strategy: AllocationStrategy::GreedyPriority,
405        }
406    }
407}
408
409impl TokenBudgetEnforcer {
410    /// Create a new budget enforcer
411    pub fn new(config: TokenBudgetConfig) -> Self {
412        Self {
413            budget: config.total_budget,
414            allocated: AtomicUsize::new(0),
415            estimator: TokenEstimator::new(),
416            reserved: config.reserved_tokens,
417            strategy: config.strategy,
418        }
419    }
420
421    /// Create with simple budget (for backwards compatibility)
422    pub fn with_budget(budget: usize) -> Self {
423        Self {
424            budget,
425            allocated: AtomicUsize::new(0),
426            estimator: TokenEstimator::new(),
427            reserved: 0,
428            strategy: AllocationStrategy::GreedyPriority,
429        }
430    }
431
432    /// Create with custom estimator
433    pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
434        Self {
435            budget,
436            allocated: AtomicUsize::new(0),
437            estimator,
438            reserved: 0,
439            strategy: AllocationStrategy::GreedyPriority,
440        }
441    }
442
443    /// Set allocation strategy
444    pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
445        self.strategy = strategy;
446        self
447    }
448
449    /// Reserve tokens for overhead (headers, separators, etc.)
450    pub fn reserve(&mut self, tokens: usize) {
451        self.reserved = tokens;
452    }
453
454    /// Get available budget (total - reserved - allocated)
455    pub fn available(&self) -> usize {
456        let allocated = self.allocated.load(Ordering::Acquire);
457        self.budget.saturating_sub(self.reserved + allocated)
458    }
459
460    /// Get total budget
461    pub fn total_budget(&self) -> usize {
462        self.budget
463    }
464
465    /// Get allocated tokens
466    pub fn allocated(&self) -> usize {
467        self.allocated.load(Ordering::Acquire)
468    }
469
470    /// Try to allocate tokens (returns true if successful)
471    pub fn try_allocate(&self, tokens: usize) -> bool {
472        loop {
473            let current = self.allocated.load(Ordering::Acquire);
474            let new_total = current + tokens;
475
476            if new_total + self.reserved > self.budget {
477                return false;
478            }
479
480            if self
481                .allocated
482                .compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
483                .is_ok()
484            {
485                return true;
486            }
487            // Retry on contention
488        }
489    }
490
491    /// Allocate sections by priority (dispatches to strategy-specific method)
492    pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
493        match self.strategy {
494            AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
495            AllocationStrategy::Proportional => self.allocate_proportional(sections),
496            AllocationStrategy::StrictPriority => self.allocate_strict(sections),
497        }
498    }
499
500    /// Greedy allocation by priority order
501    fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
502        // Sort by priority (lower = higher priority)
503        let mut sorted: Vec<_> = sections.iter().collect();
504        sorted.sort_by_key(|s| s.priority);
505
506        let mut allocation = BudgetAllocation {
507            full_sections: Vec::new(),
508            truncated_sections: Vec::new(),
509            dropped_sections: Vec::new(),
510            tokens_allocated: 0,
511            tokens_remaining: self.budget.saturating_sub(self.reserved),
512            explain: Vec::new(),
513        };
514
515        for section in sorted {
516            let remaining = allocation.tokens_remaining;
517
518            if section.estimated_tokens <= remaining {
519                // Section fits fully
520                allocation.full_sections.push(section.name.clone());
521                allocation.tokens_allocated += section.estimated_tokens;
522                allocation.tokens_remaining -= section.estimated_tokens;
523                allocation.explain.push(AllocationDecision {
524                    section: section.name.clone(),
525                    priority: section.priority,
526                    requested: section.estimated_tokens,
527                    allocated: section.estimated_tokens,
528                    outcome: AllocationOutcome::Full,
529                    reason: format!("Fits in remaining budget ({} tokens)", remaining),
530                });
531            } else if let Some(min) = section.minimum_tokens {
532                // Try truncated version
533                if min <= remaining {
534                    let truncated_to = remaining;
535                    allocation.truncated_sections.push((
536                        section.name.clone(),
537                        section.estimated_tokens,
538                        truncated_to,
539                    ));
540                    allocation.tokens_allocated += truncated_to;
541                    allocation.explain.push(AllocationDecision {
542                        section: section.name.clone(),
543                        priority: section.priority,
544                        requested: section.estimated_tokens,
545                        allocated: truncated_to,
546                        outcome: AllocationOutcome::Truncated,
547                        reason: format!(
548                            "Truncated from {} to {} tokens (min: {})",
549                            section.estimated_tokens, truncated_to, min
550                        ),
551                    });
552                    allocation.tokens_remaining = 0;
553                } else {
554                    allocation.dropped_sections.push(section.name.clone());
555                    allocation.explain.push(AllocationDecision {
556                        section: section.name.clone(),
557                        priority: section.priority,
558                        requested: section.estimated_tokens,
559                        allocated: 0,
560                        outcome: AllocationOutcome::Dropped,
561                        reason: format!("Minimum {} exceeds remaining {} tokens", min, remaining),
562                    });
563                }
564            } else {
565                // No truncation, must drop
566                allocation.dropped_sections.push(section.name.clone());
567                allocation.explain.push(AllocationDecision {
568                    section: section.name.clone(),
569                    priority: section.priority,
570                    requested: section.estimated_tokens,
571                    allocated: 0,
572                    outcome: AllocationOutcome::Dropped,
573                    reason: format!(
574                        "Requested {} exceeds remaining {} (no truncation allowed)",
575                        section.estimated_tokens, remaining
576                    ),
577                });
578            }
579        }
580
581        allocation
582    }
583
584    /// Proportional / water-filling allocation
585    ///
586    /// Allocates tokens proportionally by weight:
587    /// $$b_i = \lfloor B \cdot w_i / \sum w \rfloor$$
588    ///
589    /// With minimum guarantees and iterative redistribution.
590    fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
591        let available = self.budget.saturating_sub(self.reserved);
592        let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
593
594        if total_weight == 0.0 {
595            return self.allocate_greedy(sections);
596        }
597
598        let mut allocation = BudgetAllocation {
599            full_sections: Vec::new(),
600            truncated_sections: Vec::new(),
601            dropped_sections: Vec::new(),
602            tokens_allocated: 0,
603            tokens_remaining: available,
604            explain: Vec::new(),
605        };
606
607        // Phase 1: Calculate proportional allocations
608        let mut allocations: Vec<(usize, usize, bool)> = sections
609            .iter()
610            .map(|s| {
611                let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
612                let capped = proportional.min(s.estimated_tokens);
613                let min = s.minimum_tokens.unwrap_or(0);
614                (
615                    capped.max(min),
616                    s.estimated_tokens,
617                    capped < s.estimated_tokens,
618                )
619            })
620            .collect();
621
622        // Phase 2: Adjust to fit budget (water-filling)
623        let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
624
625        // If over budget, reduce proportionally from largest allocations
626        while total > available {
627            // Find the section with largest allocation that can be reduced
628            let max_idx = allocations
629                .iter()
630                .enumerate()
631                .filter(|(i, (a, _, _))| *a > sections[*i].minimum_tokens.unwrap_or(0))
632                .max_by_key(|(_, (a, _, _))| *a)
633                .map(|(i, _)| i);
634
635            match max_idx {
636                Some(idx) => {
637                    let reduce = (total - available)
638                        .min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
639                    allocations[idx].0 -= reduce;
640                    total -= reduce;
641                }
642                None => break, // Can't reduce further
643            }
644        }
645
646        // Phase 3: Record results
647        for (i, section) in sections.iter().enumerate() {
648            let (allocated, requested, truncated) = allocations[i];
649
650            if allocated == 0 {
651                allocation.dropped_sections.push(section.name.clone());
652                allocation.explain.push(AllocationDecision {
653                    section: section.name.clone(),
654                    priority: section.priority,
655                    requested,
656                    allocated: 0,
657                    outcome: AllocationOutcome::Dropped,
658                    reason: "No budget available after proportional allocation".to_string(),
659                });
660            } else if truncated {
661                allocation
662                    .truncated_sections
663                    .push((section.name.clone(), requested, allocated));
664                allocation.tokens_allocated += allocated;
665                allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
666                allocation.explain.push(AllocationDecision {
667                    section: section.name.clone(),
668                    priority: section.priority,
669                    requested,
670                    allocated,
671                    outcome: AllocationOutcome::Truncated,
672                    reason: format!(
673                        "Proportional allocation: {:.1}% of budget (weight {:.1})",
674                        (allocated as f32 / available as f32) * 100.0,
675                        section.weight
676                    ),
677                });
678            } else {
679                allocation.full_sections.push(section.name.clone());
680                allocation.tokens_allocated += allocated;
681                allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
682                allocation.explain.push(AllocationDecision {
683                    section: section.name.clone(),
684                    priority: section.priority,
685                    requested,
686                    allocated,
687                    outcome: AllocationOutcome::Full,
688                    reason: format!(
689                        "Full allocation within proportional budget (weight {:.1})",
690                        section.weight
691                    ),
692                });
693            }
694        }
695
696        allocation
697    }
698
699    /// Strict priority with guaranteed minimums for required sections
700    fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
701        let mut sorted: Vec<_> = sections.iter().collect();
702        sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
703
704        // First pass: allocate minimums for required sections
705        let mut allocation = BudgetAllocation {
706            full_sections: Vec::new(),
707            truncated_sections: Vec::new(),
708            dropped_sections: Vec::new(),
709            tokens_allocated: 0,
710            tokens_remaining: self.budget.saturating_sub(self.reserved),
711            explain: Vec::new(),
712        };
713
714        // Allocate required sections first (at minimum or full)
715        for section in sorted.iter().filter(|s| s.required) {
716            let remaining = allocation.tokens_remaining;
717            let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
718
719            if section.estimated_tokens <= remaining {
720                allocation.full_sections.push(section.name.clone());
721                allocation.tokens_allocated += section.estimated_tokens;
722                allocation.tokens_remaining -= section.estimated_tokens;
723                allocation.explain.push(AllocationDecision {
724                    section: section.name.clone(),
725                    priority: section.priority,
726                    requested: section.estimated_tokens,
727                    allocated: section.estimated_tokens,
728                    outcome: AllocationOutcome::Full,
729                    reason: "Required section - full allocation".to_string(),
730                });
731            } else if min <= remaining {
732                allocation.truncated_sections.push((
733                    section.name.clone(),
734                    section.estimated_tokens,
735                    remaining,
736                ));
737                allocation.tokens_allocated += remaining;
738                allocation.explain.push(AllocationDecision {
739                    section: section.name.clone(),
740                    priority: section.priority,
741                    requested: section.estimated_tokens,
742                    allocated: remaining,
743                    outcome: AllocationOutcome::Truncated,
744                    reason: "Required section - truncated to fit".to_string(),
745                });
746                allocation.tokens_remaining = 0;
747            }
748            // Required sections can't be dropped - would be an error condition
749        }
750
751        // Then allocate optional sections
752        for section in sorted.iter().filter(|s| !s.required) {
753            let remaining = allocation.tokens_remaining;
754
755            if remaining == 0 {
756                allocation.dropped_sections.push(section.name.clone());
757                allocation.explain.push(AllocationDecision {
758                    section: section.name.clone(),
759                    priority: section.priority,
760                    requested: section.estimated_tokens,
761                    allocated: 0,
762                    outcome: AllocationOutcome::Dropped,
763                    reason: "No budget remaining after required sections".to_string(),
764                });
765                continue;
766            }
767
768            if section.estimated_tokens <= remaining {
769                allocation.full_sections.push(section.name.clone());
770                allocation.tokens_allocated += section.estimated_tokens;
771                allocation.tokens_remaining -= section.estimated_tokens;
772                allocation.explain.push(AllocationDecision {
773                    section: section.name.clone(),
774                    priority: section.priority,
775                    requested: section.estimated_tokens,
776                    allocated: section.estimated_tokens,
777                    outcome: AllocationOutcome::Full,
778                    reason: "Optional section - fits in remaining budget".to_string(),
779                });
780            } else if let Some(min) = section.minimum_tokens {
781                if min <= remaining {
782                    allocation.truncated_sections.push((
783                        section.name.clone(),
784                        section.estimated_tokens,
785                        remaining,
786                    ));
787                    allocation.tokens_allocated += remaining;
788                    allocation.explain.push(AllocationDecision {
789                        section: section.name.clone(),
790                        priority: section.priority,
791                        requested: section.estimated_tokens,
792                        allocated: remaining,
793                        outcome: AllocationOutcome::Truncated,
794                        reason: "Optional section - truncated to fit".to_string(),
795                    });
796                    allocation.tokens_remaining = 0;
797                } else {
798                    allocation.dropped_sections.push(section.name.clone());
799                    allocation.explain.push(AllocationDecision {
800                        section: section.name.clone(),
801                        priority: section.priority,
802                        requested: section.estimated_tokens,
803                        allocated: 0,
804                        outcome: AllocationOutcome::Dropped,
805                        reason: format!("Minimum {} exceeds remaining {}", min, remaining),
806                    });
807                }
808            } else {
809                allocation.dropped_sections.push(section.name.clone());
810                allocation.explain.push(AllocationDecision {
811                    section: section.name.clone(),
812                    priority: section.priority,
813                    requested: section.estimated_tokens,
814                    allocated: 0,
815                    outcome: AllocationOutcome::Dropped,
816                    reason: format!(
817                        "Requested {} exceeds remaining {}",
818                        section.estimated_tokens, remaining
819                    ),
820                });
821            }
822        }
823
824        allocation
825    }
826
827    /// Reset allocation
828    pub fn reset(&self) {
829        self.allocated.store(0, Ordering::Release);
830    }
831
832    /// Get the estimator
833    pub fn estimator(&self) -> &TokenEstimator {
834        &self.estimator
835    }
836}
837
838// ============================================================================
839// EXPLAIN CONTEXT Output
840// ============================================================================
841
842impl BudgetAllocation {
843    /// Generate human-readable explanation of budget allocation
844    pub fn explain_text(&self) -> String {
845        let mut output = String::new();
846        output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
847        output.push_str(&format!(
848            "Total Allocated: {} tokens\n",
849            self.tokens_allocated
850        ));
851        output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
852
853        output.push_str("SECTIONS:\n");
854        for decision in &self.explain {
855            let status = match decision.outcome {
856                AllocationOutcome::Full => "✓ FULL",
857                AllocationOutcome::Truncated => "◐ TRUNCATED",
858                AllocationOutcome::Dropped => "✗ DROPPED",
859            };
860            output.push_str(&format!(
861                "  [{:^12}] {} (priority {})\n",
862                status, decision.section, decision.priority
863            ));
864            output.push_str(&format!(
865                "               Requested: {}, Allocated: {}\n",
866                decision.requested, decision.allocated
867            ));
868            output.push_str(&format!("               Reason: {}\n", decision.reason));
869        }
870
871        output
872    }
873
874    /// Generate JSON explanation for programmatic use
875    pub fn explain_json(&self) -> String {
876        serde_json::to_string_pretty(&ExplainOutput {
877            tokens_allocated: self.tokens_allocated,
878            tokens_remaining: self.tokens_remaining,
879            full_sections: self.full_sections.clone(),
880            truncated_sections: self.truncated_sections.clone(),
881            dropped_sections: self.dropped_sections.clone(),
882            decisions: self
883                .explain
884                .iter()
885                .map(|d| ExplainDecision {
886                    section: d.section.clone(),
887                    priority: d.priority,
888                    requested: d.requested,
889                    allocated: d.allocated,
890                    outcome: format!("{:?}", d.outcome),
891                    reason: d.reason.clone(),
892                })
893                .collect(),
894        })
895        .unwrap_or_else(|_| "{}".to_string())
896    }
897}
898
899#[derive(serde::Serialize)]
900struct ExplainOutput {
901    tokens_allocated: usize,
902    tokens_remaining: usize,
903    full_sections: Vec<String>,
904    truncated_sections: Vec<(String, usize, usize)>,
905    dropped_sections: Vec<String>,
906    decisions: Vec<ExplainDecision>,
907}
908
909#[derive(serde::Serialize)]
910struct ExplainDecision {
911    section: String,
912    priority: i32,
913    requested: usize,
914    allocated: usize,
915    outcome: String,
916    reason: String,
917}
918
919// ============================================================================
920// Token-Aware Truncation
921// ============================================================================
922
923/// Truncate a string to fit within a token budget
924pub fn truncate_to_tokens(
925    text: &str,
926    max_tokens: usize,
927    estimator: &TokenEstimator,
928    suffix: &str,
929) -> String {
930    let current = estimator.estimate_text(text);
931
932    if current <= max_tokens {
933        return text.to_string();
934    }
935
936    let suffix_tokens = estimator.estimate_text(suffix);
937    let target_tokens = max_tokens.saturating_sub(suffix_tokens);
938
939    if target_tokens == 0 {
940        return suffix.to_string();
941    }
942
943    // Binary search for the right truncation point
944    let mut low = 0;
945    let mut high = text.len();
946
947    while low < high {
948        let mid = (low + high).div_ceil(2);
949
950        // Find character boundary
951        let boundary = text
952            .char_indices()
953            .take_while(|(i, _)| *i < mid)
954            .last()
955            .map(|(i, c)| i + c.len_utf8())
956            .unwrap_or(0);
957
958        let truncated = &text[..boundary];
959        let tokens = estimator.estimate_text(truncated);
960
961        if tokens <= target_tokens {
962            low = boundary;
963        } else {
964            high = boundary.saturating_sub(1);
965        }
966    }
967
968    // Find word boundary
969    let truncated = &text[..low];
970    let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
971
972    format!("{}{}", &text[..word_boundary], suffix)
973}
974
975/// Truncate rows to fit within token budget
976pub fn truncate_rows(
977    rows: &[Vec<SochValue>],
978    max_tokens: usize,
979    estimator: &TokenEstimator,
980) -> Vec<Vec<SochValue>> {
981    let mut result = Vec::new();
982    let mut used = 0;
983
984    for row in rows {
985        let row_tokens = estimator.estimate_row(row);
986
987        if used + row_tokens <= max_tokens {
988            result.push(row.clone());
989            used += row_tokens;
990        } else {
991            break; // No more room
992        }
993    }
994
995    result
996}
997
998// ============================================================================
999// Tests
1000// ============================================================================
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005
1006    #[test]
1007    fn test_estimate_value_int() {
1008        let est = TokenEstimator::new();
1009
1010        // Small integers
1011        assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
1012        assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
1013
1014        // Large integers use more tokens
1015        let small = est.estimate_value(&SochValue::Int(42));
1016        let large = est.estimate_value(&SochValue::Int(1_000_000_000));
1017        assert!(large >= small);
1018    }
1019
1020    #[test]
1021    fn test_estimate_value_text() {
1022        let est = TokenEstimator::new();
1023
1024        let short = est.estimate_value(&SochValue::Text("hello".to_string()));
1025        let long = est.estimate_value(&SochValue::Text(
1026            "hello world this is a longer string".to_string(),
1027        ));
1028
1029        assert!(long > short);
1030    }
1031
1032    #[test]
1033    #[allow(clippy::approx_constant)]
1034    fn test_estimate_row() {
1035        let est = TokenEstimator::new();
1036
1037        let row = vec![
1038            SochValue::Int(1),
1039            SochValue::Text("Alice".to_string()),
1040            SochValue::Float(3.14),
1041        ];
1042
1043        let tokens = est.estimate_row(&row);
1044
1045        // Should be sum of values + separators + newline
1046        assert!(tokens >= 3); // At least 1 per value
1047    }
1048
1049    #[test]
1050    fn test_estimate_table() {
1051        let est = TokenEstimator::new();
1052
1053        let columns = vec!["id".to_string(), "name".to_string()];
1054        let rows = vec![
1055            vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
1056            vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
1057        ];
1058
1059        let tokens = est.estimate_table("users", &columns, &rows);
1060
1061        // Should include header + rows
1062        assert!(tokens > est.estimate_row(&rows[0]) * 2);
1063    }
1064
1065    #[test]
1066    fn test_budget_enforcer_allocation() {
1067        let enforcer = TokenBudgetEnforcer::with_budget(1000);
1068
1069        assert!(enforcer.try_allocate(500));
1070        assert_eq!(enforcer.allocated(), 500);
1071        assert_eq!(enforcer.available(), 500);
1072
1073        assert!(enforcer.try_allocate(400));
1074        assert_eq!(enforcer.allocated(), 900);
1075
1076        // This should fail (only 100 left)
1077        assert!(!enforcer.try_allocate(200));
1078        assert_eq!(enforcer.allocated(), 900);
1079    }
1080
1081    #[test]
1082    fn test_budget_enforcer_reset() {
1083        let enforcer = TokenBudgetEnforcer::with_budget(1000);
1084
1085        enforcer.try_allocate(800);
1086        assert_eq!(enforcer.allocated(), 800);
1087
1088        enforcer.reset();
1089        assert_eq!(enforcer.allocated(), 0);
1090    }
1091
1092    #[test]
1093    fn test_allocate_sections() {
1094        let enforcer = TokenBudgetEnforcer::with_budget(1000);
1095
1096        let sections = vec![
1097            BudgetSection {
1098                name: "A".to_string(),
1099                priority: 0,
1100                estimated_tokens: 300,
1101                minimum_tokens: None,
1102                required: true,
1103                weight: 1.0,
1104            },
1105            BudgetSection {
1106                name: "B".to_string(),
1107                priority: 1,
1108                estimated_tokens: 400,
1109                minimum_tokens: Some(200),
1110                required: false,
1111                weight: 1.0,
1112            },
1113            BudgetSection {
1114                name: "C".to_string(),
1115                priority: 2,
1116                estimated_tokens: 500,
1117                minimum_tokens: None,
1118                required: false,
1119                weight: 1.0,
1120            },
1121        ];
1122
1123        let allocation = enforcer.allocate_sections(&sections);
1124
1125        // A fits fully
1126        assert!(allocation.full_sections.contains(&"A".to_string()));
1127
1128        // B might fit (300 remaining after A)
1129        // C won't fit (500 tokens, only 300 remaining)
1130        assert!(allocation.dropped_sections.contains(&"C".to_string()));
1131
1132        assert!(allocation.tokens_allocated <= 1000);
1133    }
1134
1135    #[test]
1136    fn test_allocate_by_priority() {
1137        let enforcer = TokenBudgetEnforcer::with_budget(500);
1138
1139        let sections = vec![
1140            BudgetSection {
1141                name: "LowPriority".to_string(),
1142                priority: 10,
1143                estimated_tokens: 200,
1144                minimum_tokens: None,
1145                required: false,
1146                weight: 1.0,
1147            },
1148            BudgetSection {
1149                name: "HighPriority".to_string(),
1150                priority: 0,
1151                estimated_tokens: 400,
1152                minimum_tokens: None,
1153                required: true,
1154                weight: 1.0,
1155            },
1156        ];
1157
1158        let allocation = enforcer.allocate_sections(&sections);
1159
1160        // High priority goes first
1161        assert!(
1162            allocation
1163                .full_sections
1164                .contains(&"HighPriority".to_string())
1165        );
1166
1167        // Low priority dropped (only 100 remaining)
1168        assert!(
1169            allocation
1170                .dropped_sections
1171                .contains(&"LowPriority".to_string())
1172        );
1173    }
1174
1175    #[test]
1176    fn test_truncate_to_tokens() {
1177        let est = TokenEstimator::new();
1178
1179        let text = "This is a long text that needs to be truncated to fit within the token budget";
1180        let truncated = truncate_to_tokens(text, 10, &est, "...");
1181
1182        // Should be shorter
1183        assert!(truncated.len() < text.len());
1184
1185        // Should end with suffix
1186        assert!(truncated.ends_with("..."));
1187
1188        // Should fit budget
1189        assert!(est.estimate_text(&truncated) <= 10);
1190    }
1191
1192    #[test]
1193    fn test_truncate_rows() {
1194        let est = TokenEstimator::new();
1195
1196        let rows: Vec<Vec<SochValue>> = (0..100)
1197            .map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
1198            .collect();
1199
1200        let truncated = truncate_rows(&rows, 50, &est);
1201
1202        // Should have fewer rows
1203        assert!(truncated.len() < rows.len());
1204
1205        // Total tokens should be under budget
1206        let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
1207        assert!(total <= 50);
1208    }
1209
1210    #[test]
1211    fn test_reserved_budget() {
1212        let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
1213        enforcer.reserve(200);
1214
1215        assert_eq!(enforcer.available(), 800);
1216
1217        assert!(enforcer.try_allocate(700));
1218        assert_eq!(enforcer.available(), 100);
1219
1220        // Cannot exceed available (reserves are protected)
1221        assert!(!enforcer.try_allocate(200));
1222    }
1223
1224    #[test]
1225    fn test_estimator_configs() {
1226        let default = TokenEstimator::new();
1227        let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
1228        let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
1229
1230        let text = "Hello, this is a test string for comparing token estimation across different configurations.";
1231
1232        let default_est = default.estimate_text(text);
1233        let gpt4_est = gpt4.estimate_text(text);
1234        let conservative_est = conservative.estimate_text(text);
1235
1236        // Conservative should give highest estimate
1237        assert!(conservative_est >= default_est);
1238
1239        // All should be positive
1240        assert!(default_est > 0);
1241        assert!(gpt4_est > 0);
1242        assert!(conservative_est > 0);
1243    }
1244
1245    #[test]
1246    fn test_section_with_truncation() {
1247        let enforcer = TokenBudgetEnforcer::with_budget(600);
1248
1249        let sections = vec![
1250            BudgetSection {
1251                name: "Required".to_string(),
1252                priority: 0,
1253                estimated_tokens: 500,
1254                minimum_tokens: None,
1255                required: true,
1256                weight: 1.0,
1257            },
1258            BudgetSection {
1259                name: "Optional".to_string(),
1260                priority: 1,
1261                estimated_tokens: 300,
1262                minimum_tokens: Some(50), // Can be truncated
1263                required: false,
1264                weight: 1.0,
1265            },
1266        ];
1267
1268        let allocation = enforcer.allocate_sections(&sections);
1269
1270        // Required fits
1271        assert!(allocation.full_sections.contains(&"Required".to_string()));
1272
1273        // Optional gets truncated (only 100 remaining, min is 50)
1274        assert!(
1275            allocation
1276                .truncated_sections
1277                .iter()
1278                .any(|(n, _, _)| n == "Optional")
1279        );
1280    }
1281}