Skip to main content

bamboo_compression/
limits.rs

1//! Model context window limits registry.
2//!
3//! Provides known context window sizes for common models, with fallback to
4//! configurable user limits via file or session overrides.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11/// Known model defaults: `(pattern, max_context_tokens, max_output_tokens)`.
12///
13/// These are the built-in defaults used when there is no user override in
14/// `config.json:model_limits` and no legacy `model_limits.json`.
15pub const KNOWN_MODEL_LIMITS: &[(&str, u32, u32)] = &[
16    // GitHub Copilot model profiles (2026-03)
17    // Anthropic
18    ("claude-haiku-4.5", 160_000, 32_000),
19    ("claude-opus-4.5", 160_000, 32_000),
20    ("claude-opus-4.6", 192_000, 64_000),
21    ("claude-sonnet-4.6", 200_000, 32_000),
22    ("claude-sonnet-4-6", 200_000, 32_000), // Alternate wire format
23    ("claude-sonnet-4.5", 200_000, 32_000),
24    ("claude-sonnet-4-5", 200_000, 32_000), // Alternate wire format
25    // Google
26    ("gemini-2.5-pro", 128_000, 16_000),
27    ("gemini-3-flash-preview", 1_000_000, 8_192),
28    ("gemini-3.1-pro-preview", 128_000, 64_000),
29    // OpenAI
30    ("gpt-5.4", 1_050_000, 32_768),
31    ("gpt-5.3-codex", 400_000, 128_000),
32    ("gpt-5.2-codex", 400_000, 128_000),
33    ("gpt-5.2", 400_000, 128_000),
34    ("gpt-5.1", 400_000, 128_000),
35    ("gpt-5", 400_000, 128_000),
36    ("gpt-5.4-mini", 128_000, 16_384),
37    ("gpt-4.1", 128_000, 16_384),
38    ("gpt-4-o-preview", 128_000, 16_384),
39    ("gpt-4o-preview", 128_000, 16_384), // Alternate spelling
40    // xAI
41    ("grok-code-fast-1", 128_000, 10_240),
42    // GitHub specialized
43    ("oswe-vscode-prime", 264_000, 64_000),
44    // Backward-compatible aliases and non-Copilot profiles
45    ("gpt-5.4-thinking", 1_000_000, 128_000),
46    ("gpt-5.2-pro", 256_000, 64_000),
47    ("gpt-5-mini", 400_000, 128_000),
48    ("gpt-4o", 128_000, 16_000),
49    // Moonshot
50    ("kimi-k2.5", 256_000, 64_000),
51    ("kimi-for-coding", 256_000, 64_000),
52    // Zhipu
53    ("glm-5", 200_000, 128_000),
54    // Compatibility fallbacks
55    ("gpt-4o-mini", 128_000, 16_000),
56    ("gpt-4-turbo", 128_000, 16_000),
57    ("gpt-4", 8_192, 4_096),
58    ("gpt-3.5-turbo", 16_385, 4_096),
59    ("claude-3-5-sonnet", 200_000, 8_192),
60    ("claude-3-5-sonnet-20241022", 200_000, 8_192),
61    ("claude-3-5-sonnet-20240620", 200_000, 8_192),
62    ("claude-3-opus", 200_000, 8_192),
63    ("claude-3-opus-20240229", 200_000, 8_192),
64    ("claude-3-sonnet", 200_000, 8_192),
65    ("claude-3-haiku", 200_000, 8_192),
66    ("copilot-chat", 128_000, 16_000),
67    // Default fallback
68    ("default", 128_000, 4_096),
69];
70
71/// Default maximum output tokens (reserve ~25% for response).
72pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
73
74/// Default safety margin for token counting errors.
75pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
76
77/// Model limit configuration (user-overridable).
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ModelLimit {
80    /// Model identifier (partial match supported, e.g., "gpt-4" matches "gpt-4o")
81    pub model_pattern: String,
82    /// Maximum context window size in tokens
83    pub max_context_tokens: u32,
84    /// Maximum output tokens (defaults to min(4096, max_context / 4))
85    #[serde(default)]
86    pub max_output_tokens: Option<u32>,
87    /// Safety margin for token counting (defaults to 1000)
88    #[serde(default)]
89    pub safety_margin: Option<u32>,
90}
91
92impl ModelLimit {
93    /// Create a new model limit with defaults.
94    pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
95        Self {
96            model_pattern: model_pattern.into(),
97            max_context_tokens,
98            max_output_tokens: None,
99            safety_margin: None,
100        }
101    }
102
103    /// Get max output tokens with default calculation.
104    pub fn get_max_output_tokens(&self) -> u32 {
105        self.max_output_tokens
106            .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
107    }
108
109    /// Get safety margin, scaling proportionally with context window.
110    pub fn get_safety_margin(&self) -> u32 {
111        self.safety_margin
112            .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
113    }
114}
115
116fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
117    let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
118    limit.max_output_tokens = Some(max_output_tokens);
119    limit
120}
121
122/// Registry for model limits with built-in defaults and user overrides.
123#[derive(Debug, Clone)]
124pub struct ModelLimitsRegistry {
125    /// User-provided overrides (higher priority than built-in)
126    user_limits: HashMap<String, ModelLimit>,
127    /// Default path for user configuration file
128    config_path: Option<PathBuf>,
129}
130
131impl ModelLimitsRegistry {
132    /// Create a new registry with built-in defaults only.
133    pub fn new() -> Self {
134        Self {
135            user_limits: HashMap::new(),
136            config_path: None,
137        }
138    }
139
140    /// Create a registry with a specific config file path.
141    pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
142        Self {
143            user_limits: HashMap::new(),
144            config_path: Some(path.into()),
145        }
146    }
147
148    /// Load user overrides from the default configuration path.
149    ///
150    /// Default path: `{bamboo_data_dir}/model_limits.json`
151    pub async fn load_user_config(&mut self) -> std::io::Result<()> {
152        let path = self
153            .config_path
154            .clone()
155            .unwrap_or_else(get_default_config_path);
156
157        if !path.exists() {
158            return Ok(());
159        }
160
161        let content = tokio::fs::read_to_string(&path).await?;
162        let limits: Vec<ModelLimit> = serde_json::from_str(&content)
163            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
164
165        for limit in limits {
166            self.user_limits.insert(limit.model_pattern.clone(), limit);
167        }
168
169        tracing::info!(
170            "Loaded {} user model limits from {:?}",
171            self.user_limits.len(),
172            path
173        );
174        Ok(())
175    }
176
177    /// Add a user limit override.
178    pub fn add_limit(&mut self, limit: ModelLimit) {
179        self.user_limits.insert(limit.model_pattern.clone(), limit);
180    }
181
182    /// Get limit for a model, with user overrides taking priority.
183    ///
184    /// Returns `None` if no matching limit is found.
185    ///
186    /// # Matching Strategy
187    /// 1. Exact match (highest priority)
188    /// 2. Model contains pattern (e.g., "gpt-4o-mini" contains "gpt-4o")
189    /// 3. Pattern contains model (e.g., "gpt-4" contains "gpt")
190    ///
191    /// For partial matches, the longest (most specific) pattern wins.
192    pub fn get(&self, model: &str) -> Option<ModelLimit> {
193        // First check user limits for exact match
194        if let Some(limit) = self.user_limits.get(model) {
195            return Some(limit.clone());
196        }
197
198        // Check built-in limits for exact match
199        for (pattern, max_context_tokens, max_output_tokens) in KNOWN_MODEL_LIMITS {
200            if *pattern == model {
201                return Some(builtin_limit(
202                    model,
203                    *max_context_tokens,
204                    *max_output_tokens,
205                ));
206            }
207        }
208
209        // Find the best partial match from user limits
210        // Sort by pattern length (longer = more specific) for deterministic selection
211        let best_user_match = self
212            .user_limits
213            .iter()
214            .filter(|(pattern, _)| model.contains(*pattern) || pattern.contains(model))
215            .max_by_key(|(pattern, _)| pattern.len())
216            .map(|(_, limit)| limit.clone());
217
218        if let Some(limit) = best_user_match {
219            return Some(limit);
220        }
221
222        // Find the best partial match from built-in limits
223        let best_builtin_match = KNOWN_MODEL_LIMITS
224            .iter()
225            .filter(|(pattern, _, _)| model.contains(*pattern) || pattern.contains(model))
226            .max_by_key(|(pattern, _, _)| pattern.len());
227
228        if let Some((pattern, max_context_tokens, max_output_tokens)) = best_builtin_match {
229            return Some(builtin_limit(
230                pattern,
231                *max_context_tokens,
232                *max_output_tokens,
233            ));
234        }
235
236        None
237    }
238
239    /// Get limit for a model with fallback to default.
240    pub fn get_or_default(&self, model: &str) -> ModelLimit {
241        self.get(model).unwrap_or_else(|| {
242            let default = KNOWN_MODEL_LIMITS
243                .iter()
244                .find(|(k, _, _)| *k == "default")
245                .map(|(_, max_context_tokens, max_output_tokens)| {
246                    (*max_context_tokens, *max_output_tokens)
247                })
248                .unwrap_or((128_000, DEFAULT_MAX_OUTPUT_TOKENS));
249            let mut limit = ModelLimit::new("default", default.0);
250            limit.max_output_tokens = Some(default.1);
251            limit
252        })
253    }
254
255    /// Save current user limits to the configuration file.
256    pub async fn save_user_config(&self) -> std::io::Result<()> {
257        let path = self
258            .config_path
259            .clone()
260            .unwrap_or_else(get_default_config_path);
261
262        // Ensure parent directory exists
263        if let Some(parent) = path.parent() {
264            tokio::fs::create_dir_all(parent).await?;
265        }
266
267        let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
268        let content = serde_json::to_string_pretty(&limits)
269            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
270        tokio::fs::write(&path, content).await?;
271
272        Ok(())
273    }
274
275    /// List all user-defined limits.
276    pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
277        self.user_limits.values().collect()
278    }
279}
280
281impl Default for ModelLimitsRegistry {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287/// Get the default configuration file path.
288///
289/// Returns `{bamboo_data_dir}/model_limits.json`.
290pub fn get_default_config_path() -> PathBuf {
291    bamboo_infrastructure::paths::bamboo_dir().join("model_limits.json")
292}
293
294/// Load user model limits from unified `config.json` root key `model_limits`.
295///
296/// Returns:
297/// - `Ok(None)` when `model_limits` key is absent.
298/// - `Ok(Some(vec))` when key exists and is valid (including empty array).
299/// - `Err(...)` when key exists but is not a valid `Vec<ModelLimit>`.
300pub fn load_model_limits_from_unified_config(
301    config: &bamboo_infrastructure::Config,
302) -> Result<Option<Vec<ModelLimit>>, String> {
303    let Some(raw_limits) = config.extra.get("model_limits") else {
304        return Ok(None);
305    };
306
307    if raw_limits.is_null() {
308        return Ok(Some(Vec::new()));
309    }
310
311    match raw_limits {
312        Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
313            .map(Some)
314            .map_err(|error| format!("invalid config.model_limits format: {error}")),
315        _ => Err("invalid config.model_limits format: expected array".to_string()),
316    }
317}
318
319/// Create a token budget for a specific model.
320///
321/// This is a convenience function that creates a budget with appropriate defaults.
322pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
323    let registry = ModelLimitsRegistry::default();
324    let limit = registry.get_or_default(model);
325
326    crate::TokenBudget {
327        max_context_tokens: limit.max_context_tokens,
328        max_output_tokens: limit.get_max_output_tokens(),
329        strategy,
330        safety_margin: limit.get_safety_margin(),
331        compression_trigger_percent: 85, // legacy — only used when working_reserve_tokens == 0
332        compression_target_percent: 45,
333        working_reserve_tokens: 50_000,
334        fallback_trigger_percent: 75,
335        prompt_cache_min_tool_output_chars: 1_200,
336        prompt_cache_head_chars: 280,
337        prompt_cache_tail_chars: 180,
338        prompt_cache_recent_user_turns: 2,
339        prompt_cache_recent_tool_chains: 2,
340        max_tool_output_tokens: 0,
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn builtin_limits_contain_common_models() {
350        let gpt54 = KNOWN_MODEL_LIMITS
351            .iter()
352            .find(|(k, _, _)| *k == "gpt-5.4")
353            .expect("Should have gpt-5.4");
354        assert_eq!(gpt54.1, 1_050_000);
355        assert_eq!(gpt54.2, 32_768);
356
357        let gpt53_codex = KNOWN_MODEL_LIMITS
358            .iter()
359            .find(|(k, _, _)| *k == "gpt-5.3-codex")
360            .expect("Should have gpt-5.3-codex");
361        assert_eq!(gpt53_codex.1, 400_000);
362        assert_eq!(gpt53_codex.2, 128_000);
363
364        let gpt52_codex = KNOWN_MODEL_LIMITS
365            .iter()
366            .find(|(k, _, _)| *k == "gpt-5.2-codex")
367            .expect("Should have gpt-5.2-codex");
368        assert_eq!(gpt52_codex.1, 400_000);
369        assert_eq!(gpt52_codex.2, 128_000);
370
371        let gemini31_pro_preview = KNOWN_MODEL_LIMITS
372            .iter()
373            .find(|(k, _, _)| *k == "gemini-3.1-pro-preview")
374            .expect("Should have gemini-3.1-pro-preview");
375        assert_eq!(gemini31_pro_preview.1, 128_000);
376        assert_eq!(gemini31_pro_preview.2, 64_000);
377    }
378
379    #[test]
380    fn registry_finds_builtin_by_exact_match() {
381        let registry = ModelLimitsRegistry::new();
382        let limit = registry
383            .get("gpt-5.2-codex")
384            .expect("Should find gpt-5.2-codex");
385        assert_eq!(limit.max_context_tokens, 400_000);
386        assert_eq!(limit.get_max_output_tokens(), 128_000);
387    }
388
389    #[test]
390    fn registry_finds_builtin_by_partial_match() {
391        let registry = ModelLimitsRegistry::new();
392        // "gpt-5.2-codex-preview" contains "gpt-5.2-codex"
393        let limit = registry
394            .get("gpt-5.2-codex-preview")
395            .expect("Should find gpt-5.2-codex");
396        assert_eq!(limit.max_context_tokens, 400_000);
397        assert_eq!(limit.get_max_output_tokens(), 128_000);
398    }
399
400    #[test]
401    fn registry_returns_default_for_unknown() {
402        let registry = ModelLimitsRegistry::new();
403        let limit = registry.get_or_default("unknown-model-xyz");
404        assert_eq!(limit.model_pattern, "default");
405    }
406
407    #[test]
408    fn user_override_takes_precedence() {
409        let mut registry = ModelLimitsRegistry::new();
410        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); // Override with smaller limit
411
412        let limit = registry
413            .get("gpt-5.2-codex")
414            .expect("Should find overridden limit");
415        assert_eq!(limit.max_context_tokens, 64_000);
416    }
417
418    #[test]
419    fn model_limit_calculates_default_output_tokens() {
420        let limit = ModelLimit::new("test", 100_000);
421        // Default is min(max_context / 4, 4096) = min(25000, 4096) = 4096
422        assert_eq!(limit.get_max_output_tokens(), 4096);
423    }
424
425    #[test]
426    fn model_limit_uses_custom_output_tokens() {
427        let mut limit = ModelLimit::new("test", 100_000);
428        limit.max_output_tokens = Some(8192);
429        assert_eq!(limit.get_max_output_tokens(), 8192);
430    }
431
432    #[test]
433    fn model_limit_calculates_small_context_output() {
434        let limit = ModelLimit::new("test", 8_192);
435        // Default is min(8192 / 4, 4096) = 2048
436        assert_eq!(limit.get_max_output_tokens(), 2048);
437    }
438
439    #[test]
440    fn unified_config_loader_returns_none_when_absent() {
441        let temp_dir = tempfile::tempdir().expect("tempdir");
442        let config =
443            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
444        let loaded = load_model_limits_from_unified_config(&config).expect("should parse");
445        assert!(loaded.is_none());
446    }
447
448    #[test]
449    fn unified_config_loader_reads_valid_model_limits() {
450        let temp_dir = tempfile::tempdir().expect("tempdir");
451        let mut config =
452            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
453        config.extra.insert(
454            "model_limits".to_string(),
455            serde_json::json!([
456                {
457                    "model_pattern": "gpt-5.2-codex",
458                    "max_context_tokens": 64000,
459                    "max_output_tokens": 2048,
460                    "safety_margin": 512
461                }
462            ]),
463        );
464
465        let loaded = load_model_limits_from_unified_config(&config)
466            .expect("should parse")
467            .expect("should exist");
468        assert_eq!(loaded.len(), 1);
469        assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
470        assert_eq!(loaded[0].max_context_tokens, 64_000);
471        assert_eq!(loaded[0].max_output_tokens, Some(2048));
472        assert_eq!(loaded[0].safety_margin, Some(512));
473    }
474
475    #[test]
476    fn unified_config_loader_errors_on_invalid_shape() {
477        let temp_dir = tempfile::tempdir().expect("tempdir");
478        let mut config =
479            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
480        config.extra.insert(
481            "model_limits".to_string(),
482            serde_json::json!({"unexpected": true}),
483        );
484
485        let error = load_model_limits_from_unified_config(&config).expect_err("should error");
486        assert!(error.contains("expected array"));
487    }
488
489    #[test]
490    fn safety_margin_scales_with_context_window() {
491        // Small context → floor at DEFAULT_SAFETY_MARGIN (1000)
492        let small = ModelLimit::new("test", 8_192);
493        assert_eq!(small.get_safety_margin(), 1000);
494
495        // Medium context → proportional
496        let medium = ModelLimit::new("test", 200_000);
497        assert_eq!(medium.get_safety_margin(), 2000);
498
499        // Large context → proportional
500        let large = ModelLimit::new("test", 1_050_000);
501        assert_eq!(large.get_safety_margin(), 10_500);
502
503        // Explicit override takes precedence
504        let mut custom = ModelLimit::new("test", 200_000);
505        custom.safety_margin = Some(500);
506        assert_eq!(custom.get_safety_margin(), 500);
507    }
508}