Skip to main content

bamboo_compression/
limits.rs

1//! Model context window limits registry.
2//!
3//! There is intentionally **no** built-in per-model table. Real per-model
4//! values come from two sources, in priority order:
5//!   1. provider runtime metadata (e.g. Copilot reports real context/output),
6//!   2. user overrides persisted in `model_limits.json`.
7//! Anything without a match falls back to a single global default
8//! (`DEFAULT_MAX_CONTEXT_TOKENS` / `DEFAULT_MAX_OUTPUT_TOKENS`). This keeps the
9//! registry from going stale as models churn — see `token_budget.rs`.
10
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::path::PathBuf;
15
16/// Sentinel pattern used for the single global fallback limit.
17pub const DEFAULT_MODEL_PATTERN: &str = "default";
18
19/// Global default context window applied to any model without a provider
20/// metadata value or a user override. 200K is a mainstream range across
21/// current frontier models.
22pub const DEFAULT_MAX_CONTEXT_TOKENS: u32 = 200_000;
23
24/// Global default maximum output tokens.
25pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 64_000;
26
27/// Default safety margin for token counting errors (floor; scales with context
28/// window via [`ModelLimit::get_safety_margin`]).
29pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
30
31/// Build the single global default limit (`200K` context / `64K` output).
32pub fn default_model_limit() -> ModelLimit {
33    builtin_limit(
34        DEFAULT_MODEL_PATTERN,
35        DEFAULT_MAX_CONTEXT_TOKENS,
36        DEFAULT_MAX_OUTPUT_TOKENS,
37    )
38}
39
40/// Whether a user override is a no-op — identical to the global default, so it
41/// carries no information and need not be persisted (diff-only storage).
42///
43/// The `model_pattern` is irrelevant: any model pinned to exactly the default
44/// context/output with no explicit safety margin resolves to the same budget
45/// as having no override at all.
46pub fn is_default_limit(limit: &ModelLimit) -> bool {
47    limit.max_context_tokens == DEFAULT_MAX_CONTEXT_TOKENS
48        && limit.max_output_tokens == Some(DEFAULT_MAX_OUTPUT_TOKENS)
49        && limit.safety_margin.is_none()
50}
51
52/// Model limit configuration (user-overridable).
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ModelLimit {
55    /// Model identifier (partial match supported, e.g., "gpt-4" matches "gpt-4o")
56    pub model_pattern: String,
57    /// Maximum context window size in tokens
58    pub max_context_tokens: u32,
59    /// Maximum output tokens (defaults to min(4096, max_context / 4))
60    #[serde(default)]
61    pub max_output_tokens: Option<u32>,
62    /// Safety margin for token counting (defaults to 1000)
63    #[serde(default)]
64    pub safety_margin: Option<u32>,
65}
66
67impl ModelLimit {
68    /// Create a new model limit with defaults.
69    pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
70        Self {
71            model_pattern: model_pattern.into(),
72            max_context_tokens,
73            max_output_tokens: None,
74            safety_margin: None,
75        }
76    }
77
78    /// Get max output tokens with default calculation.
79    pub fn get_max_output_tokens(&self) -> u32 {
80        self.max_output_tokens
81            .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
82    }
83
84    /// Get safety margin, scaling proportionally with context window.
85    pub fn get_safety_margin(&self) -> u32 {
86        self.safety_margin
87            .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
88    }
89}
90
91fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
92    let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
93    limit.max_output_tokens = Some(max_output_tokens);
94    limit
95}
96
97/// Registry for model limits with built-in defaults and user overrides.
98#[derive(Debug, Clone)]
99pub struct ModelLimitsRegistry {
100    /// User-provided overrides (higher priority than built-in)
101    user_limits: HashMap<String, ModelLimit>,
102    /// Default path for user configuration file
103    config_path: Option<PathBuf>,
104}
105
106impl ModelLimitsRegistry {
107    /// Create a new registry with built-in defaults only.
108    pub fn new() -> Self {
109        Self {
110            user_limits: HashMap::new(),
111            config_path: None,
112        }
113    }
114
115    /// Create a registry with a specific config file path.
116    pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
117        Self {
118            user_limits: HashMap::new(),
119            config_path: Some(path.into()),
120        }
121    }
122
123    /// Load user overrides from the default configuration path.
124    ///
125    /// Default path: `{bamboo_data_dir}/model_limits.json`
126    pub async fn load_user_config(&mut self) -> std::io::Result<()> {
127        let path = self
128            .config_path
129            .clone()
130            .unwrap_or_else(get_default_config_path);
131
132        if !path.exists() {
133            return Ok(());
134        }
135
136        let content = tokio::fs::read_to_string(&path).await?;
137        let limits: Vec<ModelLimit> = serde_json::from_str(&content)
138            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
139
140        for limit in limits {
141            self.user_limits.insert(limit.model_pattern.clone(), limit);
142        }
143
144        tracing::info!(
145            "Loaded {} user model limits from {:?}",
146            self.user_limits.len(),
147            path
148        );
149        Ok(())
150    }
151
152    /// Add a user limit override.
153    pub fn add_limit(&mut self, limit: ModelLimit) {
154        self.user_limits.insert(limit.model_pattern.clone(), limit);
155    }
156
157    /// Get limit for a model, with user overrides taking priority.
158    ///
159    /// Returns `None` if no matching limit is found.
160    ///
161    /// # Matching Strategy
162    /// 1. Exact match (highest priority)
163    /// 2. Model contains pattern (e.g., "gpt-4o-mini" contains "gpt-4o")
164    /// 3. Pattern contains model (e.g., "gpt-4" contains "gpt")
165    ///
166    /// For partial matches, the longest (most specific) pattern wins.
167    pub fn get(&self, model: &str) -> Option<ModelLimit> {
168        // Exact user override match (highest priority).
169        if let Some(limit) = self.user_limits.get(model) {
170            return Some(limit.clone());
171        }
172
173        // Best partial match among user overrides. Longer (more specific)
174        // patterns win for deterministic selection. There is no built-in table;
175        // a miss returns None and the caller falls back to the global default.
176        self.user_limits
177            .iter()
178            .filter(|(pattern, _)| model.contains(pattern.as_str()) || pattern.contains(model))
179            .max_by_key(|(pattern, _)| pattern.len())
180            .map(|(_, limit)| limit.clone())
181    }
182
183    /// Get limit for a model with fallback to default.
184    pub fn get_or_default(&self, model: &str) -> ModelLimit {
185        self.get(model).unwrap_or_else(default_model_limit)
186    }
187
188    /// Save current user limits to the configuration file.
189    pub async fn save_user_config(&self) -> std::io::Result<()> {
190        let path = self
191            .config_path
192            .clone()
193            .unwrap_or_else(get_default_config_path);
194
195        // Ensure parent directory exists
196        if let Some(parent) = path.parent() {
197            tokio::fs::create_dir_all(parent).await?;
198        }
199
200        let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
201        let content = serde_json::to_string_pretty(&limits)
202            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
203        tokio::fs::write(&path, content).await?;
204
205        Ok(())
206    }
207
208    /// List all user-defined limits.
209    pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
210        self.user_limits.values().collect()
211    }
212}
213
214impl Default for ModelLimitsRegistry {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Get the default configuration file path.
221///
222/// Returns `{bamboo_data_dir}/model_limits.json`.
223pub fn get_default_config_path() -> PathBuf {
224    bamboo_infrastructure::paths::bamboo_dir().join("model_limits.json")
225}
226
227/// Load user model limits from unified `config.json` root key `model_limits`.
228///
229/// Returns:
230/// - `Ok(None)` when `model_limits` key is absent.
231/// - `Ok(Some(vec))` when key exists and is valid (including empty array).
232/// - `Err(...)` when key exists but is not a valid `Vec<ModelLimit>`.
233pub fn load_model_limits_from_unified_config(
234    config: &bamboo_infrastructure::Config,
235) -> Result<Option<Vec<ModelLimit>>, String> {
236    let Some(raw_limits) = config.extra.get("model_limits") else {
237        return Ok(None);
238    };
239
240    if raw_limits.is_null() {
241        return Ok(Some(Vec::new()));
242    }
243
244    match raw_limits {
245        Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
246            .map(Some)
247            .map_err(|error| format!("invalid config.model_limits format: {error}")),
248        _ => Err("invalid config.model_limits format: expected array".to_string()),
249    }
250}
251
252/// Create a token budget for a specific model.
253///
254/// This is a convenience function that creates a budget with appropriate defaults.
255pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
256    let registry = ModelLimitsRegistry::default();
257    let limit = registry.get_or_default(model);
258
259    crate::TokenBudget {
260        max_context_tokens: limit.max_context_tokens,
261        max_output_tokens: limit.get_max_output_tokens(),
262        strategy,
263        safety_margin: limit.get_safety_margin(),
264        compression_trigger_percent: 85, // legacy — only used when working_reserve_tokens == 0
265        compression_target_percent: 45,
266        working_reserve_tokens: 50_000,
267        fallback_trigger_percent: 75,
268        prompt_cache_min_tool_output_chars: 1_200,
269        prompt_cache_head_chars: 280,
270        prompt_cache_tail_chars: 180,
271        prompt_cache_recent_user_turns: 2,
272        prompt_cache_recent_tool_chains: 2,
273        max_tool_output_tokens: 0,
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn default_limit_is_200k_64k() {
283        let limit = default_model_limit();
284        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
285        assert_eq!(limit.max_context_tokens, 200_000);
286        assert_eq!(limit.get_max_output_tokens(), 64_000);
287    }
288
289    #[test]
290    fn is_default_limit_detects_no_op_overrides() {
291        // A row pinned to exactly the default values (any pattern) is a no-op.
292        let mut noop = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
293        noop.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
294        assert!(is_default_limit(&noop));
295
296        // The synthesized global default is itself a no-op override.
297        assert!(is_default_limit(&default_model_limit()));
298
299        // A different context window is a real override.
300        let mut smaller = ModelLimit::new("gpt-4o", 128_000);
301        smaller.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
302        assert!(!is_default_limit(&smaller));
303
304        // An explicit safety margin is a real override even at default size.
305        let mut custom_margin = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
306        custom_margin.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
307        custom_margin.safety_margin = Some(500);
308        assert!(!is_default_limit(&custom_margin));
309    }
310
311    #[test]
312    fn registry_returns_none_for_unknown_without_overrides() {
313        // No built-in table: an unknown model with no user override has no match.
314        let registry = ModelLimitsRegistry::new();
315        assert!(registry.get("gpt-5.2-codex").is_none());
316        assert!(registry.get("some-brand-new-model").is_none());
317    }
318
319    #[test]
320    fn registry_returns_default_for_unknown() {
321        let registry = ModelLimitsRegistry::new();
322        let limit = registry.get_or_default("unknown-model-xyz");
323        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
324        assert_eq!(limit.max_context_tokens, 200_000);
325        assert_eq!(limit.get_max_output_tokens(), 64_000);
326    }
327
328    #[test]
329    fn user_override_exact_match_wins() {
330        let mut registry = ModelLimitsRegistry::new();
331        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); // Override with smaller limit
332
333        let limit = registry
334            .get("gpt-5.2-codex")
335            .expect("Should find overridden limit");
336        assert_eq!(limit.max_context_tokens, 64_000);
337    }
338
339    #[test]
340    fn user_override_partial_match_longest_wins() {
341        let mut registry = ModelLimitsRegistry::new();
342        registry.add_limit(ModelLimit::new("gpt-5", 111_000));
343        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 222_000));
344
345        // "gpt-5.2-codex-preview" contains both patterns; the longest wins.
346        let limit = registry
347            .get("gpt-5.2-codex-preview")
348            .expect("Should partial-match a user override");
349        assert_eq!(limit.max_context_tokens, 222_000);
350    }
351
352    #[test]
353    fn model_limit_calculates_default_output_tokens() {
354        let limit = ModelLimit::new("test", 100_000);
355        // Default is min(max_context / 4, 4096) = min(25000, 4096) = 4096
356        assert_eq!(limit.get_max_output_tokens(), 4096);
357    }
358
359    #[test]
360    fn model_limit_uses_custom_output_tokens() {
361        let mut limit = ModelLimit::new("test", 100_000);
362        limit.max_output_tokens = Some(8192);
363        assert_eq!(limit.get_max_output_tokens(), 8192);
364    }
365
366    #[test]
367    fn model_limit_calculates_small_context_output() {
368        let limit = ModelLimit::new("test", 8_192);
369        // Default is min(8192 / 4, 4096) = 2048
370        assert_eq!(limit.get_max_output_tokens(), 2048);
371    }
372
373    #[test]
374    fn unified_config_loader_returns_none_when_absent() {
375        let temp_dir = tempfile::tempdir().expect("tempdir");
376        let config =
377            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
378        let loaded = load_model_limits_from_unified_config(&config).expect("should parse");
379        assert!(loaded.is_none());
380    }
381
382    #[test]
383    fn unified_config_loader_reads_valid_model_limits() {
384        let temp_dir = tempfile::tempdir().expect("tempdir");
385        let mut config =
386            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
387        config.extra.insert(
388            "model_limits".to_string(),
389            serde_json::json!([
390                {
391                    "model_pattern": "gpt-5.2-codex",
392                    "max_context_tokens": 64000,
393                    "max_output_tokens": 2048,
394                    "safety_margin": 512
395                }
396            ]),
397        );
398
399        let loaded = load_model_limits_from_unified_config(&config)
400            .expect("should parse")
401            .expect("should exist");
402        assert_eq!(loaded.len(), 1);
403        assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
404        assert_eq!(loaded[0].max_context_tokens, 64_000);
405        assert_eq!(loaded[0].max_output_tokens, Some(2048));
406        assert_eq!(loaded[0].safety_margin, Some(512));
407    }
408
409    #[test]
410    fn unified_config_loader_errors_on_invalid_shape() {
411        let temp_dir = tempfile::tempdir().expect("tempdir");
412        let mut config =
413            bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
414        config.extra.insert(
415            "model_limits".to_string(),
416            serde_json::json!({"unexpected": true}),
417        );
418
419        let error = load_model_limits_from_unified_config(&config).expect_err("should error");
420        assert!(error.contains("expected array"));
421    }
422
423    #[test]
424    fn safety_margin_scales_with_context_window() {
425        // Small context → floor at DEFAULT_SAFETY_MARGIN (1000)
426        let small = ModelLimit::new("test", 8_192);
427        assert_eq!(small.get_safety_margin(), 1000);
428
429        // Medium context → proportional
430        let medium = ModelLimit::new("test", 200_000);
431        assert_eq!(medium.get_safety_margin(), 2000);
432
433        // Large context → proportional
434        let large = ModelLimit::new("test", 1_050_000);
435        assert_eq!(large.get_safety_margin(), 10_500);
436
437        // Explicit override takes precedence
438        let mut custom = ModelLimit::new("test", 200_000);
439        custom.safety_margin = Some(500);
440        assert_eq!(custom.get_safety_margin(), 500);
441    }
442
443    #[tokio::test]
444    async fn persisted_overrides_drive_runtime_resolution() {
445        // Integration: a `model_limits.json` on disk → registry load → resolve.
446        let dir = tempfile::tempdir().expect("tempdir");
447        let path = dir.path().join("model_limits.json");
448        tokio::fs::write(
449            &path,
450            r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000,"max_output_tokens":16384}]"#,
451        )
452        .await
453        .expect("seed overrides");
454
455        let mut registry = ModelLimitsRegistry::with_config_path(path);
456        registry.load_user_config().await.expect("load user config");
457
458        // Persisted override is applied at runtime.
459        let gpt4o = registry.get("gpt-4o").expect("override present");
460        assert_eq!(gpt4o.max_context_tokens, 128_000);
461        assert_eq!(gpt4o.get_max_output_tokens(), 16_384);
462
463        // Unknown model with no override falls back to the single global default.
464        let unknown = registry.get_or_default("brand-new-frontier-model");
465        assert_eq!(unknown.model_pattern, DEFAULT_MODEL_PATTERN);
466        assert_eq!(unknown.max_context_tokens, 200_000);
467        assert_eq!(unknown.get_max_output_tokens(), 64_000);
468    }
469
470    #[test]
471    fn create_budget_for_model_uses_global_default_for_any_model() {
472        let budget = create_budget_for_model("anything-at-all", crate::BudgetStrategy::default());
473        assert_eq!(budget.max_context_tokens, 200_000);
474        assert_eq!(budget.max_output_tokens, 64_000);
475    }
476}