Skip to main content

ai_agents_runtime/spec/
llm.rs

1//! LLM configuration types
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
7pub struct CliMetadata {
8    #[serde(default, skip_serializing_if = "Option::is_none")]
9    pub welcome: Option<String>,
10
11    #[serde(default, skip_serializing_if = "Vec::is_empty")]
12    pub hints: Vec<String>,
13
14    #[serde(default, skip_serializing_if = "Option::is_none")]
15    pub show_tools: Option<bool>,
16
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub show_state: Option<bool>,
19
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub show_timing: Option<bool>,
22
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub streaming: Option<bool>,
25
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub prompt_style: Option<CliPromptStyle>,
28
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub disable_builtin_commands: Option<bool>,
31
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub hitl: Option<CliHitlMetadata>,
34
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub theme: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "snake_case")]
41pub enum CliPromptStyle {
42    Simple,
43    WithState,
44}
45
46/// Controls how the CLI handles HITL approval requests at runtime.
47#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
48#[serde(rename_all = "snake_case")]
49pub enum CliHitlStyle {
50    /// Interactive y/N prompt in the terminal (default).
51    #[default]
52    Prompt,
53    /// Silently approve all requests.
54    AutoApprove,
55    /// Silently reject all requests.
56    AutoReject,
57}
58
59/// CLI-specific HITL display settings from `metadata.cli.hitl`.
60#[derive(Debug, Clone, Serialize, Deserialize, Default)]
61pub struct CliHitlMetadata {
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub style: Option<CliHitlStyle>,
64
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub show_context: Option<bool>,
67}
68
69/// Configuration for LLM provider
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LLMConfig {
72    pub provider: String,
73
74    pub model: String,
75
76    #[serde(default = "default_temperature")]
77    pub temperature: f32,
78
79    #[serde(default = "default_max_tokens")]
80    pub max_tokens: u32,
81
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub top_p: Option<f32>,
84
85    /// Base URL for the LLM provider API.
86    /// Required for `openai-compatible`; optional override for other providers.
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub base_url: Option<String>,
89
90    /// Environment variable name containing the API key.
91    /// Overrides the provider's default env var (e.g. OPENAI_API_KEY).
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub api_key_env: Option<String>,
94
95    /// Request timeout in seconds.
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub timeout_seconds: Option<u64>,
98
99    /// Enable extended thinking / reasoning mode.
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub reasoning: Option<bool>,
102
103    /// Reasoning effort level: "low", "medium", or "high".
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub reasoning_effort: Option<String>,
106
107    /// Maximum token budget for reasoning.
108    #[serde(default, skip_serializing_if = "Option::is_none")]
109    pub reasoning_budget_tokens: Option<u32>,
110
111    /// Additional provider-specific configuration
112    #[serde(flatten)]
113    pub extra: HashMap<String, serde_json::Value>,
114}
115
116fn default_temperature() -> f32 {
117    0.7
118}
119
120fn default_max_tokens() -> u32 {
121    2000
122}
123
124impl Default for LLMConfig {
125    fn default() -> Self {
126        Self {
127            provider: "openai".to_string(),
128            model: "gpt-4".to_string(),
129            temperature: default_temperature(),
130            max_tokens: default_max_tokens(),
131            top_p: None,
132            base_url: None,
133            api_key_env: None,
134            timeout_seconds: None,
135            reasoning: None,
136            reasoning_effort: None,
137            reasoning_budget_tokens: None,
138            extra: HashMap::new(),
139        }
140    }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct LLMSelector {
145    #[serde(default = "default_alias")]
146    pub default: String,
147    #[serde(default)]
148    pub router: Option<String>,
149}
150
151fn default_alias() -> String {
152    "default".to_string()
153}
154
155impl Default for LLMSelector {
156    fn default() -> Self {
157        Self {
158            default: default_alias(),
159            router: None,
160        }
161    }
162}
163
164impl LLMSelector {
165    pub fn new(default: impl Into<String>) -> Self {
166        Self {
167            default: default.into(),
168            router: None,
169        }
170    }
171
172    pub fn with_router(mut self, router: impl Into<String>) -> Self {
173        self.router = Some(router.into());
174        self
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_cli_metadata_deserialize() {
184        let yaml = r#"
185welcome: "=== Demo ==="
186hints:
187  - "Try: hello"
188  - "Try: help"
189show_tools: true
190show_state: false
191show_timing: true
192streaming: true
193prompt_style: with_state
194disable_builtin_commands: false
195"#;
196        let metadata: CliMetadata = serde_yaml::from_str(yaml).unwrap();
197        assert_eq!(metadata.welcome.as_deref(), Some("=== Demo ==="));
198        assert_eq!(metadata.hints.len(), 2);
199        assert_eq!(metadata.show_tools, Some(true));
200        assert_eq!(metadata.show_state, Some(false));
201        assert_eq!(metadata.show_timing, Some(true));
202        assert_eq!(metadata.streaming, Some(true));
203        assert_eq!(metadata.prompt_style, Some(CliPromptStyle::WithState));
204        assert_eq!(metadata.disable_builtin_commands, Some(false));
205        assert!(metadata.hitl.is_none());
206    }
207
208    #[test]
209    fn test_llm_config_default() {
210        let config = LLMConfig::default();
211        assert_eq!(config.provider, "openai");
212        assert_eq!(config.model, "gpt-4");
213        assert_eq!(config.temperature, 0.7);
214        assert_eq!(config.max_tokens, 2000);
215        assert_eq!(config.base_url, None);
216        assert_eq!(config.api_key_env, None);
217    }
218
219    #[test]
220    fn test_llm_config_with_base_url() {
221        let yaml = r#"
222provider: openai-compatible
223model: llama3.2
224base_url: http://localhost:1234/v1
225"#;
226        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
227        assert_eq!(config.provider, "openai-compatible");
228        assert_eq!(
229            config.base_url,
230            Some("http://localhost:1234/v1".to_string())
231        );
232    }
233
234    #[test]
235    fn test_llm_config_with_api_key_env() {
236        let yaml = r#"
237provider: openai-compatible
238model: my-model
239base_url: http://my-server:8080/v1
240api_key_env: MY_SERVER_KEY
241"#;
242        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
243        assert_eq!(config.api_key_env, Some("MY_SERVER_KEY".to_string()));
244    }
245
246    #[test]
247    fn test_llm_config_base_url_does_not_leak_to_extra() {
248        let yaml = r#"
249provider: openai-compatible
250model: my-model
251base_url: http://localhost:1234/v1
252"#;
253        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
254        assert!(!config.extra.contains_key("base_url"));
255    }
256
257    #[test]
258    fn test_llm_config_deserialize() {
259        let yaml = r#"
260provider: openai
261model: gpt-3.5-turbo
262temperature: 0.5
263max_tokens: 1000
264"#;
265        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
266        assert_eq!(config.provider, "openai");
267        assert_eq!(config.model, "gpt-3.5-turbo");
268        assert_eq!(config.temperature, 0.5);
269        assert_eq!(config.max_tokens, 1000);
270    }
271
272    #[test]
273    fn test_llm_config_with_defaults() {
274        let yaml = r#"
275provider: openai
276model: gpt-4
277"#;
278        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
279        assert_eq!(config.temperature, 0.7); // default
280        assert_eq!(config.max_tokens, 2000); // default
281    }
282
283    #[test]
284    fn test_llm_config_extra_fields() {
285        let yaml = r#"
286provider: openai
287model: gpt-4
288custom_field: "value"
289another_field: 123
290"#;
291        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
292        assert!(config.extra.contains_key("custom_field"));
293        assert!(config.extra.contains_key("another_field"));
294    }
295
296    #[test]
297    fn test_llm_selector_default() {
298        let selector = LLMSelector::default();
299        assert_eq!(selector.default, "default");
300        assert!(selector.router.is_none());
301    }
302
303    #[test]
304    fn test_llm_selector_with_router() {
305        let selector = LLMSelector::new("main").with_router("cheap");
306        assert_eq!(selector.default, "main");
307        assert_eq!(selector.router, Some("cheap".to_string()));
308    }
309
310    #[test]
311    fn test_cli_hitl_metadata_deserialize() {
312        let yaml = r#"
313style: auto_approve
314show_context: false
315"#;
316        let meta: CliHitlMetadata = serde_yaml::from_str(yaml).unwrap();
317        assert_eq!(meta.style, Some(CliHitlStyle::AutoApprove));
318        assert_eq!(meta.show_context, Some(false));
319    }
320
321    #[test]
322    fn test_cli_hitl_style_default() {
323        assert_eq!(CliHitlStyle::default(), CliHitlStyle::Prompt);
324    }
325
326    #[test]
327    fn test_cli_metadata_with_hitl() {
328        let yaml = r#"
329welcome: "Hello"
330hints: []
331hitl:
332  style: prompt
333  show_context: true
334"#;
335        let meta: CliMetadata = serde_yaml::from_str(yaml).unwrap();
336        let hitl = meta.hitl.unwrap();
337        assert_eq!(hitl.style, Some(CliHitlStyle::Prompt));
338        assert_eq!(hitl.show_context, Some(true));
339    }
340
341    #[test]
342    fn test_llm_selector_deserialize() {
343        let yaml = r#"
344default: main
345router: router_llm
346"#;
347        let selector: LLMSelector = serde_yaml::from_str(yaml).unwrap();
348        assert_eq!(selector.default, "main");
349        assert_eq!(selector.router, Some("router_llm".to_string()));
350    }
351
352    #[test]
353    fn test_llm_config_reasoning_fields_deser() {
354        let yaml = r#"
355provider: openai
356model: o3
357timeout_seconds: 120
358reasoning: true
359reasoning_effort: high
360reasoning_budget_tokens: 16384
361"#;
362        let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
363        assert_eq!(config.timeout_seconds, Some(120));
364        assert_eq!(config.reasoning, Some(true));
365        assert_eq!(config.reasoning_effort.as_deref(), Some("high"));
366        assert_eq!(config.reasoning_budget_tokens, Some(16384));
367        // Must NOT leak into extra
368        assert!(!config.extra.contains_key("timeout_seconds"));
369        assert!(!config.extra.contains_key("reasoning"));
370        assert!(!config.extra.contains_key("reasoning_effort"));
371        assert!(!config.extra.contains_key("reasoning_budget_tokens"));
372    }
373}