1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7 pub api_listen: String,
11 #[serde(default)]
13 pub api_key: String,
14 #[serde(default = "default_rate_limit_rpm")]
16 pub rate_limit_rpm: u32,
17 pub default_model: ModelConfig,
19 pub memory: MemoryConfig,
21 #[serde(default)]
23 pub tunnel: Option<TunnelConfig>,
24 #[serde(default)]
26 pub channels: HashMap<String, ChannelConfig>,
27 #[serde(default)]
29 pub mcp_servers: HashMap<String, McpServerConfig>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelConfig {
35 pub provider: Provider,
37 pub model: String,
39 pub api_key_env: Option<String>,
41 pub base_url: Option<String>,
43 pub max_tokens: Option<u32>,
45 pub temperature: Option<f32>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum Provider {
53 Anthropic,
54 Google,
55 #[serde(rename = "openai")]
56 OpenAI,
57 Groq,
58 DeepSeek,
59 Ollama,
60 Mistral,
61 Together,
62 Fireworks,
63 Cerebras,
64 #[serde(rename = "xai")]
65 XAI,
66 Cohere,
67 Bedrock,
68 #[serde(rename = "azure_openai")]
69 AzureOpenAi,
70 Custom(String),
71}
72
73impl std::fmt::Display for Provider {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 Self::Anthropic => write!(f, "anthropic"),
77 Self::Google => write!(f, "google"),
78 Self::OpenAI => write!(f, "openai"),
79 Self::Groq => write!(f, "groq"),
80 Self::DeepSeek => write!(f, "deepseek"),
81 Self::Ollama => write!(f, "ollama"),
82 Self::Mistral => write!(f, "mistral"),
83 Self::Together => write!(f, "together"),
84 Self::Fireworks => write!(f, "fireworks"),
85 Self::Cerebras => write!(f, "cerebras"),
86 Self::XAI => write!(f, "xai"),
87 Self::Cohere => write!(f, "cohere"),
88 Self::Bedrock => write!(f, "bedrock"),
89 Self::AzureOpenAi => write!(f, "azure_openai"),
90 Self::Custom(name) => write!(f, "custom({})", name),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MemoryConfig {
98 pub db_path: String,
100 #[serde(default = "default_true")]
102 pub knowledge_graph_enabled: bool,
103 pub max_entries: Option<u64>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TunnelConfig {
110 pub base_url: String,
113 #[serde(default = "default_tunnel_mode")]
115 pub mode: String,
116}
117
118fn default_tunnel_mode() -> String {
119 "manual".to_string()
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ChannelConfig {
125 pub channel_type: String,
127 pub token_env: Option<String>,
129 pub webhook_secret_env: Option<String>,
132 #[serde(default)]
135 pub allowed_user_ids: Vec<String>,
136 #[serde(default = "default_channel_rate_limit")]
138 pub rate_limit_per_user: u32,
139 #[serde(default)]
141 pub settings: HashMap<String, serde_json::Value>,
142}
143
144fn default_channel_rate_limit() -> u32 {
145 20
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct McpServerConfig {
151 pub command: String,
153 #[serde(default)]
155 pub args: Vec<String>,
156 #[serde(default)]
158 pub env: HashMap<String, String>,
159}
160
161fn default_true() -> bool {
162 true
163}
164
165fn default_rate_limit_rpm() -> u32 {
166 60
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_provider_display_all_variants() {
175 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
176 assert_eq!(Provider::Google.to_string(), "google");
177 assert_eq!(Provider::OpenAI.to_string(), "openai");
178 assert_eq!(Provider::Groq.to_string(), "groq");
179 assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
180 assert_eq!(Provider::Ollama.to_string(), "ollama");
181 assert_eq!(Provider::Mistral.to_string(), "mistral");
182 assert_eq!(Provider::Together.to_string(), "together");
183 assert_eq!(Provider::Fireworks.to_string(), "fireworks");
184 assert_eq!(Provider::Cerebras.to_string(), "cerebras");
185 assert_eq!(Provider::XAI.to_string(), "xai");
186 assert_eq!(Provider::Cohere.to_string(), "cohere");
187 assert_eq!(Provider::Bedrock.to_string(), "bedrock");
188 assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
189 assert_eq!(
190 Provider::Custom("my_provider".to_string()).to_string(),
191 "custom(my_provider)"
192 );
193 }
194
195 #[test]
196 fn test_provider_serde_roundtrip() {
197 let providers = vec![
198 Provider::Anthropic,
199 Provider::Google,
200 Provider::OpenAI,
201 Provider::Groq,
202 Provider::DeepSeek,
203 Provider::Ollama,
204 Provider::Mistral,
205 Provider::Together,
206 Provider::Fireworks,
207 Provider::Cerebras,
208 Provider::XAI,
209 Provider::Cohere,
210 Provider::Bedrock,
211 Provider::AzureOpenAi,
212 Provider::Custom("test".to_string()),
213 ];
214 for provider in &providers {
215 let json = serde_json::to_string(provider).expect("serialize provider");
216 let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
217 assert_eq!(&deser, provider);
218 }
219 }
220
221 #[test]
222 fn test_provider_serde_rename() {
223 assert_eq!(
224 serde_json::to_string(&Provider::OpenAI).unwrap(),
225 "\"openai\""
226 );
227 assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
228 assert_eq!(
229 serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
230 "\"azure_openai\""
231 );
232 assert_eq!(
233 serde_json::to_string(&Provider::Anthropic).unwrap(),
234 "\"anthropic\""
235 );
236 }
237
238 #[test]
239 fn test_provider_equality() {
240 assert_eq!(Provider::Anthropic, Provider::Anthropic);
241 assert_ne!(Provider::Anthropic, Provider::Google);
242 assert_eq!(
243 Provider::Custom("x".to_string()),
244 Provider::Custom("x".to_string())
245 );
246 assert_ne!(
247 Provider::Custom("x".to_string()),
248 Provider::Custom("y".to_string())
249 );
250 }
251
252 #[test]
253 fn test_model_config_serde_roundtrip() {
254 let config = ModelConfig {
255 provider: Provider::Anthropic,
256 model: "claude-sonnet-4-20250514".to_string(),
257 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
258 base_url: None,
259 max_tokens: Some(4096),
260 temperature: Some(0.7),
261 };
262 let json = serde_json::to_string(&config).expect("serialize");
263 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
264 assert_eq!(deser.model, "claude-sonnet-4-20250514");
265 assert_eq!(deser.provider, Provider::Anthropic);
266 assert_eq!(deser.max_tokens, Some(4096));
267 assert_eq!(deser.temperature, Some(0.7));
268 }
269
270 #[test]
271 fn test_model_config_optional_fields() {
272 let config = ModelConfig {
273 provider: Provider::Ollama,
274 model: "llama3".to_string(),
275 api_key_env: None,
276 base_url: Some("http://localhost:11434".to_string()),
277 max_tokens: None,
278 temperature: None,
279 };
280 let json = serde_json::to_string(&config).expect("serialize");
281 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
282 assert!(deser.api_key_env.is_none());
283 assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
284 assert!(deser.max_tokens.is_none());
285 assert!(deser.temperature.is_none());
286 }
287
288 #[test]
289 fn test_memory_config_defaults() {
290 let json = r#"{"db_path": "/tmp/punch.db"}"#;
291 let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
292 assert_eq!(config.db_path, "/tmp/punch.db");
293 assert!(config.knowledge_graph_enabled); assert!(config.max_entries.is_none());
295 }
296
297 #[test]
298 fn test_memory_config_roundtrip() {
299 let config = MemoryConfig {
300 db_path: "/data/punch.db".to_string(),
301 knowledge_graph_enabled: false,
302 max_entries: Some(10000),
303 };
304 let json = serde_json::to_string(&config).expect("serialize");
305 let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
306 assert_eq!(deser.db_path, "/data/punch.db");
307 assert!(!deser.knowledge_graph_enabled);
308 assert_eq!(deser.max_entries, Some(10000));
309 }
310
311 #[test]
312 fn test_channel_config_serde() {
313 let config = ChannelConfig {
314 channel_type: "slack".to_string(),
315 token_env: Some("SLACK_TOKEN".to_string()),
316 webhook_secret_env: Some("SLACK_SIGNING_SECRET".to_string()),
317 allowed_user_ids: vec!["U123".to_string()],
318 rate_limit_per_user: 20,
319 settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
320 };
321 let json = serde_json::to_string(&config).expect("serialize");
322 let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
323 assert_eq!(deser.channel_type, "slack");
324 assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
325 assert_eq!(deser.settings["channel_id"], "C123456");
326 }
327
328 #[test]
329 fn test_channel_config_defaults() {
330 let json = r#"{"channel_type": "webhook"}"#;
331 let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
332 assert_eq!(config.channel_type, "webhook");
333 assert!(config.token_env.is_none());
334 assert!(config.settings.is_empty());
335 }
336
337 #[test]
338 fn test_mcp_server_config_serde() {
339 let config = McpServerConfig {
340 command: "npx".to_string(),
341 args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
342 env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
343 };
344 let json = serde_json::to_string(&config).expect("serialize");
345 let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
346 assert_eq!(deser.command, "npx");
347 assert_eq!(deser.args.len(), 2);
348 assert_eq!(deser.env["NODE_ENV"], "production");
349 }
350
351 #[test]
352 fn test_mcp_server_config_defaults() {
353 let json = r#"{"command": "python"}"#;
354 let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
355 assert_eq!(config.command, "python");
356 assert!(config.args.is_empty());
357 assert!(config.env.is_empty());
358 }
359
360 #[test]
361 fn test_tunnel_config_serde() {
362 let config = TunnelConfig {
363 base_url: "https://abc.trycloudflare.com".to_string(),
364 mode: "quick".to_string(),
365 };
366 let json = serde_json::to_string(&config).expect("serialize");
367 let deser: TunnelConfig = serde_json::from_str(&json).expect("deserialize");
368 assert_eq!(deser.base_url, "https://abc.trycloudflare.com");
369 assert_eq!(deser.mode, "quick");
370 }
371
372 #[test]
373 fn test_tunnel_config_default_mode() {
374 let json = r#"{"base_url": "https://example.com"}"#;
375 let config: TunnelConfig = serde_json::from_str(json).expect("deserialize");
376 assert_eq!(config.mode, "manual");
377 }
378
379 #[test]
380 fn test_provider_hash() {
381 let mut set = std::collections::HashSet::new();
382 set.insert(Provider::Anthropic);
383 set.insert(Provider::Google);
384 set.insert(Provider::Anthropic); assert_eq!(set.len(), 2);
386 }
387}