1use once_cell::sync::Lazy;
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4use utoipa::ToSchema;
5
6const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
7
8#[derive(Error, Debug)]
9pub enum ConfigError {
10 #[error("Environment variable '{0}' not found")]
11 EnvVarMissing(String),
12 #[error("Invalid value for '{0}': '{1}' - {2}")]
13 InvalidValue(String, String, String),
14 #[error("Value for '{0}' is out of valid range: {1}")]
15 InvalidRange(String, String),
16}
17
18static MODEL_SPECIFIC_LIMITS: Lazy<Vec<(&'static str, usize)>> = Lazy::new(|| {
19 vec![
20 ("gpt-5.2-codex", 400_000), ("gpt-5.2", 400_000), ("gpt-5.1-codex-max", 256_000),
24 ("gpt-5.1-codex-mini", 256_000),
25 ("gpt-4-turbo", 128_000),
26 ("gpt-4.1", 1_000_000),
27 ("gpt-4-1", 1_000_000),
28 ("gpt-4o", 128_000),
29 ("o4-mini", 200_000),
30 ("o3-mini", 200_000),
31 ("o3", 200_000),
32 ("claude", 200_000),
34 ("gemini-1.5-flash", 1_000_000),
36 ("gemini-1", 128_000),
37 ("gemini-2", 1_000_000),
38 ("gemma-3-27b", 128_000),
39 ("gemma-3-12b", 128_000),
40 ("gemma-3-4b", 128_000),
41 ("gemma-3-1b", 32_000),
42 ("gemma3-27b", 128_000),
43 ("gemma3-12b", 128_000),
44 ("gemma3-4b", 128_000),
45 ("gemma3-1b", 32_000),
46 ("gemma-2-27b", 8_192),
47 ("gemma-2-9b", 8_192),
48 ("gemma-2-2b", 8_192),
49 ("gemma2-", 8_192),
50 ("gemma-7b", 8_192),
51 ("gemma-2b", 8_192),
52 ("gemma1", 8_192),
53 ("gemma", 8_192),
54 ("llama-2-1b", 32_000),
56 ("llama", 128_000),
57 ("qwen3-coder", 262_144),
59 ("qwen2-7b", 128_000),
60 ("qwen2-14b", 128_000),
61 ("qwen2-32b", 131_072),
62 ("qwen2-70b", 262_144),
63 ("qwen2", 128_000),
64 ("qwen3-32b", 131_072),
65 ("grok-4", 256_000),
67 ("grok-code-fast-1", 256_000),
68 ("grok", 131_072),
69 ("kimi-k2", 131_072),
71 ]
72});
73
74#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
75pub struct ModelConfig {
76 pub model_name: String,
77 pub context_limit: Option<usize>,
78 pub temperature: Option<f32>,
79 pub max_tokens: Option<i32>,
80 pub toolshim: bool,
81 pub toolshim_model: Option<String>,
82 pub fast_model: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ModelLimitConfig {
87 pub pattern: String,
88 pub context_limit: usize,
89}
90
91impl ModelConfig {
92 pub fn new(model_name: &str) -> Result<Self, ConfigError> {
93 Self::new_with_context_env(model_name.to_string(), None)
94 }
95
96 pub fn new_with_context_env(
97 model_name: String,
98 context_env_var: Option<&str>,
99 ) -> Result<Self, ConfigError> {
100 let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?;
101 let temperature = Self::parse_temperature()?;
102 let max_tokens = Self::parse_max_tokens()?;
103 let toolshim = Self::parse_toolshim()?;
104 let toolshim_model = Self::parse_toolshim_model()?;
105
106 Ok(Self {
107 model_name,
108 context_limit,
109 temperature,
110 max_tokens,
111 toolshim,
112 toolshim_model,
113 fast_model: None,
114 })
115 }
116
117 fn parse_context_limit(
118 model_name: &str,
119 fast_model: Option<&str>,
120 custom_env_var: Option<&str>,
121 ) -> Result<Option<usize>, ConfigError> {
122 if let Some(env_var) = custom_env_var {
124 if let Ok(val) = std::env::var(env_var) {
125 return Self::validate_context_limit(&val, env_var).map(Some);
126 }
127 }
128 if let Ok(val) = std::env::var("ASTER_CONTEXT_LIMIT") {
129 return Self::validate_context_limit(&val, "ASTER_CONTEXT_LIMIT").map(Some);
130 }
131
132 let model_limit = Self::get_model_specific_limit(model_name);
134
135 if let Some(fast_model_name) = fast_model {
137 let fast_model_limit = Self::get_model_specific_limit(fast_model_name);
138
139 match (model_limit, fast_model_limit) {
141 (Some(m), Some(f)) => Ok(Some(m.min(f))),
142 (Some(m), None) => Ok(Some(m)),
143 (None, Some(f)) => Ok(Some(f)),
144 (None, None) => Ok(None),
145 }
146 } else {
147 Ok(model_limit)
148 }
149 }
150
151 fn validate_context_limit(val: &str, env_var: &str) -> Result<usize, ConfigError> {
152 let limit = val.parse::<usize>().map_err(|_| {
153 ConfigError::InvalidValue(
154 env_var.to_string(),
155 val.to_string(),
156 "must be a positive integer".to_string(),
157 )
158 })?;
159
160 if limit < 4 * 1024 {
161 return Err(ConfigError::InvalidRange(
162 env_var.to_string(),
163 "must be greater than 4K".to_string(),
164 ));
165 }
166
167 Ok(limit)
168 }
169
170 fn parse_temperature() -> Result<Option<f32>, ConfigError> {
171 if let Ok(val) = std::env::var("ASTER_TEMPERATURE") {
172 let temp = val.parse::<f32>().map_err(|_| {
173 ConfigError::InvalidValue(
174 "ASTER_TEMPERATURE".to_string(),
175 val.clone(),
176 "must be a valid number".to_string(),
177 )
178 })?;
179 if temp < 0.0 {
180 return Err(ConfigError::InvalidRange(
181 "ASTER_TEMPERATURE".to_string(),
182 val,
183 ));
184 }
185 Ok(Some(temp))
186 } else {
187 Ok(None)
188 }
189 }
190
191 fn parse_max_tokens() -> Result<Option<i32>, ConfigError> {
192 match crate::config::Config::global().get_param::<i32>("ASTER_MAX_TOKENS") {
193 Ok(tokens) => {
194 if tokens <= 0 {
195 return Err(ConfigError::InvalidRange(
196 "aster_max_tokens".to_string(),
197 "must be greater than 0".to_string(),
198 ));
199 }
200 Ok(Some(tokens))
201 }
202 Err(crate::config::ConfigError::NotFound(_)) => Ok(None),
203 Err(e) => Err(ConfigError::InvalidValue(
204 "aster_max_tokens".to_string(),
205 String::new(),
206 e.to_string(),
207 )),
208 }
209 }
210
211 fn parse_toolshim() -> Result<bool, ConfigError> {
212 if let Ok(val) = std::env::var("ASTER_TOOLSHIM") {
213 match val.to_lowercase().as_str() {
214 "1" | "true" | "yes" | "on" => Ok(true),
215 "0" | "false" | "no" | "off" => Ok(false),
216 _ => Err(ConfigError::InvalidValue(
217 "ASTER_TOOLSHIM".to_string(),
218 val,
219 "must be one of: 1, true, yes, on, 0, false, no, off".to_string(),
220 )),
221 }
222 } else {
223 Ok(false)
224 }
225 }
226
227 fn parse_toolshim_model() -> Result<Option<String>, ConfigError> {
228 match std::env::var("ASTER_TOOLSHIM_OLLAMA_MODEL") {
229 Ok(val) if val.trim().is_empty() => Err(ConfigError::InvalidValue(
230 "ASTER_TOOLSHIM_OLLAMA_MODEL".to_string(),
231 val,
232 "cannot be empty if set".to_string(),
233 )),
234 Ok(val) => Ok(Some(val)),
235 Err(_) => Ok(None),
236 }
237 }
238
239 fn get_model_specific_limit(model_name: &str) -> Option<usize> {
240 MODEL_SPECIFIC_LIMITS
241 .iter()
242 .find(|(pattern, _)| model_name.contains(pattern))
243 .map(|(_, limit)| *limit)
244 }
245
246 pub fn get_all_model_limits() -> Vec<ModelLimitConfig> {
247 MODEL_SPECIFIC_LIMITS
248 .iter()
249 .map(|(pattern, context_limit)| ModelLimitConfig {
250 pattern: pattern.to_string(),
251 context_limit: *context_limit,
252 })
253 .collect()
254 }
255
256 pub fn with_context_limit(mut self, limit: Option<usize>) -> Self {
257 if limit.is_some() {
258 self.context_limit = limit;
259 }
260 self
261 }
262
263 pub fn with_temperature(mut self, temp: Option<f32>) -> Self {
264 self.temperature = temp;
265 self
266 }
267
268 pub fn with_max_tokens(mut self, tokens: Option<i32>) -> Self {
269 self.max_tokens = tokens;
270 self
271 }
272
273 pub fn with_toolshim(mut self, toolshim: bool) -> Self {
274 self.toolshim = toolshim;
275 self
276 }
277
278 pub fn with_toolshim_model(mut self, model: Option<String>) -> Self {
279 self.toolshim_model = model;
280 self
281 }
282
283 pub fn with_fast(mut self, fast_model: String) -> Self {
284 self.fast_model = Some(fast_model);
285 self
286 }
287
288 pub fn use_fast_model(&self) -> Self {
289 if let Some(fast_model) = &self.fast_model {
290 let mut config = self.clone();
291 config.model_name = fast_model.clone();
292 config
293 } else {
294 self.clone()
295 }
296 }
297
298 pub fn context_limit(&self) -> usize {
299 if let Some(limit) = self.context_limit {
301 return limit;
302 }
303
304 let main_limit =
306 Self::get_model_specific_limit(&self.model_name).unwrap_or(DEFAULT_CONTEXT_LIMIT);
307
308 if let Some(fast_model) = &self.fast_model {
310 let fast_limit =
311 Self::get_model_specific_limit(fast_model).unwrap_or(DEFAULT_CONTEXT_LIMIT);
312 main_limit.min(fast_limit)
313 } else {
314 main_limit
315 }
316 }
317
318 pub fn new_or_fail(model_name: &str) -> ModelConfig {
319 ModelConfig::new(model_name)
320 .unwrap_or_else(|_| panic!("Failed to create model config for {}", model_name))
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_parse_max_tokens_valid() {
330 let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("4096"))]);
331 let result = ModelConfig::parse_max_tokens().unwrap();
332 assert_eq!(result, Some(4096));
333 }
334
335 #[test]
336 fn test_parse_max_tokens_not_set() {
337 let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", None::<&str>)]);
338 let result = ModelConfig::parse_max_tokens().unwrap();
339 assert_eq!(result, None);
340 }
341
342 #[test]
343 fn test_parse_max_tokens_invalid_string() {
344 let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("not_a_number"))]);
345 let result = ModelConfig::parse_max_tokens();
346 assert!(result.is_err());
347 assert!(matches!(result.unwrap_err(), ConfigError::InvalidValue(..)));
348 }
349
350 #[test]
351 fn test_parse_max_tokens_zero() {
352 let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("0"))]);
353 let result = ModelConfig::parse_max_tokens();
354 assert!(result.is_err());
355 assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..)));
356 }
357
358 #[test]
359 fn test_parse_max_tokens_negative() {
360 let _guard = env_lock::lock_env([("ASTER_MAX_TOKENS", Some("-100"))]);
361 let result = ModelConfig::parse_max_tokens();
362 assert!(result.is_err());
363 assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..)));
364 }
365
366 #[test]
367 fn test_model_config_with_max_tokens_env() {
368 let _guard = env_lock::lock_env([
369 ("ASTER_MAX_TOKENS", Some("8192")),
370 ("ASTER_TEMPERATURE", None::<&str>),
371 ("ASTER_CONTEXT_LIMIT", None::<&str>),
372 ("ASTER_TOOLSHIM", None::<&str>),
373 ("ASTER_TOOLSHIM_OLLAMA_MODEL", None::<&str>),
374 ]);
375 let config = ModelConfig::new("test-model").unwrap();
376 assert_eq!(config.max_tokens, Some(8192));
377 }
378
379 #[test]
380 fn test_model_config_without_max_tokens_env() {
381 let _guard = env_lock::lock_env([
382 ("ASTER_MAX_TOKENS", None::<&str>),
383 ("ASTER_TEMPERATURE", None::<&str>),
384 ("ASTER_CONTEXT_LIMIT", None::<&str>),
385 ("ASTER_TOOLSHIM", None::<&str>),
386 ("ASTER_TOOLSHIM_OLLAMA_MODEL", None::<&str>),
387 ]);
388 let config = ModelConfig::new("test-model").unwrap();
389 assert_eq!(config.max_tokens, None);
390 }
391}