1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7 pub api_listen: String,
11 #[serde(default)]
13 pub api_key: String,
14 #[serde(default = "default_rate_limit_rpm")]
16 pub rate_limit_rpm: u32,
17 pub default_model: ModelConfig,
19 pub memory: MemoryConfig,
21 #[serde(default)]
23 pub tunnel: Option<TunnelConfig>,
24 #[serde(default)]
26 pub channels: HashMap<String, ChannelConfig>,
27 #[serde(default)]
29 pub mcp_servers: HashMap<String, McpServerConfig>,
30 #[serde(default)]
33 pub model_routing: ModelRoutingConfig,
34 #[serde(default)]
37 pub budget: BudgetConfig,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43 pub provider: Provider,
45 pub model: String,
47 pub api_key_env: Option<String>,
49 pub base_url: Option<String>,
51 pub max_tokens: Option<u32>,
53 pub temperature: Option<f32>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum Provider {
61 Anthropic,
62 Google,
63 #[serde(rename = "openai")]
64 OpenAI,
65 Groq,
66 DeepSeek,
67 Ollama,
68 Mistral,
69 Together,
70 Fireworks,
71 Cerebras,
72 #[serde(rename = "xai")]
73 XAI,
74 Cohere,
75 Bedrock,
76 #[serde(rename = "azure_openai")]
77 AzureOpenAi,
78 Custom(String),
79}
80
81impl std::fmt::Display for Provider {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Self::Anthropic => write!(f, "anthropic"),
85 Self::Google => write!(f, "google"),
86 Self::OpenAI => write!(f, "openai"),
87 Self::Groq => write!(f, "groq"),
88 Self::DeepSeek => write!(f, "deepseek"),
89 Self::Ollama => write!(f, "ollama"),
90 Self::Mistral => write!(f, "mistral"),
91 Self::Together => write!(f, "together"),
92 Self::Fireworks => write!(f, "fireworks"),
93 Self::Cerebras => write!(f, "cerebras"),
94 Self::XAI => write!(f, "xai"),
95 Self::Cohere => write!(f, "cohere"),
96 Self::Bedrock => write!(f, "bedrock"),
97 Self::AzureOpenAi => write!(f, "azure_openai"),
98 Self::Custom(name) => write!(f, "custom({})", name),
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct MemoryConfig {
106 pub db_path: String,
108 #[serde(default = "default_true")]
110 pub knowledge_graph_enabled: bool,
111 pub max_entries: Option<u64>,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TunnelConfig {
118 pub base_url: String,
121 #[serde(default = "default_tunnel_mode")]
123 pub mode: String,
124}
125
126fn default_tunnel_mode() -> String {
127 "manual".to_string()
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ChannelConfig {
133 pub channel_type: String,
135 pub token_env: Option<String>,
137 pub webhook_secret_env: Option<String>,
140 #[serde(default)]
143 pub allowed_user_ids: Vec<String>,
144 #[serde(default = "default_channel_rate_limit")]
146 pub rate_limit_per_user: u32,
147 #[serde(default)]
149 pub settings: HashMap<String, serde_json::Value>,
150}
151
152fn default_channel_rate_limit() -> u32 {
153 20
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct McpServerConfig {
159 pub command: String,
161 #[serde(default)]
163 pub args: Vec<String>,
164 #[serde(default)]
166 pub env: HashMap<String, String>,
167}
168
169fn default_true() -> bool {
170 true
171}
172
173fn default_rate_limit_rpm() -> u32 {
174 60
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, Default)]
183pub struct ModelRoutingConfig {
184 #[serde(default)]
187 pub enabled: bool,
188 pub cheap: Option<ModelConfig>,
190 pub mid: Option<ModelConfig>,
192 pub expensive: Option<ModelConfig>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct BudgetConfig {
203 pub daily_cost_limit_usd: Option<f64>,
205 pub monthly_cost_limit_usd: Option<f64>,
207 #[serde(default = "default_eco_threshold")]
210 pub eco_mode_threshold_percent: u8,
211}
212
213fn default_eco_threshold() -> u8 {
214 80
215}
216
217impl Default for BudgetConfig {
218 fn default() -> Self {
219 Self {
220 daily_cost_limit_usd: None,
221 monthly_cost_limit_usd: None,
222 eco_mode_threshold_percent: default_eco_threshold(),
223 }
224 }
225}
226
227impl BudgetConfig {
228 pub fn has_any_limit(&self) -> bool {
230 self.daily_cost_limit_usd.is_some() || self.monthly_cost_limit_usd.is_some()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_provider_display_all_variants() {
240 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
241 assert_eq!(Provider::Google.to_string(), "google");
242 assert_eq!(Provider::OpenAI.to_string(), "openai");
243 assert_eq!(Provider::Groq.to_string(), "groq");
244 assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
245 assert_eq!(Provider::Ollama.to_string(), "ollama");
246 assert_eq!(Provider::Mistral.to_string(), "mistral");
247 assert_eq!(Provider::Together.to_string(), "together");
248 assert_eq!(Provider::Fireworks.to_string(), "fireworks");
249 assert_eq!(Provider::Cerebras.to_string(), "cerebras");
250 assert_eq!(Provider::XAI.to_string(), "xai");
251 assert_eq!(Provider::Cohere.to_string(), "cohere");
252 assert_eq!(Provider::Bedrock.to_string(), "bedrock");
253 assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
254 assert_eq!(
255 Provider::Custom("my_provider".to_string()).to_string(),
256 "custom(my_provider)"
257 );
258 }
259
260 #[test]
261 fn test_provider_serde_roundtrip() {
262 let providers = vec![
263 Provider::Anthropic,
264 Provider::Google,
265 Provider::OpenAI,
266 Provider::Groq,
267 Provider::DeepSeek,
268 Provider::Ollama,
269 Provider::Mistral,
270 Provider::Together,
271 Provider::Fireworks,
272 Provider::Cerebras,
273 Provider::XAI,
274 Provider::Cohere,
275 Provider::Bedrock,
276 Provider::AzureOpenAi,
277 Provider::Custom("test".to_string()),
278 ];
279 for provider in &providers {
280 let json = serde_json::to_string(provider).expect("serialize provider");
281 let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
282 assert_eq!(&deser, provider);
283 }
284 }
285
286 #[test]
287 fn test_provider_serde_rename() {
288 assert_eq!(
289 serde_json::to_string(&Provider::OpenAI).unwrap(),
290 "\"openai\""
291 );
292 assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
293 assert_eq!(
294 serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
295 "\"azure_openai\""
296 );
297 assert_eq!(
298 serde_json::to_string(&Provider::Anthropic).unwrap(),
299 "\"anthropic\""
300 );
301 }
302
303 #[test]
304 fn test_provider_equality() {
305 assert_eq!(Provider::Anthropic, Provider::Anthropic);
306 assert_ne!(Provider::Anthropic, Provider::Google);
307 assert_eq!(
308 Provider::Custom("x".to_string()),
309 Provider::Custom("x".to_string())
310 );
311 assert_ne!(
312 Provider::Custom("x".to_string()),
313 Provider::Custom("y".to_string())
314 );
315 }
316
317 #[test]
318 fn test_model_config_serde_roundtrip() {
319 let config = ModelConfig {
320 provider: Provider::Anthropic,
321 model: "claude-sonnet-4-20250514".to_string(),
322 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
323 base_url: None,
324 max_tokens: Some(4096),
325 temperature: Some(0.7),
326 };
327 let json = serde_json::to_string(&config).expect("serialize");
328 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
329 assert_eq!(deser.model, "claude-sonnet-4-20250514");
330 assert_eq!(deser.provider, Provider::Anthropic);
331 assert_eq!(deser.max_tokens, Some(4096));
332 assert_eq!(deser.temperature, Some(0.7));
333 }
334
335 #[test]
336 fn test_model_config_optional_fields() {
337 let config = ModelConfig {
338 provider: Provider::Ollama,
339 model: "llama3".to_string(),
340 api_key_env: None,
341 base_url: Some("http://localhost:11434".to_string()),
342 max_tokens: None,
343 temperature: None,
344 };
345 let json = serde_json::to_string(&config).expect("serialize");
346 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
347 assert!(deser.api_key_env.is_none());
348 assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
349 assert!(deser.max_tokens.is_none());
350 assert!(deser.temperature.is_none());
351 }
352
353 #[test]
354 fn test_memory_config_defaults() {
355 let json = r#"{"db_path": "/tmp/punch.db"}"#;
356 let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
357 assert_eq!(config.db_path, "/tmp/punch.db");
358 assert!(config.knowledge_graph_enabled); assert!(config.max_entries.is_none());
360 }
361
362 #[test]
363 fn test_memory_config_roundtrip() {
364 let config = MemoryConfig {
365 db_path: "/data/punch.db".to_string(),
366 knowledge_graph_enabled: false,
367 max_entries: Some(10000),
368 };
369 let json = serde_json::to_string(&config).expect("serialize");
370 let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
371 assert_eq!(deser.db_path, "/data/punch.db");
372 assert!(!deser.knowledge_graph_enabled);
373 assert_eq!(deser.max_entries, Some(10000));
374 }
375
376 #[test]
377 fn test_channel_config_serde() {
378 let config = ChannelConfig {
379 channel_type: "slack".to_string(),
380 token_env: Some("SLACK_TOKEN".to_string()),
381 webhook_secret_env: Some("SLACK_SIGNING_SECRET".to_string()),
382 allowed_user_ids: vec!["U123".to_string()],
383 rate_limit_per_user: 20,
384 settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
385 };
386 let json = serde_json::to_string(&config).expect("serialize");
387 let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
388 assert_eq!(deser.channel_type, "slack");
389 assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
390 assert_eq!(deser.settings["channel_id"], "C123456");
391 }
392
393 #[test]
394 fn test_channel_config_defaults() {
395 let json = r#"{"channel_type": "webhook"}"#;
396 let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
397 assert_eq!(config.channel_type, "webhook");
398 assert!(config.token_env.is_none());
399 assert!(config.settings.is_empty());
400 }
401
402 #[test]
403 fn test_mcp_server_config_serde() {
404 let config = McpServerConfig {
405 command: "npx".to_string(),
406 args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
407 env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
408 };
409 let json = serde_json::to_string(&config).expect("serialize");
410 let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
411 assert_eq!(deser.command, "npx");
412 assert_eq!(deser.args.len(), 2);
413 assert_eq!(deser.env["NODE_ENV"], "production");
414 }
415
416 #[test]
417 fn test_mcp_server_config_defaults() {
418 let json = r#"{"command": "python"}"#;
419 let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
420 assert_eq!(config.command, "python");
421 assert!(config.args.is_empty());
422 assert!(config.env.is_empty());
423 }
424
425 #[test]
426 fn test_tunnel_config_serde() {
427 let config = TunnelConfig {
428 base_url: "https://abc.trycloudflare.com".to_string(),
429 mode: "quick".to_string(),
430 };
431 let json = serde_json::to_string(&config).expect("serialize");
432 let deser: TunnelConfig = serde_json::from_str(&json).expect("deserialize");
433 assert_eq!(deser.base_url, "https://abc.trycloudflare.com");
434 assert_eq!(deser.mode, "quick");
435 }
436
437 #[test]
438 fn test_tunnel_config_default_mode() {
439 let json = r#"{"base_url": "https://example.com"}"#;
440 let config: TunnelConfig = serde_json::from_str(json).expect("deserialize");
441 assert_eq!(config.mode, "manual");
442 }
443
444 #[test]
445 fn test_provider_hash() {
446 let mut set = std::collections::HashSet::new();
447 set.insert(Provider::Anthropic);
448 set.insert(Provider::Google);
449 set.insert(Provider::Anthropic); assert_eq!(set.len(), 2);
451 }
452
453 #[test]
454 fn test_budget_config_defaults() {
455 let config = BudgetConfig::default();
456 assert!(!config.has_any_limit());
457 assert!(config.daily_cost_limit_usd.is_none());
458 assert!(config.monthly_cost_limit_usd.is_none());
459 assert_eq!(config.eco_mode_threshold_percent, 80);
460 }
461
462 #[test]
463 fn test_budget_config_has_limit() {
464 let config = BudgetConfig {
465 daily_cost_limit_usd: Some(1.0),
466 ..Default::default()
467 };
468 assert!(config.has_any_limit());
469 }
470
471 #[test]
472 fn test_budget_config_serde() {
473 let json = r#"{"daily_cost_limit_usd": 5.0, "monthly_cost_limit_usd": 15.0, "eco_mode_threshold_percent": 70}"#;
474 let config: BudgetConfig = serde_json::from_str(json).expect("deserialize");
475 assert_eq!(config.daily_cost_limit_usd, Some(5.0));
476 assert_eq!(config.monthly_cost_limit_usd, Some(15.0));
477 assert_eq!(config.eco_mode_threshold_percent, 70);
478 assert!(config.has_any_limit());
479 }
480
481 #[test]
482 fn test_budget_config_serde_empty() {
483 let json = "{}";
484 let config: BudgetConfig = serde_json::from_str(json).expect("deserialize");
485 assert!(!config.has_any_limit());
486 assert_eq!(config.eco_mode_threshold_percent, 80);
487 }
488}