1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11pub const KNOWN_MODEL_LIMITS: &[(&str, u32, u32)] = &[
16 ("claude-haiku-4.5", 160_000, 32_000),
19 ("claude-opus-4.5", 160_000, 32_000),
20 ("claude-opus-4.6", 192_000, 64_000),
21 ("claude-sonnet-4.6", 200_000, 32_000),
22 ("claude-sonnet-4-6", 200_000, 32_000), ("claude-sonnet-4.5", 200_000, 32_000),
24 ("claude-sonnet-4-5", 200_000, 32_000), ("gemini-2.5-pro", 128_000, 16_000),
27 ("gemini-3-flash-preview", 1_000_000, 8_192),
28 ("gemini-3.1-pro-preview", 128_000, 64_000),
29 ("gpt-5.4", 1_050_000, 32_768),
31 ("gpt-5.3-codex", 400_000, 128_000),
32 ("gpt-5.2-codex", 400_000, 128_000),
33 ("gpt-5.2", 400_000, 128_000),
34 ("gpt-5.1", 400_000, 128_000),
35 ("gpt-5", 400_000, 128_000),
36 ("gpt-5.4-mini", 128_000, 16_384),
37 ("gpt-4.1", 128_000, 16_384),
38 ("gpt-4-o-preview", 128_000, 16_384),
39 ("gpt-4o-preview", 128_000, 16_384), ("grok-code-fast-1", 128_000, 10_240),
42 ("oswe-vscode-prime", 264_000, 64_000),
44 ("gpt-5.4-thinking", 1_000_000, 128_000),
46 ("gpt-5.2-pro", 256_000, 64_000),
47 ("gpt-5-mini", 400_000, 128_000),
48 ("gpt-4o", 128_000, 16_000),
49 ("kimi-k2.5", 256_000, 64_000),
51 ("kimi-for-coding", 256_000, 64_000),
52 ("glm-5", 200_000, 128_000),
54 ("gpt-4o-mini", 128_000, 16_000),
56 ("gpt-4-turbo", 128_000, 16_000),
57 ("gpt-4", 8_192, 4_096),
58 ("gpt-3.5-turbo", 16_385, 4_096),
59 ("claude-3-5-sonnet", 200_000, 8_192),
60 ("claude-3-5-sonnet-20241022", 200_000, 8_192),
61 ("claude-3-5-sonnet-20240620", 200_000, 8_192),
62 ("claude-3-opus", 200_000, 8_192),
63 ("claude-3-opus-20240229", 200_000, 8_192),
64 ("claude-3-sonnet", 200_000, 8_192),
65 ("claude-3-haiku", 200_000, 8_192),
66 ("copilot-chat", 128_000, 16_000),
67 ("default", 128_000, 4_096),
69];
70
71pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
73
74pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ModelLimit {
80 pub model_pattern: String,
82 pub max_context_tokens: u32,
84 #[serde(default)]
86 pub max_output_tokens: Option<u32>,
87 #[serde(default)]
89 pub safety_margin: Option<u32>,
90}
91
92impl ModelLimit {
93 pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
95 Self {
96 model_pattern: model_pattern.into(),
97 max_context_tokens,
98 max_output_tokens: None,
99 safety_margin: None,
100 }
101 }
102
103 pub fn get_max_output_tokens(&self) -> u32 {
105 self.max_output_tokens
106 .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
107 }
108
109 pub fn get_safety_margin(&self) -> u32 {
111 self.safety_margin
112 .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
113 }
114}
115
116fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
117 let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
118 limit.max_output_tokens = Some(max_output_tokens);
119 limit
120}
121
122#[derive(Debug, Clone)]
124pub struct ModelLimitsRegistry {
125 user_limits: HashMap<String, ModelLimit>,
127 config_path: Option<PathBuf>,
129}
130
131impl ModelLimitsRegistry {
132 pub fn new() -> Self {
134 Self {
135 user_limits: HashMap::new(),
136 config_path: None,
137 }
138 }
139
140 pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
142 Self {
143 user_limits: HashMap::new(),
144 config_path: Some(path.into()),
145 }
146 }
147
148 pub async fn load_user_config(&mut self) -> std::io::Result<()> {
152 let path = self
153 .config_path
154 .clone()
155 .unwrap_or_else(get_default_config_path);
156
157 if !path.exists() {
158 return Ok(());
159 }
160
161 let content = tokio::fs::read_to_string(&path).await?;
162 let limits: Vec<ModelLimit> = serde_json::from_str(&content)
163 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
164
165 for limit in limits {
166 self.user_limits.insert(limit.model_pattern.clone(), limit);
167 }
168
169 tracing::info!(
170 "Loaded {} user model limits from {:?}",
171 self.user_limits.len(),
172 path
173 );
174 Ok(())
175 }
176
177 pub fn add_limit(&mut self, limit: ModelLimit) {
179 self.user_limits.insert(limit.model_pattern.clone(), limit);
180 }
181
182 pub fn get(&self, model: &str) -> Option<ModelLimit> {
193 if let Some(limit) = self.user_limits.get(model) {
195 return Some(limit.clone());
196 }
197
198 for (pattern, max_context_tokens, max_output_tokens) in KNOWN_MODEL_LIMITS {
200 if *pattern == model {
201 return Some(builtin_limit(
202 model,
203 *max_context_tokens,
204 *max_output_tokens,
205 ));
206 }
207 }
208
209 let best_user_match = self
212 .user_limits
213 .iter()
214 .filter(|(pattern, _)| model.contains(*pattern) || pattern.contains(model))
215 .max_by_key(|(pattern, _)| pattern.len())
216 .map(|(_, limit)| limit.clone());
217
218 if let Some(limit) = best_user_match {
219 return Some(limit);
220 }
221
222 let best_builtin_match = KNOWN_MODEL_LIMITS
224 .iter()
225 .filter(|(pattern, _, _)| model.contains(*pattern) || pattern.contains(model))
226 .max_by_key(|(pattern, _, _)| pattern.len());
227
228 if let Some((pattern, max_context_tokens, max_output_tokens)) = best_builtin_match {
229 return Some(builtin_limit(
230 pattern,
231 *max_context_tokens,
232 *max_output_tokens,
233 ));
234 }
235
236 None
237 }
238
239 pub fn get_or_default(&self, model: &str) -> ModelLimit {
241 self.get(model).unwrap_or_else(|| {
242 let default = KNOWN_MODEL_LIMITS
243 .iter()
244 .find(|(k, _, _)| *k == "default")
245 .map(|(_, max_context_tokens, max_output_tokens)| {
246 (*max_context_tokens, *max_output_tokens)
247 })
248 .unwrap_or((128_000, DEFAULT_MAX_OUTPUT_TOKENS));
249 let mut limit = ModelLimit::new("default", default.0);
250 limit.max_output_tokens = Some(default.1);
251 limit
252 })
253 }
254
255 pub async fn save_user_config(&self) -> std::io::Result<()> {
257 let path = self
258 .config_path
259 .clone()
260 .unwrap_or_else(get_default_config_path);
261
262 if let Some(parent) = path.parent() {
264 tokio::fs::create_dir_all(parent).await?;
265 }
266
267 let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
268 let content = serde_json::to_string_pretty(&limits)
269 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
270 tokio::fs::write(&path, content).await?;
271
272 Ok(())
273 }
274
275 pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
277 self.user_limits.values().collect()
278 }
279}
280
281impl Default for ModelLimitsRegistry {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287pub fn get_default_config_path() -> PathBuf {
291 bamboo_infrastructure::paths::bamboo_dir().join("model_limits.json")
292}
293
294pub fn load_model_limits_from_unified_config(
301 config: &bamboo_infrastructure::Config,
302) -> Result<Option<Vec<ModelLimit>>, String> {
303 let Some(raw_limits) = config.extra.get("model_limits") else {
304 return Ok(None);
305 };
306
307 if raw_limits.is_null() {
308 return Ok(Some(Vec::new()));
309 }
310
311 match raw_limits {
312 Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
313 .map(Some)
314 .map_err(|error| format!("invalid config.model_limits format: {error}")),
315 _ => Err("invalid config.model_limits format: expected array".to_string()),
316 }
317}
318
319pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
323 let registry = ModelLimitsRegistry::default();
324 let limit = registry.get_or_default(model);
325
326 crate::TokenBudget {
327 max_context_tokens: limit.max_context_tokens,
328 max_output_tokens: limit.get_max_output_tokens(),
329 strategy,
330 safety_margin: limit.get_safety_margin(),
331 compression_trigger_percent: 85, compression_target_percent: 45,
333 working_reserve_tokens: 50_000,
334 fallback_trigger_percent: 75,
335 prompt_cache_min_tool_output_chars: 1_200,
336 prompt_cache_head_chars: 280,
337 prompt_cache_tail_chars: 180,
338 prompt_cache_recent_user_turns: 2,
339 prompt_cache_recent_tool_chains: 2,
340 max_tool_output_tokens: 0,
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn builtin_limits_contain_common_models() {
350 let gpt54 = KNOWN_MODEL_LIMITS
351 .iter()
352 .find(|(k, _, _)| *k == "gpt-5.4")
353 .expect("Should have gpt-5.4");
354 assert_eq!(gpt54.1, 1_050_000);
355 assert_eq!(gpt54.2, 32_768);
356
357 let gpt53_codex = KNOWN_MODEL_LIMITS
358 .iter()
359 .find(|(k, _, _)| *k == "gpt-5.3-codex")
360 .expect("Should have gpt-5.3-codex");
361 assert_eq!(gpt53_codex.1, 400_000);
362 assert_eq!(gpt53_codex.2, 128_000);
363
364 let gpt52_codex = KNOWN_MODEL_LIMITS
365 .iter()
366 .find(|(k, _, _)| *k == "gpt-5.2-codex")
367 .expect("Should have gpt-5.2-codex");
368 assert_eq!(gpt52_codex.1, 400_000);
369 assert_eq!(gpt52_codex.2, 128_000);
370
371 let gemini31_pro_preview = KNOWN_MODEL_LIMITS
372 .iter()
373 .find(|(k, _, _)| *k == "gemini-3.1-pro-preview")
374 .expect("Should have gemini-3.1-pro-preview");
375 assert_eq!(gemini31_pro_preview.1, 128_000);
376 assert_eq!(gemini31_pro_preview.2, 64_000);
377 }
378
379 #[test]
380 fn registry_finds_builtin_by_exact_match() {
381 let registry = ModelLimitsRegistry::new();
382 let limit = registry
383 .get("gpt-5.2-codex")
384 .expect("Should find gpt-5.2-codex");
385 assert_eq!(limit.max_context_tokens, 400_000);
386 assert_eq!(limit.get_max_output_tokens(), 128_000);
387 }
388
389 #[test]
390 fn registry_finds_builtin_by_partial_match() {
391 let registry = ModelLimitsRegistry::new();
392 let limit = registry
394 .get("gpt-5.2-codex-preview")
395 .expect("Should find gpt-5.2-codex");
396 assert_eq!(limit.max_context_tokens, 400_000);
397 assert_eq!(limit.get_max_output_tokens(), 128_000);
398 }
399
400 #[test]
401 fn registry_returns_default_for_unknown() {
402 let registry = ModelLimitsRegistry::new();
403 let limit = registry.get_or_default("unknown-model-xyz");
404 assert_eq!(limit.model_pattern, "default");
405 }
406
407 #[test]
408 fn user_override_takes_precedence() {
409 let mut registry = ModelLimitsRegistry::new();
410 registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); let limit = registry
413 .get("gpt-5.2-codex")
414 .expect("Should find overridden limit");
415 assert_eq!(limit.max_context_tokens, 64_000);
416 }
417
418 #[test]
419 fn model_limit_calculates_default_output_tokens() {
420 let limit = ModelLimit::new("test", 100_000);
421 assert_eq!(limit.get_max_output_tokens(), 4096);
423 }
424
425 #[test]
426 fn model_limit_uses_custom_output_tokens() {
427 let mut limit = ModelLimit::new("test", 100_000);
428 limit.max_output_tokens = Some(8192);
429 assert_eq!(limit.get_max_output_tokens(), 8192);
430 }
431
432 #[test]
433 fn model_limit_calculates_small_context_output() {
434 let limit = ModelLimit::new("test", 8_192);
435 assert_eq!(limit.get_max_output_tokens(), 2048);
437 }
438
439 #[test]
440 fn unified_config_loader_returns_none_when_absent() {
441 let temp_dir = tempfile::tempdir().expect("tempdir");
442 let config =
443 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
444 let loaded = load_model_limits_from_unified_config(&config).expect("should parse");
445 assert!(loaded.is_none());
446 }
447
448 #[test]
449 fn unified_config_loader_reads_valid_model_limits() {
450 let temp_dir = tempfile::tempdir().expect("tempdir");
451 let mut config =
452 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
453 config.extra.insert(
454 "model_limits".to_string(),
455 serde_json::json!([
456 {
457 "model_pattern": "gpt-5.2-codex",
458 "max_context_tokens": 64000,
459 "max_output_tokens": 2048,
460 "safety_margin": 512
461 }
462 ]),
463 );
464
465 let loaded = load_model_limits_from_unified_config(&config)
466 .expect("should parse")
467 .expect("should exist");
468 assert_eq!(loaded.len(), 1);
469 assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
470 assert_eq!(loaded[0].max_context_tokens, 64_000);
471 assert_eq!(loaded[0].max_output_tokens, Some(2048));
472 assert_eq!(loaded[0].safety_margin, Some(512));
473 }
474
475 #[test]
476 fn unified_config_loader_errors_on_invalid_shape() {
477 let temp_dir = tempfile::tempdir().expect("tempdir");
478 let mut config =
479 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
480 config.extra.insert(
481 "model_limits".to_string(),
482 serde_json::json!({"unexpected": true}),
483 );
484
485 let error = load_model_limits_from_unified_config(&config).expect_err("should error");
486 assert!(error.contains("expected array"));
487 }
488
489 #[test]
490 fn safety_margin_scales_with_context_window() {
491 let small = ModelLimit::new("test", 8_192);
493 assert_eq!(small.get_safety_margin(), 1000);
494
495 let medium = ModelLimit::new("test", 200_000);
497 assert_eq!(medium.get_safety_margin(), 2000);
498
499 let large = ModelLimit::new("test", 1_050_000);
501 assert_eq!(large.get_safety_margin(), 10_500);
502
503 let mut custom = ModelLimit::new("test", 200_000);
505 custom.safety_margin = Some(500);
506 assert_eq!(custom.get_safety_margin(), 500);
507 }
508}