Skip to main content

batuta/content/
budget.rs

1//! Token budgeting (Heijunka)
2//!
3//! Token budget calculation (spec section 5.4) for load leveling.
4
5use super::ContentError;
6use serde::{Deserialize, Serialize};
7
8/// Model context window sizes
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
10pub enum ModelContext {
11    /// Claude Sonnet/Opus (200K) - Default
12    #[default]
13    Claude200K,
14    /// Claude Haiku (200K)
15    ClaudeHaiku,
16    /// Gemini Pro (1M)
17    GeminiPro,
18    /// Gemini Flash (1M)
19    GeminiFlash,
20    /// GPT-4 Turbo (128K)
21    Gpt4Turbo,
22    /// Custom context window
23    Custom(usize),
24}
25
26impl ModelContext {
27    /// Get the context window size in tokens
28    pub fn window_size(&self) -> usize {
29        match self {
30            ModelContext::Claude200K => 200_000,
31            ModelContext::ClaudeHaiku => 200_000,
32            ModelContext::GeminiPro => 1_000_000,
33            ModelContext::GeminiFlash => 1_000_000,
34            ModelContext::Gpt4Turbo => 128_000,
35            ModelContext::Custom(size) => *size,
36        }
37    }
38
39    /// Get the model name
40    pub fn name(&self) -> &'static str {
41        match self {
42            ModelContext::Claude200K => "claude-sonnet",
43            ModelContext::ClaudeHaiku => "claude-haiku",
44            ModelContext::GeminiPro => "gemini-pro",
45            ModelContext::GeminiFlash => "gemini-flash",
46            ModelContext::Gpt4Turbo => "gpt-4-turbo",
47            ModelContext::Custom(_) => "custom",
48        }
49    }
50}
51
52/// Token budget calculation (spec section 5.4)
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct TokenBudget {
55    /// Target model context window
56    pub context_window: usize,
57    /// Reserved for system prompt
58    pub system_reserve: usize,
59    /// Reserved for source context (--source-context)
60    pub source_context: usize,
61    /// Reserved for RAG context (--rag-context)
62    pub rag_context: usize,
63    /// Reserved for few-shot examples
64    pub few_shot: usize,
65    /// Target output tokens
66    pub output_target: usize,
67}
68
69impl TokenBudget {
70    /// Create a new token budget for a model
71    pub fn new(model: ModelContext) -> Self {
72        Self {
73            context_window: model.window_size(),
74            system_reserve: 2_000,
75            source_context: 0,
76            rag_context: 0,
77            few_shot: 1_500,
78            output_target: 4_000,
79        }
80    }
81
82    /// Set source context tokens
83    pub fn with_source_context(mut self, tokens: usize) -> Self {
84        self.source_context = tokens;
85        self
86    }
87
88    /// Set RAG context tokens
89    pub fn with_rag_context(mut self, tokens: usize) -> Self {
90        self.rag_context = tokens;
91        self
92    }
93
94    /// Set output target tokens
95    pub fn with_output_target(mut self, tokens: usize) -> Self {
96        self.output_target = tokens;
97        self
98    }
99
100    /// Calculate total prompt tokens (excluding output)
101    pub fn prompt_tokens(&self) -> usize {
102        self.system_reserve + self.source_context + self.rag_context + self.few_shot
103    }
104
105    /// Calculate available margin
106    pub fn available_margin(&self) -> usize {
107        let used = self.prompt_tokens() + self.output_target;
108        self.context_window.saturating_sub(used)
109    }
110
111    /// Validate the budget fits within context window
112    pub fn validate(&self) -> Result<(), ContentError> {
113        let total = self.prompt_tokens() + self.output_target;
114        if total > self.context_window {
115            Err(ContentError::TokenBudgetExceeded { used: total, limit: self.context_window })
116        } else {
117            Ok(())
118        }
119    }
120
121    /// Estimate tokens from word count (rough: 1 word ≈ 1.3 tokens)
122    pub fn words_to_tokens(words: usize) -> usize {
123        (words as f64 * 1.3).ceil() as usize
124    }
125
126    /// Estimate words from token count
127    pub fn tokens_to_words(tokens: usize) -> usize {
128        (tokens as f64 / 1.3).floor() as usize
129    }
130
131    /// Format budget as display string
132    pub fn format_display(&self, model_name: &str) -> String {
133        let mut output = String::new();
134        output.push_str(&format!(
135            "Token Budget for {} ({}K context):\n",
136            model_name,
137            self.context_window / 1000
138        ));
139        output.push_str(&format!("├── System prompt:     {:>6} tokens\n", self.system_reserve));
140        output.push_str(&format!("├── Source context:    {:>6} tokens\n", self.source_context));
141        output.push_str(&format!("├── RAG context:       {:>6} tokens\n", self.rag_context));
142        output.push_str(&format!("├── Few-shot examples: {:>6} tokens\n", self.few_shot));
143        output.push_str(&format!(
144            "├── Output reserved:   {:>6} tokens (~{} words)\n",
145            self.output_target,
146            Self::tokens_to_words(self.output_target)
147        ));
148        let margin = self.available_margin();
149        let status = if margin > 0 { "✓" } else { "✗" };
150        output.push_str(&format!("└── Available margin:  {:>6} tokens {}\n", margin, status));
151        output
152    }
153}
154
155impl Default for TokenBudget {
156    fn default() -> Self {
157        Self::new(ModelContext::Claude200K)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    // =========================================================================
166    // ModelContext Tests
167    // =========================================================================
168
169    #[test]
170    fn test_model_context_default() {
171        let ctx = ModelContext::default();
172        assert_eq!(ctx, ModelContext::Claude200K);
173    }
174
175    #[test]
176    fn test_model_context_window_sizes() {
177        assert_eq!(ModelContext::Claude200K.window_size(), 200_000);
178        assert_eq!(ModelContext::ClaudeHaiku.window_size(), 200_000);
179        assert_eq!(ModelContext::GeminiPro.window_size(), 1_000_000);
180        assert_eq!(ModelContext::GeminiFlash.window_size(), 1_000_000);
181        assert_eq!(ModelContext::Gpt4Turbo.window_size(), 128_000);
182        assert_eq!(ModelContext::Custom(50_000).window_size(), 50_000);
183    }
184
185    #[test]
186    fn test_model_context_names() {
187        assert_eq!(ModelContext::Claude200K.name(), "claude-sonnet");
188        assert_eq!(ModelContext::ClaudeHaiku.name(), "claude-haiku");
189        assert_eq!(ModelContext::GeminiPro.name(), "gemini-pro");
190        assert_eq!(ModelContext::GeminiFlash.name(), "gemini-flash");
191        assert_eq!(ModelContext::Gpt4Turbo.name(), "gpt-4-turbo");
192        assert_eq!(ModelContext::Custom(1000).name(), "custom");
193    }
194
195    #[test]
196    fn test_model_context_serialization() {
197        let ctx = ModelContext::GeminiPro;
198        let json = serde_json::to_string(&ctx).expect("json serialize failed");
199        let deserialized: ModelContext =
200            serde_json::from_str(&json).expect("json deserialize failed");
201        assert_eq!(deserialized, ctx);
202    }
203
204    #[test]
205    fn test_model_context_custom_serialization() {
206        let ctx = ModelContext::Custom(75_000);
207        let json = serde_json::to_string(&ctx).expect("json serialize failed");
208        let deserialized: ModelContext =
209            serde_json::from_str(&json).expect("json deserialize failed");
210        assert_eq!(deserialized, ctx);
211        assert_eq!(deserialized.window_size(), 75_000);
212    }
213
214    // =========================================================================
215    // TokenBudget Tests
216    // =========================================================================
217
218    #[test]
219    fn test_token_budget_new() {
220        let budget = TokenBudget::new(ModelContext::Claude200K);
221        assert_eq!(budget.context_window, 200_000);
222        assert_eq!(budget.system_reserve, 2_000);
223        assert_eq!(budget.source_context, 0);
224        assert_eq!(budget.rag_context, 0);
225        assert_eq!(budget.few_shot, 1_500);
226        assert_eq!(budget.output_target, 4_000);
227    }
228
229    #[test]
230    fn test_token_budget_default() {
231        let budget = TokenBudget::default();
232        assert_eq!(budget.context_window, 200_000);
233    }
234
235    #[test]
236    fn test_token_budget_with_source_context() {
237        let budget = TokenBudget::new(ModelContext::Claude200K).with_source_context(10_000);
238        assert_eq!(budget.source_context, 10_000);
239    }
240
241    #[test]
242    fn test_token_budget_with_rag_context() {
243        let budget = TokenBudget::new(ModelContext::Claude200K).with_rag_context(5_000);
244        assert_eq!(budget.rag_context, 5_000);
245    }
246
247    #[test]
248    fn test_token_budget_with_output_target() {
249        let budget = TokenBudget::new(ModelContext::Claude200K).with_output_target(8_000);
250        assert_eq!(budget.output_target, 8_000);
251    }
252
253    #[test]
254    fn test_token_budget_prompt_tokens() {
255        let budget = TokenBudget::new(ModelContext::Claude200K)
256            .with_source_context(10_000)
257            .with_rag_context(5_000);
258        // system_reserve(2000) + source(10000) + rag(5000) + few_shot(1500)
259        assert_eq!(budget.prompt_tokens(), 18_500);
260    }
261
262    #[test]
263    fn test_token_budget_available_margin() {
264        let budget = TokenBudget::new(ModelContext::Claude200K);
265        // context(200000) - prompt(3500) - output(4000)
266        let margin = budget.available_margin();
267        assert_eq!(margin, 200_000 - 3_500 - 4_000);
268    }
269
270    #[test]
271    fn test_token_budget_validate_ok() {
272        let budget = TokenBudget::new(ModelContext::Claude200K);
273        assert!(budget.validate().is_ok());
274    }
275
276    #[test]
277    fn test_token_budget_validate_exceeded() {
278        let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
279        assert!(budget.validate().is_err());
280    }
281
282    #[test]
283    fn test_words_to_tokens() {
284        // 100 words * 1.3 = 130 tokens
285        assert_eq!(TokenBudget::words_to_tokens(100), 130);
286        assert_eq!(TokenBudget::words_to_tokens(0), 0);
287    }
288
289    #[test]
290    fn test_tokens_to_words() {
291        // 130 tokens / 1.3 = 100 words
292        assert_eq!(TokenBudget::tokens_to_words(130), 100);
293        assert_eq!(TokenBudget::tokens_to_words(0), 0);
294    }
295
296    #[test]
297    fn test_token_budget_format_display() {
298        let budget = TokenBudget::new(ModelContext::Claude200K);
299        let output = budget.format_display("claude-sonnet");
300        assert!(output.contains("Token Budget for claude-sonnet"));
301        assert!(output.contains("200K context"));
302        assert!(output.contains("System prompt"));
303        assert!(output.contains("Available margin"));
304        assert!(output.contains("✓")); // Should have margin available
305    }
306
307    #[test]
308    fn test_token_budget_format_display_exceeded() {
309        let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
310        let output = budget.format_display("custom");
311        // When margin is 0 or negative, should show ✗
312        assert!(output.contains("Available margin"));
313    }
314
315    #[test]
316    fn test_token_budget_serialization() {
317        let budget = TokenBudget::new(ModelContext::GeminiPro)
318            .with_source_context(5_000)
319            .with_rag_context(3_000);
320        let json = serde_json::to_string(&budget).expect("json serialize failed");
321        let deserialized: TokenBudget =
322            serde_json::from_str(&json).expect("json deserialize failed");
323        assert_eq!(deserialized, budget);
324    }
325
326    #[test]
327    fn test_token_budget_builder_chain() {
328        let budget = TokenBudget::new(ModelContext::Gpt4Turbo)
329            .with_source_context(10_000)
330            .with_rag_context(8_000)
331            .with_output_target(6_000);
332
333        assert_eq!(budget.context_window, 128_000);
334        assert_eq!(budget.source_context, 10_000);
335        assert_eq!(budget.rag_context, 8_000);
336        assert_eq!(budget.output_target, 6_000);
337    }
338}