1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7 pub api_listen: String,
9 #[serde(default)]
11 pub api_key: String,
12 #[serde(default = "default_rate_limit_rpm")]
14 pub rate_limit_rpm: u32,
15 pub default_model: ModelConfig,
17 pub memory: MemoryConfig,
19 #[serde(default)]
21 pub channels: HashMap<String, ChannelConfig>,
22 #[serde(default)]
24 pub mcp_servers: HashMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelConfig {
30 pub provider: Provider,
32 pub model: String,
34 pub api_key_env: Option<String>,
36 pub base_url: Option<String>,
38 pub max_tokens: Option<u32>,
40 pub temperature: Option<f32>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum Provider {
48 Anthropic,
49 Google,
50 #[serde(rename = "openai")]
51 OpenAI,
52 Groq,
53 DeepSeek,
54 Ollama,
55 Mistral,
56 Together,
57 Fireworks,
58 Cerebras,
59 #[serde(rename = "xai")]
60 XAI,
61 Cohere,
62 Bedrock,
63 #[serde(rename = "azure_openai")]
64 AzureOpenAi,
65 Custom(String),
66}
67
68impl std::fmt::Display for Provider {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Self::Anthropic => write!(f, "anthropic"),
72 Self::Google => write!(f, "google"),
73 Self::OpenAI => write!(f, "openai"),
74 Self::Groq => write!(f, "groq"),
75 Self::DeepSeek => write!(f, "deepseek"),
76 Self::Ollama => write!(f, "ollama"),
77 Self::Mistral => write!(f, "mistral"),
78 Self::Together => write!(f, "together"),
79 Self::Fireworks => write!(f, "fireworks"),
80 Self::Cerebras => write!(f, "cerebras"),
81 Self::XAI => write!(f, "xai"),
82 Self::Cohere => write!(f, "cohere"),
83 Self::Bedrock => write!(f, "bedrock"),
84 Self::AzureOpenAi => write!(f, "azure_openai"),
85 Self::Custom(name) => write!(f, "custom({})", name),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct MemoryConfig {
93 pub db_path: String,
95 #[serde(default = "default_true")]
97 pub knowledge_graph_enabled: bool,
98 pub max_entries: Option<u64>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ChannelConfig {
105 pub channel_type: String,
107 pub token_env: Option<String>,
109 #[serde(default)]
111 pub settings: HashMap<String, serde_json::Value>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct McpServerConfig {
117 pub command: String,
119 #[serde(default)]
121 pub args: Vec<String>,
122 #[serde(default)]
124 pub env: HashMap<String, String>,
125}
126
127fn default_true() -> bool {
128 true
129}
130
131fn default_rate_limit_rpm() -> u32 {
132 60
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_provider_display_all_variants() {
141 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
142 assert_eq!(Provider::Google.to_string(), "google");
143 assert_eq!(Provider::OpenAI.to_string(), "openai");
144 assert_eq!(Provider::Groq.to_string(), "groq");
145 assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
146 assert_eq!(Provider::Ollama.to_string(), "ollama");
147 assert_eq!(Provider::Mistral.to_string(), "mistral");
148 assert_eq!(Provider::Together.to_string(), "together");
149 assert_eq!(Provider::Fireworks.to_string(), "fireworks");
150 assert_eq!(Provider::Cerebras.to_string(), "cerebras");
151 assert_eq!(Provider::XAI.to_string(), "xai");
152 assert_eq!(Provider::Cohere.to_string(), "cohere");
153 assert_eq!(Provider::Bedrock.to_string(), "bedrock");
154 assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
155 assert_eq!(
156 Provider::Custom("my_provider".to_string()).to_string(),
157 "custom(my_provider)"
158 );
159 }
160
161 #[test]
162 fn test_provider_serde_roundtrip() {
163 let providers = vec![
164 Provider::Anthropic,
165 Provider::Google,
166 Provider::OpenAI,
167 Provider::Groq,
168 Provider::DeepSeek,
169 Provider::Ollama,
170 Provider::Mistral,
171 Provider::Together,
172 Provider::Fireworks,
173 Provider::Cerebras,
174 Provider::XAI,
175 Provider::Cohere,
176 Provider::Bedrock,
177 Provider::AzureOpenAi,
178 Provider::Custom("test".to_string()),
179 ];
180 for provider in &providers {
181 let json = serde_json::to_string(provider).expect("serialize provider");
182 let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
183 assert_eq!(&deser, provider);
184 }
185 }
186
187 #[test]
188 fn test_provider_serde_rename() {
189 assert_eq!(
190 serde_json::to_string(&Provider::OpenAI).unwrap(),
191 "\"openai\""
192 );
193 assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
194 assert_eq!(
195 serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
196 "\"azure_openai\""
197 );
198 assert_eq!(
199 serde_json::to_string(&Provider::Anthropic).unwrap(),
200 "\"anthropic\""
201 );
202 }
203
204 #[test]
205 fn test_provider_equality() {
206 assert_eq!(Provider::Anthropic, Provider::Anthropic);
207 assert_ne!(Provider::Anthropic, Provider::Google);
208 assert_eq!(
209 Provider::Custom("x".to_string()),
210 Provider::Custom("x".to_string())
211 );
212 assert_ne!(
213 Provider::Custom("x".to_string()),
214 Provider::Custom("y".to_string())
215 );
216 }
217
218 #[test]
219 fn test_model_config_serde_roundtrip() {
220 let config = ModelConfig {
221 provider: Provider::Anthropic,
222 model: "claude-sonnet-4-20250514".to_string(),
223 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
224 base_url: None,
225 max_tokens: Some(4096),
226 temperature: Some(0.7),
227 };
228 let json = serde_json::to_string(&config).expect("serialize");
229 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
230 assert_eq!(deser.model, "claude-sonnet-4-20250514");
231 assert_eq!(deser.provider, Provider::Anthropic);
232 assert_eq!(deser.max_tokens, Some(4096));
233 assert_eq!(deser.temperature, Some(0.7));
234 }
235
236 #[test]
237 fn test_model_config_optional_fields() {
238 let config = ModelConfig {
239 provider: Provider::Ollama,
240 model: "llama3".to_string(),
241 api_key_env: None,
242 base_url: Some("http://localhost:11434".to_string()),
243 max_tokens: None,
244 temperature: None,
245 };
246 let json = serde_json::to_string(&config).expect("serialize");
247 let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
248 assert!(deser.api_key_env.is_none());
249 assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
250 assert!(deser.max_tokens.is_none());
251 assert!(deser.temperature.is_none());
252 }
253
254 #[test]
255 fn test_memory_config_defaults() {
256 let json = r#"{"db_path": "/tmp/punch.db"}"#;
257 let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
258 assert_eq!(config.db_path, "/tmp/punch.db");
259 assert!(config.knowledge_graph_enabled); assert!(config.max_entries.is_none());
261 }
262
263 #[test]
264 fn test_memory_config_roundtrip() {
265 let config = MemoryConfig {
266 db_path: "/data/punch.db".to_string(),
267 knowledge_graph_enabled: false,
268 max_entries: Some(10000),
269 };
270 let json = serde_json::to_string(&config).expect("serialize");
271 let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
272 assert_eq!(deser.db_path, "/data/punch.db");
273 assert!(!deser.knowledge_graph_enabled);
274 assert_eq!(deser.max_entries, Some(10000));
275 }
276
277 #[test]
278 fn test_channel_config_serde() {
279 let config = ChannelConfig {
280 channel_type: "slack".to_string(),
281 token_env: Some("SLACK_TOKEN".to_string()),
282 settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
283 };
284 let json = serde_json::to_string(&config).expect("serialize");
285 let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
286 assert_eq!(deser.channel_type, "slack");
287 assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
288 assert_eq!(deser.settings["channel_id"], "C123456");
289 }
290
291 #[test]
292 fn test_channel_config_defaults() {
293 let json = r#"{"channel_type": "webhook"}"#;
294 let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
295 assert_eq!(config.channel_type, "webhook");
296 assert!(config.token_env.is_none());
297 assert!(config.settings.is_empty());
298 }
299
300 #[test]
301 fn test_mcp_server_config_serde() {
302 let config = McpServerConfig {
303 command: "npx".to_string(),
304 args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
305 env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
306 };
307 let json = serde_json::to_string(&config).expect("serialize");
308 let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
309 assert_eq!(deser.command, "npx");
310 assert_eq!(deser.args.len(), 2);
311 assert_eq!(deser.env["NODE_ENV"], "production");
312 }
313
314 #[test]
315 fn test_mcp_server_config_defaults() {
316 let json = r#"{"command": "python"}"#;
317 let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
318 assert_eq!(config.command, "python");
319 assert!(config.args.is_empty());
320 assert!(config.env.is_empty());
321 }
322
323 #[test]
324 fn test_provider_hash() {
325 let mut set = std::collections::HashSet::new();
326 set.insert(Provider::Anthropic);
327 set.insert(Provider::Google);
328 set.insert(Provider::Anthropic); assert_eq!(set.len(), 2);
330 }
331}