Skip to main content

aster/
model.rs

1use once_cell::sync::Lazy;
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4use utoipa::ToSchema;
5
6const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
7
8#[derive(Error, Debug)]
9pub enum ConfigError {
10    #[error("Environment variable '{0}' not found")]
11    EnvVarMissing(String),
12    #[error("Invalid value for '{0}': '{1}' - {2}")]
13    InvalidValue(String, String, String),
14    #[error("Value for '{0}' is out of valid range: {1}")]
15    InvalidRange(String, String),
16}
17
18static MODEL_SPECIFIC_LIMITS: Lazy<Vec<(&'static str, usize)>> = Lazy::new(|| {
19    vec![
20        // openai
21        ("gpt-5.2-codex", 400_000), // auto-compacting context
22        ("gpt-5.2", 400_000),       // auto-compacting context
23        ("gpt-5.1-codex-max", 256_000),
24        ("gpt-5.1-codex-mini", 256_000),
25        ("gpt-4-turbo", 128_000),
26        ("gpt-4.1", 1_000_000),
27        ("gpt-4-1", 1_000_000),
28        ("gpt-4o", 128_000),
29        ("o4-mini", 200_000),
30        ("o3-mini", 200_000),
31        ("o3", 200_000),
32        // anthropic - all 200k
33        ("claude", 200_000),
34        // google
35        ("gemini-1.5-flash", 1_000_000),
36        ("gemini-1", 128_000),
37        ("gemini-2", 1_000_000),
38        ("gemma-3-27b", 128_000),
39        ("gemma-3-12b", 128_000),
40        ("gemma-3-4b", 128_000),
41        ("gemma-3-1b", 32_000),
42        ("gemma3-27b", 128_000),
43        ("gemma3-12b", 128_000),
44        ("gemma3-4b", 128_000),
45        ("gemma3-1b", 32_000),
46        ("gemma-2-27b", 8_192),
47        ("gemma-2-9b", 8_192),
48        ("gemma-2-2b", 8_192),
49        ("gemma2-", 8_192),
50        ("gemma-7b", 8_192),
51        ("gemma-2b", 8_192),
52        ("gemma1", 8_192),
53        ("gemma", 8_192),
54        // facebook
55        ("llama-2-1b", 32_000),
56        ("llama", 128_000),
57        // qwen
58        ("qwen3-coder", 262_144),
59        ("qwen2-7b", 128_000),
60        ("qwen2-14b", 128_000),
61        ("qwen2-32b", 131_072),
62        ("qwen2-70b", 262_144),
63        ("qwen2", 128_000),
64        ("qwen3-32b", 131_072),
65        // xai
66        ("grok-4", 256_000),
67        ("grok-code-fast-1", 256_000),
68        ("grok", 131_072),
69        // other
70        ("kimi-k2", 131_072),
71    ]
72});
73
74#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
75pub struct ModelConfig {
76    pub model_name: String,
77    pub context_limit: Option<usize>,
78    pub temperature: Option<f32>,
79    pub max_tokens: Option<i32>,
80    pub toolshim: bool,
81    pub toolshim_model: Option<String>,
82    pub fast_model: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ModelLimitConfig {
87    pub pattern: String,
88    pub context_limit: usize,
89}
90
91impl ModelConfig {
92    pub fn new(model_name: &str) -> Result<Self, ConfigError> {
93        Self::new_with_context_env(model_name.to_string(), None)
94    }
95
96    pub fn new_with_context_env(
97        model_name: String,
98        context_env_var: Option<&str>,
99    ) -> Result<Self, ConfigError> {
100        let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?;
101        let temperature = Self::parse_temperature()?;
102        let max_tokens = Self::parse_max_tokens()?;
103        let toolshim = Self::parse_toolshim()?;
104        let toolshim_model = Self::parse_toolshim_model()?;
105
106        Ok(Self {
107            model_name,
108            context_limit,
109            temperature,
110            max_tokens,
111            toolshim,
112            toolshim_model,
113            fast_model: None,
114        })
115    }
116
117    fn parse_context_limit(
118        model_name: &str,
119        fast_model: Option<&str>,
120        custom_env_var: Option<&str>,
121    ) -> Result<Option<usize>, ConfigError> {
122        // First check if there's an explicit environment variable override
123        if let Some(env_var) = custom_env_var {
124            if let Ok(val) = std::env::var(env_var) {
125                return Self::validate_context_limit(&val, env_var).map(Some);
126            }
127        }
128        if let Ok(val) = std::env::var("ASTER_CONTEXT_LIMIT") {
129            return Self::validate_context_limit(&val, "ASTER_CONTEXT_LIMIT").map(Some);
130        }
131
132        // Get the model's limit
133        let model_limit = Self::get_model_specific_limit(model_name);
134
135        // If there's a fast_model, get its limit and use the minimum
136        if let Some(fast_model_name) = fast_model {
137            let fast_model_limit = Self::get_model_specific_limit(fast_model_name);
138
139            // Return the minimum of both limits (if both exist)
140            match (model_limit, fast_model_limit) {
141                (Some(m), Some(f)) => Ok(Some(m.min(f))),
142                (Some(m), None) => Ok(Some(m)),
143                (None, Some(f)) => Ok(Some(f)),
144                (None, None) => Ok(None),
145            }
146        } else {
147            Ok(model_limit)
148        }
149    }
150
151    fn validate_context_limit(val: &str, env_var: &str) -> Result<usize, ConfigError> {
152        let limit = val.parse::<usize>().map_err(|_| {
153            ConfigError::InvalidValue(
154                env_var.to_string(),
155                val.to_string(),
156                "must be a positive integer".to_string(),
157            )
158        })?;
159
160        if limit < 4 * 1024 {
161            return Err(ConfigError::InvalidRange(
162                env_var.to_string(),
163                "must be greater than 4K".to_string(),
164            ));
165        }
166
167        Ok(limit)
168    }
169
170    fn parse_temperature() -> Result<Option<f32>, ConfigError> {
171        if let Ok(val) = std::env::var("ASTER_TEMPERATURE") {
172            let temp = val.parse::<f32>().map_err(|_| {
173                ConfigError::InvalidValue(
174                    "ASTER_TEMPERATURE".to_string(),
175                    val.clone(),
176                    "must be a valid number".to_string(),
177                )
178            })?;
179            if temp < 0.0 {
180                return Err(ConfigError::InvalidRange(
181                    "ASTER_TEMPERATURE".to_string(),
182                    val,
183                ));
184            }
185            Ok(Some(temp))
186        } else {
187            Ok(None)
188        }
189    }
190
191    fn parse_max_tokens() -> Result<Option<i32>, ConfigError> {
192        match crate::config::Config::global().get_param::<i32>("ASTER_MAX_TOKENS") {
193            Ok(tokens) => {
194                if tokens <= 0 {
195                    return Err(ConfigError::InvalidRange(
196                        "aster_max_tokens".to_string(),
197                        "must be greater than 0".to_string(),
198                    ));
199                }
200                Ok(Some(tokens))
201            }
202            Err(crate::config::ConfigError::NotFound(_)) => Ok(None),
203            Err(e) => Err(ConfigError::InvalidValue(
204                "aster_max_tokens".to_string(),
205                String::new(),
206                e.to_string(),
207            )),
208        }
209    }
210
211    fn parse_toolshim() -> Result<bool, ConfigError> {
212        if let Ok(val) = std::env::var("ASTER_TOOLSHIM") {
213            match val.to_lowercase().as_str() {
214                "1" | "true" | "yes" | "on" => Ok(true),
215                "0" | "false" | "no" | "off" => Ok(false),
216                _ => Err(ConfigError::InvalidValue(
217                    "ASTER_TOOLSHIM".to_string(),
218                    val,
219                    "must be one of: 1, true, yes, on, 0, false, no, off".to_string(),
220                )),
221            }
222        } else {
223            Ok(false)
224        }
225    }
226
227    fn parse_toolshim_model() -> Result<Option<String>, ConfigError> {
228        match std::env::var("ASTER_TOOLSHIM_OLLAMA_MODEL") {
229            Ok(val) if val.trim().is_empty() => Err(ConfigError::InvalidValue(
230                "ASTER_TOOLSHIM_OLLAMA_MODEL".to_string(),
231                val,
232                "cannot be empty if set".to_string(),
233            )),
234            Ok(val) => Ok(Some(val)),
235            Err(_) => Ok(None),
236        }
237    }
238
239    fn get_model_specific_limit(model_name: &str) -> Option<usize> {
240        MODEL_SPECIFIC_LIMITS
241            .iter()
242            .find(|(pattern, _)| model_name.contains(pattern))
243            .map(|(_, limit)| *limit)
244    }
245
246    pub fn get_all_model_limits() -> Vec<ModelLimitConfig> {
247        MODEL_SPECIFIC_LIMITS
248            .iter()
249            .map(|(pattern, context_limit)| ModelLimitConfig {
250                pattern: pattern.to_string(),
251                context_limit: *context_limit,
252            })
253            .collect()
254    }
255
256    pub fn with_context_limit(mut self, limit: Option<usize>) -> Self {
257        if limit.is_some() {
258            self.context_limit = limit;
259        }
260        self
261    }
262
263    pub fn with_temperature(mut self, temp: Option<f32>) -> Self {
264        self.temperature = temp;
265        self
266    }
267
268    pub fn with_max_tokens(mut self, tokens: Option<i32>) -> Self {
269        self.max_tokens = tokens;
270        self
271    }
272
273    pub fn with_toolshim(mut self, toolshim: bool) -> Self {
274        self.toolshim = toolshim;
275        self
276    }
277
278    pub fn with_toolshim_model(mut self, model: Option<String>) -> Self {
279        self.toolshim_model = model;
280        self
281    }
282
283    pub fn with_fast(mut self, fast_model: String) -> Self {
284        self.fast_model = Some(fast_model);
285        self
286    }
287
288    pub fn use_fast_model(&self) -> Self {
289        if let Some(fast_model) = &self.fast_model {
290            let mut config = self.clone();
291            config.model_name = fast_model.clone();
292            config
293        } else {
294            self.clone()
295        }
296    }
297
298    pub fn context_limit(&self) -> usize {
299        // If we have an explicit context limit set, use it
300        if let Some(limit) = self.context_limit {
301            return limit;
302        }
303
304        // Otherwise, get the model's default limit
305        let main_limit =
306            Self::get_model_specific_limit(&self.model_name).unwrap_or(DEFAULT_CONTEXT_LIMIT);
307
308        // If we have a fast_model, also check its limit and use the minimum
309        if let Some(fast_model) = &self.fast_model {
310            let fast_limit =
311                Self::get_model_specific_limit(fast_model).unwrap_or(DEFAULT_CONTEXT_LIMIT);
312            main_limit.min(fast_limit)
313        } else {
314            main_limit
315        }
316    }
317
318    pub fn new_or_fail(model_name: &str) -> ModelConfig {
319        ModelConfig::new(model_name)
320            .unwrap_or_else(|_| panic!("Failed to create model config for {}", model_name))
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_parse_max_tokens_valid() {
330        let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("4096"))]);
331        let result = ModelConfig::parse_max_tokens().unwrap();
332        assert_eq!(result, Some(4096));
333    }
334
335    #[test]
336    fn test_parse_max_tokens_not_set() {
337        let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", None::<&str>)]);
338        let result = ModelConfig::parse_max_tokens().unwrap();
339        assert_eq!(result, None);
340    }
341
342    #[test]
343    fn test_parse_max_tokens_invalid_string() {
344        let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("not_a_number"))]);
345        let result = ModelConfig::parse_max_tokens();
346        assert!(result.is_err());
347        assert!(matches!(result.unwrap_err(), ConfigError::InvalidValue(..)));
348    }
349
350    #[test]
351    fn test_parse_max_tokens_zero() {
352        let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("0"))]);
353        let result = ModelConfig::parse_max_tokens();
354        assert!(result.is_err());
355        assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..)));
356    }
357
358    #[test]
359    fn test_parse_max_tokens_negative() {
360        let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("-100"))]);
361        let result = ModelConfig::parse_max_tokens();
362        assert!(result.is_err());
363        assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..)));
364    }
365
366    #[test]
367    fn test_model_config_with_max_tokens_env() {
368        let _guard = env_lock::lock_env([
369            ("ASTER_MAX_TOKENS", Some("8192")),
370            ("ASTER_TEMPERATURE", None::<&str>),
371            ("ASTER_CONTEXT_LIMIT", None::<&str>),
372            ("ASTER_TOOLSHIM", None::<&str>),
373            ("ASTER_TOOLSHIM_OLLAMA_MODEL", None::<&str>),
374        ]);
375        let config = ModelConfig::new("test-model").unwrap();
376        assert_eq!(config.max_tokens, Some(8192));
377    }
378
379    #[test]
380    fn test_model_config_without_max_tokens_env() {
381        let _guard = env_lock::lock_env([
382            ("ASTER_MAX_TOKENS", None::<&str>),
383            ("ASTER_TEMPERATURE", None::<&str>),
384            ("ASTER_CONTEXT_LIMIT", None::<&str>),
385            ("ASTER_TOOLSHIM", None::<&str>),
386            ("ASTER_TOOLSHIM_OLLAMA_MODEL", None::<&str>),
387        ]);
388        let config = ModelConfig::new("test-model").unwrap();
389        assert_eq!(config.max_tokens, None);
390    }
391}