1use super::ContentError;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
10pub enum ModelContext {
11 #[default]
13 Claude200K,
14 ClaudeHaiku,
16 GeminiPro,
18 GeminiFlash,
20 Gpt4Turbo,
22 Custom(usize),
24}
25
26impl ModelContext {
27 pub fn window_size(&self) -> usize {
29 match self {
30 ModelContext::Claude200K => 200_000,
31 ModelContext::ClaudeHaiku => 200_000,
32 ModelContext::GeminiPro => 1_000_000,
33 ModelContext::GeminiFlash => 1_000_000,
34 ModelContext::Gpt4Turbo => 128_000,
35 ModelContext::Custom(size) => *size,
36 }
37 }
38
39 pub fn name(&self) -> &'static str {
41 match self {
42 ModelContext::Claude200K => "claude-sonnet",
43 ModelContext::ClaudeHaiku => "claude-haiku",
44 ModelContext::GeminiPro => "gemini-pro",
45 ModelContext::GeminiFlash => "gemini-flash",
46 ModelContext::Gpt4Turbo => "gpt-4-turbo",
47 ModelContext::Custom(_) => "custom",
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct TokenBudget {
55 pub context_window: usize,
57 pub system_reserve: usize,
59 pub source_context: usize,
61 pub rag_context: usize,
63 pub few_shot: usize,
65 pub output_target: usize,
67}
68
69impl TokenBudget {
70 pub fn new(model: ModelContext) -> Self {
72 Self {
73 context_window: model.window_size(),
74 system_reserve: 2_000,
75 source_context: 0,
76 rag_context: 0,
77 few_shot: 1_500,
78 output_target: 4_000,
79 }
80 }
81
82 pub fn with_source_context(mut self, tokens: usize) -> Self {
84 self.source_context = tokens;
85 self
86 }
87
88 pub fn with_rag_context(mut self, tokens: usize) -> Self {
90 self.rag_context = tokens;
91 self
92 }
93
94 pub fn with_output_target(mut self, tokens: usize) -> Self {
96 self.output_target = tokens;
97 self
98 }
99
100 pub fn prompt_tokens(&self) -> usize {
102 self.system_reserve + self.source_context + self.rag_context + self.few_shot
103 }
104
105 pub fn available_margin(&self) -> usize {
107 let used = self.prompt_tokens() + self.output_target;
108 self.context_window.saturating_sub(used)
109 }
110
111 pub fn validate(&self) -> Result<(), ContentError> {
113 let total = self.prompt_tokens() + self.output_target;
114 if total > self.context_window {
115 Err(ContentError::TokenBudgetExceeded { used: total, limit: self.context_window })
116 } else {
117 Ok(())
118 }
119 }
120
121 pub fn words_to_tokens(words: usize) -> usize {
123 (words as f64 * 1.3).ceil() as usize
124 }
125
126 pub fn tokens_to_words(tokens: usize) -> usize {
128 (tokens as f64 / 1.3).floor() as usize
129 }
130
131 pub fn format_display(&self, model_name: &str) -> String {
133 let mut output = String::new();
134 output.push_str(&format!(
135 "Token Budget for {} ({}K context):\n",
136 model_name,
137 self.context_window / 1000
138 ));
139 output.push_str(&format!("├── System prompt: {:>6} tokens\n", self.system_reserve));
140 output.push_str(&format!("├── Source context: {:>6} tokens\n", self.source_context));
141 output.push_str(&format!("├── RAG context: {:>6} tokens\n", self.rag_context));
142 output.push_str(&format!("├── Few-shot examples: {:>6} tokens\n", self.few_shot));
143 output.push_str(&format!(
144 "├── Output reserved: {:>6} tokens (~{} words)\n",
145 self.output_target,
146 Self::tokens_to_words(self.output_target)
147 ));
148 let margin = self.available_margin();
149 let status = if margin > 0 { "✓" } else { "✗" };
150 output.push_str(&format!("└── Available margin: {:>6} tokens {}\n", margin, status));
151 output
152 }
153}
154
155impl Default for TokenBudget {
156 fn default() -> Self {
157 Self::new(ModelContext::Claude200K)
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
170 fn test_model_context_default() {
171 let ctx = ModelContext::default();
172 assert_eq!(ctx, ModelContext::Claude200K);
173 }
174
175 #[test]
176 fn test_model_context_window_sizes() {
177 assert_eq!(ModelContext::Claude200K.window_size(), 200_000);
178 assert_eq!(ModelContext::ClaudeHaiku.window_size(), 200_000);
179 assert_eq!(ModelContext::GeminiPro.window_size(), 1_000_000);
180 assert_eq!(ModelContext::GeminiFlash.window_size(), 1_000_000);
181 assert_eq!(ModelContext::Gpt4Turbo.window_size(), 128_000);
182 assert_eq!(ModelContext::Custom(50_000).window_size(), 50_000);
183 }
184
185 #[test]
186 fn test_model_context_names() {
187 assert_eq!(ModelContext::Claude200K.name(), "claude-sonnet");
188 assert_eq!(ModelContext::ClaudeHaiku.name(), "claude-haiku");
189 assert_eq!(ModelContext::GeminiPro.name(), "gemini-pro");
190 assert_eq!(ModelContext::GeminiFlash.name(), "gemini-flash");
191 assert_eq!(ModelContext::Gpt4Turbo.name(), "gpt-4-turbo");
192 assert_eq!(ModelContext::Custom(1000).name(), "custom");
193 }
194
195 #[test]
196 fn test_model_context_serialization() {
197 let ctx = ModelContext::GeminiPro;
198 let json = serde_json::to_string(&ctx).expect("json serialize failed");
199 let deserialized: ModelContext =
200 serde_json::from_str(&json).expect("json deserialize failed");
201 assert_eq!(deserialized, ctx);
202 }
203
204 #[test]
205 fn test_model_context_custom_serialization() {
206 let ctx = ModelContext::Custom(75_000);
207 let json = serde_json::to_string(&ctx).expect("json serialize failed");
208 let deserialized: ModelContext =
209 serde_json::from_str(&json).expect("json deserialize failed");
210 assert_eq!(deserialized, ctx);
211 assert_eq!(deserialized.window_size(), 75_000);
212 }
213
214 #[test]
219 fn test_token_budget_new() {
220 let budget = TokenBudget::new(ModelContext::Claude200K);
221 assert_eq!(budget.context_window, 200_000);
222 assert_eq!(budget.system_reserve, 2_000);
223 assert_eq!(budget.source_context, 0);
224 assert_eq!(budget.rag_context, 0);
225 assert_eq!(budget.few_shot, 1_500);
226 assert_eq!(budget.output_target, 4_000);
227 }
228
229 #[test]
230 fn test_token_budget_default() {
231 let budget = TokenBudget::default();
232 assert_eq!(budget.context_window, 200_000);
233 }
234
235 #[test]
236 fn test_token_budget_with_source_context() {
237 let budget = TokenBudget::new(ModelContext::Claude200K).with_source_context(10_000);
238 assert_eq!(budget.source_context, 10_000);
239 }
240
241 #[test]
242 fn test_token_budget_with_rag_context() {
243 let budget = TokenBudget::new(ModelContext::Claude200K).with_rag_context(5_000);
244 assert_eq!(budget.rag_context, 5_000);
245 }
246
247 #[test]
248 fn test_token_budget_with_output_target() {
249 let budget = TokenBudget::new(ModelContext::Claude200K).with_output_target(8_000);
250 assert_eq!(budget.output_target, 8_000);
251 }
252
253 #[test]
254 fn test_token_budget_prompt_tokens() {
255 let budget = TokenBudget::new(ModelContext::Claude200K)
256 .with_source_context(10_000)
257 .with_rag_context(5_000);
258 assert_eq!(budget.prompt_tokens(), 18_500);
260 }
261
262 #[test]
263 fn test_token_budget_available_margin() {
264 let budget = TokenBudget::new(ModelContext::Claude200K);
265 let margin = budget.available_margin();
267 assert_eq!(margin, 200_000 - 3_500 - 4_000);
268 }
269
270 #[test]
271 fn test_token_budget_validate_ok() {
272 let budget = TokenBudget::new(ModelContext::Claude200K);
273 assert!(budget.validate().is_ok());
274 }
275
276 #[test]
277 fn test_token_budget_validate_exceeded() {
278 let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
279 assert!(budget.validate().is_err());
280 }
281
282 #[test]
283 fn test_words_to_tokens() {
284 assert_eq!(TokenBudget::words_to_tokens(100), 130);
286 assert_eq!(TokenBudget::words_to_tokens(0), 0);
287 }
288
289 #[test]
290 fn test_tokens_to_words() {
291 assert_eq!(TokenBudget::tokens_to_words(130), 100);
293 assert_eq!(TokenBudget::tokens_to_words(0), 0);
294 }
295
296 #[test]
297 fn test_token_budget_format_display() {
298 let budget = TokenBudget::new(ModelContext::Claude200K);
299 let output = budget.format_display("claude-sonnet");
300 assert!(output.contains("Token Budget for claude-sonnet"));
301 assert!(output.contains("200K context"));
302 assert!(output.contains("System prompt"));
303 assert!(output.contains("Available margin"));
304 assert!(output.contains("✓")); }
306
307 #[test]
308 fn test_token_budget_format_display_exceeded() {
309 let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
310 let output = budget.format_display("custom");
311 assert!(output.contains("Available margin"));
313 }
314
315 #[test]
316 fn test_token_budget_serialization() {
317 let budget = TokenBudget::new(ModelContext::GeminiPro)
318 .with_source_context(5_000)
319 .with_rag_context(3_000);
320 let json = serde_json::to_string(&budget).expect("json serialize failed");
321 let deserialized: TokenBudget =
322 serde_json::from_str(&json).expect("json deserialize failed");
323 assert_eq!(deserialized, budget);
324 }
325
326 #[test]
327 fn test_token_budget_builder_chain() {
328 let budget = TokenBudget::new(ModelContext::Gpt4Turbo)
329 .with_source_context(10_000)
330 .with_rag_context(8_000)
331 .with_output_target(6_000);
332
333 assert_eq!(budget.context_window, 128_000);
334 assert_eq!(budget.source_context, 10_000);
335 assert_eq!(budget.rag_context, 8_000);
336 assert_eq!(budget.output_target, 6_000);
337 }
338}