1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
7pub struct CliMetadata {
8 #[serde(default, skip_serializing_if = "Option::is_none")]
9 pub welcome: Option<String>,
10
11 #[serde(default, skip_serializing_if = "Vec::is_empty")]
12 pub hints: Vec<String>,
13
14 #[serde(default, skip_serializing_if = "Option::is_none")]
15 pub show_tools: Option<bool>,
16
17 #[serde(default, skip_serializing_if = "Option::is_none")]
18 pub show_state: Option<bool>,
19
20 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub show_timing: Option<bool>,
22
23 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub streaming: Option<bool>,
25
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub prompt_style: Option<CliPromptStyle>,
28
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub disable_builtin_commands: Option<bool>,
31
32 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub hitl: Option<CliHitlMetadata>,
34
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub theme: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "snake_case")]
41pub enum CliPromptStyle {
42 Simple,
43 WithState,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
48#[serde(rename_all = "snake_case")]
49pub enum CliHitlStyle {
50 #[default]
52 Prompt,
53 AutoApprove,
55 AutoReject,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, Default)]
61pub struct CliHitlMetadata {
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub style: Option<CliHitlStyle>,
64
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub show_context: Option<bool>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LLMConfig {
72 pub provider: String,
73
74 pub model: String,
75
76 #[serde(default = "default_temperature")]
77 pub temperature: f32,
78
79 #[serde(default = "default_max_tokens")]
80 pub max_tokens: u32,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub top_p: Option<f32>,
84
85 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub base_url: Option<String>,
89
90 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub api_key_env: Option<String>,
94
95 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub timeout_seconds: Option<u64>,
98
99 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub reasoning: Option<bool>,
102
103 #[serde(default, skip_serializing_if = "Option::is_none")]
105 pub reasoning_effort: Option<String>,
106
107 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub reasoning_budget_tokens: Option<u32>,
110
111 #[serde(flatten)]
113 pub extra: HashMap<String, serde_json::Value>,
114}
115
116fn default_temperature() -> f32 {
117 0.7
118}
119
120fn default_max_tokens() -> u32 {
121 2000
122}
123
124impl Default for LLMConfig {
125 fn default() -> Self {
126 Self {
127 provider: "openai".to_string(),
128 model: "gpt-4".to_string(),
129 temperature: default_temperature(),
130 max_tokens: default_max_tokens(),
131 top_p: None,
132 base_url: None,
133 api_key_env: None,
134 timeout_seconds: None,
135 reasoning: None,
136 reasoning_effort: None,
137 reasoning_budget_tokens: None,
138 extra: HashMap::new(),
139 }
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct LLMSelector {
145 #[serde(default = "default_alias")]
146 pub default: String,
147 #[serde(default)]
148 pub router: Option<String>,
149}
150
151fn default_alias() -> String {
152 "default".to_string()
153}
154
155impl Default for LLMSelector {
156 fn default() -> Self {
157 Self {
158 default: default_alias(),
159 router: None,
160 }
161 }
162}
163
164impl LLMSelector {
165 pub fn new(default: impl Into<String>) -> Self {
166 Self {
167 default: default.into(),
168 router: None,
169 }
170 }
171
172 pub fn with_router(mut self, router: impl Into<String>) -> Self {
173 self.router = Some(router.into());
174 self
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_cli_metadata_deserialize() {
184 let yaml = r#"
185welcome: "=== Demo ==="
186hints:
187 - "Try: hello"
188 - "Try: help"
189show_tools: true
190show_state: false
191show_timing: true
192streaming: true
193prompt_style: with_state
194disable_builtin_commands: false
195"#;
196 let metadata: CliMetadata = serde_yaml::from_str(yaml).unwrap();
197 assert_eq!(metadata.welcome.as_deref(), Some("=== Demo ==="));
198 assert_eq!(metadata.hints.len(), 2);
199 assert_eq!(metadata.show_tools, Some(true));
200 assert_eq!(metadata.show_state, Some(false));
201 assert_eq!(metadata.show_timing, Some(true));
202 assert_eq!(metadata.streaming, Some(true));
203 assert_eq!(metadata.prompt_style, Some(CliPromptStyle::WithState));
204 assert_eq!(metadata.disable_builtin_commands, Some(false));
205 assert!(metadata.hitl.is_none());
206 }
207
208 #[test]
209 fn test_llm_config_default() {
210 let config = LLMConfig::default();
211 assert_eq!(config.provider, "openai");
212 assert_eq!(config.model, "gpt-4");
213 assert_eq!(config.temperature, 0.7);
214 assert_eq!(config.max_tokens, 2000);
215 assert_eq!(config.base_url, None);
216 assert_eq!(config.api_key_env, None);
217 }
218
219 #[test]
220 fn test_llm_config_with_base_url() {
221 let yaml = r#"
222provider: openai-compatible
223model: llama3.2
224base_url: http://localhost:1234/v1
225"#;
226 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
227 assert_eq!(config.provider, "openai-compatible");
228 assert_eq!(
229 config.base_url,
230 Some("http://localhost:1234/v1".to_string())
231 );
232 }
233
234 #[test]
235 fn test_llm_config_with_api_key_env() {
236 let yaml = r#"
237provider: openai-compatible
238model: my-model
239base_url: http://my-server:8080/v1
240api_key_env: MY_SERVER_KEY
241"#;
242 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
243 assert_eq!(config.api_key_env, Some("MY_SERVER_KEY".to_string()));
244 }
245
246 #[test]
247 fn test_llm_config_base_url_does_not_leak_to_extra() {
248 let yaml = r#"
249provider: openai-compatible
250model: my-model
251base_url: http://localhost:1234/v1
252"#;
253 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
254 assert!(!config.extra.contains_key("base_url"));
255 }
256
257 #[test]
258 fn test_llm_config_deserialize() {
259 let yaml = r#"
260provider: openai
261model: gpt-3.5-turbo
262temperature: 0.5
263max_tokens: 1000
264"#;
265 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
266 assert_eq!(config.provider, "openai");
267 assert_eq!(config.model, "gpt-3.5-turbo");
268 assert_eq!(config.temperature, 0.5);
269 assert_eq!(config.max_tokens, 1000);
270 }
271
272 #[test]
273 fn test_llm_config_with_defaults() {
274 let yaml = r#"
275provider: openai
276model: gpt-4
277"#;
278 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
279 assert_eq!(config.temperature, 0.7); assert_eq!(config.max_tokens, 2000); }
282
283 #[test]
284 fn test_llm_config_extra_fields() {
285 let yaml = r#"
286provider: openai
287model: gpt-4
288custom_field: "value"
289another_field: 123
290"#;
291 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
292 assert!(config.extra.contains_key("custom_field"));
293 assert!(config.extra.contains_key("another_field"));
294 }
295
296 #[test]
297 fn test_llm_selector_default() {
298 let selector = LLMSelector::default();
299 assert_eq!(selector.default, "default");
300 assert!(selector.router.is_none());
301 }
302
303 #[test]
304 fn test_llm_selector_with_router() {
305 let selector = LLMSelector::new("main").with_router("cheap");
306 assert_eq!(selector.default, "main");
307 assert_eq!(selector.router, Some("cheap".to_string()));
308 }
309
310 #[test]
311 fn test_cli_hitl_metadata_deserialize() {
312 let yaml = r#"
313style: auto_approve
314show_context: false
315"#;
316 let meta: CliHitlMetadata = serde_yaml::from_str(yaml).unwrap();
317 assert_eq!(meta.style, Some(CliHitlStyle::AutoApprove));
318 assert_eq!(meta.show_context, Some(false));
319 }
320
321 #[test]
322 fn test_cli_hitl_style_default() {
323 assert_eq!(CliHitlStyle::default(), CliHitlStyle::Prompt);
324 }
325
326 #[test]
327 fn test_cli_metadata_with_hitl() {
328 let yaml = r#"
329welcome: "Hello"
330hints: []
331hitl:
332 style: prompt
333 show_context: true
334"#;
335 let meta: CliMetadata = serde_yaml::from_str(yaml).unwrap();
336 let hitl = meta.hitl.unwrap();
337 assert_eq!(hitl.style, Some(CliHitlStyle::Prompt));
338 assert_eq!(hitl.show_context, Some(true));
339 }
340
341 #[test]
342 fn test_llm_selector_deserialize() {
343 let yaml = r#"
344default: main
345router: router_llm
346"#;
347 let selector: LLMSelector = serde_yaml::from_str(yaml).unwrap();
348 assert_eq!(selector.default, "main");
349 assert_eq!(selector.router, Some("router_llm".to_string()));
350 }
351
352 #[test]
353 fn test_llm_config_reasoning_fields_deser() {
354 let yaml = r#"
355provider: openai
356model: o3
357timeout_seconds: 120
358reasoning: true
359reasoning_effort: high
360reasoning_budget_tokens: 16384
361"#;
362 let config: LLMConfig = serde_yaml::from_str(yaml).unwrap();
363 assert_eq!(config.timeout_seconds, Some(120));
364 assert_eq!(config.reasoning, Some(true));
365 assert_eq!(config.reasoning_effort.as_deref(), Some("high"));
366 assert_eq!(config.reasoning_budget_tokens, Some(16384));
367 assert!(!config.extra.contains_key("timeout_seconds"));
369 assert!(!config.extra.contains_key("reasoning"));
370 assert!(!config.extra.contains_key("reasoning_effort"));
371 assert!(!config.extra.contains_key("reasoning_budget_tokens"));
372 }
373}