Skip to main content

punch_types/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Top-level Punch configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7    /// Socket address for the Arena API server (e.g. "0.0.0.0:6660").
8    pub api_listen: String,
9    /// API key for authentication. If empty, auth is disabled (dev mode).
10    #[serde(default)]
11    pub api_key: String,
12    /// Per-IP rate limit in requests per minute. Default: 60.
13    #[serde(default = "default_rate_limit_rpm")]
14    pub rate_limit_rpm: u32,
15    /// Default model to use when none is specified.
16    pub default_model: ModelConfig,
17    /// Memory subsystem configuration.
18    pub memory: MemoryConfig,
19    /// Channel configurations keyed by channel name (e.g. "slack", "discord").
20    #[serde(default)]
21    pub channels: HashMap<String, ChannelConfig>,
22    /// MCP server definitions keyed by server name.
23    #[serde(default)]
24    pub mcp_servers: HashMap<String, McpServerConfig>,
25}
26
27/// Configuration for a language model.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelConfig {
30    /// The provider to use.
31    pub provider: Provider,
32    /// Model identifier (e.g. "claude-sonnet-4-20250514").
33    pub model: String,
34    /// Environment variable name that holds the API key.
35    pub api_key_env: Option<String>,
36    /// Optional base URL override for the provider API.
37    pub base_url: Option<String>,
38    /// Maximum tokens to generate per request.
39    pub max_tokens: Option<u32>,
40    /// Sampling temperature.
41    pub temperature: Option<f32>,
42}
43
44/// Supported model providers.
45#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum Provider {
48    Anthropic,
49    Google,
50    #[serde(rename = "openai")]
51    OpenAI,
52    Groq,
53    DeepSeek,
54    Ollama,
55    Mistral,
56    Together,
57    Fireworks,
58    Cerebras,
59    #[serde(rename = "xai")]
60    XAI,
61    Cohere,
62    Bedrock,
63    #[serde(rename = "azure_openai")]
64    AzureOpenAi,
65    Custom(String),
66}
67
68impl std::fmt::Display for Provider {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            Self::Anthropic => write!(f, "anthropic"),
72            Self::Google => write!(f, "google"),
73            Self::OpenAI => write!(f, "openai"),
74            Self::Groq => write!(f, "groq"),
75            Self::DeepSeek => write!(f, "deepseek"),
76            Self::Ollama => write!(f, "ollama"),
77            Self::Mistral => write!(f, "mistral"),
78            Self::Together => write!(f, "together"),
79            Self::Fireworks => write!(f, "fireworks"),
80            Self::Cerebras => write!(f, "cerebras"),
81            Self::XAI => write!(f, "xai"),
82            Self::Cohere => write!(f, "cohere"),
83            Self::Bedrock => write!(f, "bedrock"),
84            Self::AzureOpenAi => write!(f, "azure_openai"),
85            Self::Custom(name) => write!(f, "custom({})", name),
86        }
87    }
88}
89
90/// Memory subsystem configuration.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct MemoryConfig {
93    /// Path to the SQLite database file.
94    pub db_path: String,
95    /// Whether to enable the knowledge graph.
96    #[serde(default = "default_true")]
97    pub knowledge_graph_enabled: bool,
98    /// Maximum number of memory entries to retain.
99    pub max_entries: Option<u64>,
100}
101
102/// Configuration for a communication channel.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ChannelConfig {
105    /// Channel type identifier (e.g. "slack", "discord", "webhook").
106    pub channel_type: String,
107    /// Environment variable holding the authentication token.
108    pub token_env: Option<String>,
109    /// Additional channel-specific settings.
110    #[serde(default)]
111    pub settings: HashMap<String, serde_json::Value>,
112}
113
114/// Configuration for an MCP (Model Context Protocol) server.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct McpServerConfig {
117    /// Command to start the MCP server.
118    pub command: String,
119    /// Arguments to pass to the command.
120    #[serde(default)]
121    pub args: Vec<String>,
122    /// Environment variables to set for the MCP server process.
123    #[serde(default)]
124    pub env: HashMap<String, String>,
125}
126
127fn default_true() -> bool {
128    true
129}
130
131fn default_rate_limit_rpm() -> u32 {
132    60
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_provider_display_all_variants() {
141        assert_eq!(Provider::Anthropic.to_string(), "anthropic");
142        assert_eq!(Provider::Google.to_string(), "google");
143        assert_eq!(Provider::OpenAI.to_string(), "openai");
144        assert_eq!(Provider::Groq.to_string(), "groq");
145        assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
146        assert_eq!(Provider::Ollama.to_string(), "ollama");
147        assert_eq!(Provider::Mistral.to_string(), "mistral");
148        assert_eq!(Provider::Together.to_string(), "together");
149        assert_eq!(Provider::Fireworks.to_string(), "fireworks");
150        assert_eq!(Provider::Cerebras.to_string(), "cerebras");
151        assert_eq!(Provider::XAI.to_string(), "xai");
152        assert_eq!(Provider::Cohere.to_string(), "cohere");
153        assert_eq!(Provider::Bedrock.to_string(), "bedrock");
154        assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
155        assert_eq!(
156            Provider::Custom("my_provider".to_string()).to_string(),
157            "custom(my_provider)"
158        );
159    }
160
161    #[test]
162    fn test_provider_serde_roundtrip() {
163        let providers = vec![
164            Provider::Anthropic,
165            Provider::Google,
166            Provider::OpenAI,
167            Provider::Groq,
168            Provider::DeepSeek,
169            Provider::Ollama,
170            Provider::Mistral,
171            Provider::Together,
172            Provider::Fireworks,
173            Provider::Cerebras,
174            Provider::XAI,
175            Provider::Cohere,
176            Provider::Bedrock,
177            Provider::AzureOpenAi,
178            Provider::Custom("test".to_string()),
179        ];
180        for provider in &providers {
181            let json = serde_json::to_string(provider).expect("serialize provider");
182            let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
183            assert_eq!(&deser, provider);
184        }
185    }
186
187    #[test]
188    fn test_provider_serde_rename() {
189        assert_eq!(
190            serde_json::to_string(&Provider::OpenAI).unwrap(),
191            "\"openai\""
192        );
193        assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
194        assert_eq!(
195            serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
196            "\"azure_openai\""
197        );
198        assert_eq!(
199            serde_json::to_string(&Provider::Anthropic).unwrap(),
200            "\"anthropic\""
201        );
202    }
203
204    #[test]
205    fn test_provider_equality() {
206        assert_eq!(Provider::Anthropic, Provider::Anthropic);
207        assert_ne!(Provider::Anthropic, Provider::Google);
208        assert_eq!(
209            Provider::Custom("x".to_string()),
210            Provider::Custom("x".to_string())
211        );
212        assert_ne!(
213            Provider::Custom("x".to_string()),
214            Provider::Custom("y".to_string())
215        );
216    }
217
218    #[test]
219    fn test_model_config_serde_roundtrip() {
220        let config = ModelConfig {
221            provider: Provider::Anthropic,
222            model: "claude-sonnet-4-20250514".to_string(),
223            api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
224            base_url: None,
225            max_tokens: Some(4096),
226            temperature: Some(0.7),
227        };
228        let json = serde_json::to_string(&config).expect("serialize");
229        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
230        assert_eq!(deser.model, "claude-sonnet-4-20250514");
231        assert_eq!(deser.provider, Provider::Anthropic);
232        assert_eq!(deser.max_tokens, Some(4096));
233        assert_eq!(deser.temperature, Some(0.7));
234    }
235
236    #[test]
237    fn test_model_config_optional_fields() {
238        let config = ModelConfig {
239            provider: Provider::Ollama,
240            model: "llama3".to_string(),
241            api_key_env: None,
242            base_url: Some("http://localhost:11434".to_string()),
243            max_tokens: None,
244            temperature: None,
245        };
246        let json = serde_json::to_string(&config).expect("serialize");
247        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
248        assert!(deser.api_key_env.is_none());
249        assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
250        assert!(deser.max_tokens.is_none());
251        assert!(deser.temperature.is_none());
252    }
253
254    #[test]
255    fn test_memory_config_defaults() {
256        let json = r#"{"db_path": "/tmp/punch.db"}"#;
257        let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
258        assert_eq!(config.db_path, "/tmp/punch.db");
259        assert!(config.knowledge_graph_enabled); // default_true
260        assert!(config.max_entries.is_none());
261    }
262
263    #[test]
264    fn test_memory_config_roundtrip() {
265        let config = MemoryConfig {
266            db_path: "/data/punch.db".to_string(),
267            knowledge_graph_enabled: false,
268            max_entries: Some(10000),
269        };
270        let json = serde_json::to_string(&config).expect("serialize");
271        let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
272        assert_eq!(deser.db_path, "/data/punch.db");
273        assert!(!deser.knowledge_graph_enabled);
274        assert_eq!(deser.max_entries, Some(10000));
275    }
276
277    #[test]
278    fn test_channel_config_serde() {
279        let config = ChannelConfig {
280            channel_type: "slack".to_string(),
281            token_env: Some("SLACK_TOKEN".to_string()),
282            settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
283        };
284        let json = serde_json::to_string(&config).expect("serialize");
285        let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
286        assert_eq!(deser.channel_type, "slack");
287        assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
288        assert_eq!(deser.settings["channel_id"], "C123456");
289    }
290
291    #[test]
292    fn test_channel_config_defaults() {
293        let json = r#"{"channel_type": "webhook"}"#;
294        let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
295        assert_eq!(config.channel_type, "webhook");
296        assert!(config.token_env.is_none());
297        assert!(config.settings.is_empty());
298    }
299
300    #[test]
301    fn test_mcp_server_config_serde() {
302        let config = McpServerConfig {
303            command: "npx".to_string(),
304            args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
305            env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
306        };
307        let json = serde_json::to_string(&config).expect("serialize");
308        let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
309        assert_eq!(deser.command, "npx");
310        assert_eq!(deser.args.len(), 2);
311        assert_eq!(deser.env["NODE_ENV"], "production");
312    }
313
314    #[test]
315    fn test_mcp_server_config_defaults() {
316        let json = r#"{"command": "python"}"#;
317        let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
318        assert_eq!(config.command, "python");
319        assert!(config.args.is_empty());
320        assert!(config.env.is_empty());
321    }
322
323    #[test]
324    fn test_provider_hash() {
325        let mut set = std::collections::HashSet::new();
326        set.insert(Provider::Anthropic);
327        set.insert(Provider::Google);
328        set.insert(Provider::Anthropic); // duplicate
329        assert_eq!(set.len(), 2);
330    }
331}