Skip to main content

punch_types/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Top-level Punch configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7    /// Socket address for the Arena API server (e.g. "127.0.0.1:6660").
8    /// Use 127.0.0.1 for local-only access. Only bind to 0.0.0.0 if you
9    /// need external access AND have authentication configured.
10    pub api_listen: String,
11    /// API key for authentication. If empty, auth is disabled (dev mode).
12    #[serde(default)]
13    pub api_key: String,
14    /// Per-IP rate limit in requests per minute. Default: 60.
15    #[serde(default = "default_rate_limit_rpm")]
16    pub rate_limit_rpm: u32,
17    /// Default model to use when none is specified.
18    pub default_model: ModelConfig,
19    /// Memory subsystem configuration.
20    pub memory: MemoryConfig,
21    /// Tunnel / public URL configuration shared by all channel webhooks.
22    #[serde(default)]
23    pub tunnel: Option<TunnelConfig>,
24    /// Channel configurations keyed by channel name (e.g. "slack", "discord").
25    #[serde(default)]
26    pub channels: HashMap<String, ChannelConfig>,
27    /// MCP server definitions keyed by server name.
28    #[serde(default)]
29    pub mcp_servers: HashMap<String, McpServerConfig>,
30}
31
32/// Configuration for a language model.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelConfig {
35    /// The provider to use.
36    pub provider: Provider,
37    /// Model identifier (e.g. "claude-sonnet-4-20250514").
38    pub model: String,
39    /// Environment variable name that holds the API key.
40    pub api_key_env: Option<String>,
41    /// Optional base URL override for the provider API.
42    pub base_url: Option<String>,
43    /// Maximum tokens to generate per request.
44    pub max_tokens: Option<u32>,
45    /// Sampling temperature.
46    pub temperature: Option<f32>,
47}
48
49/// Supported model providers.
50#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum Provider {
53    Anthropic,
54    Google,
55    #[serde(rename = "openai")]
56    OpenAI,
57    Groq,
58    DeepSeek,
59    Ollama,
60    Mistral,
61    Together,
62    Fireworks,
63    Cerebras,
64    #[serde(rename = "xai")]
65    XAI,
66    Cohere,
67    Bedrock,
68    #[serde(rename = "azure_openai")]
69    AzureOpenAi,
70    Custom(String),
71}
72
73impl std::fmt::Display for Provider {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            Self::Anthropic => write!(f, "anthropic"),
77            Self::Google => write!(f, "google"),
78            Self::OpenAI => write!(f, "openai"),
79            Self::Groq => write!(f, "groq"),
80            Self::DeepSeek => write!(f, "deepseek"),
81            Self::Ollama => write!(f, "ollama"),
82            Self::Mistral => write!(f, "mistral"),
83            Self::Together => write!(f, "together"),
84            Self::Fireworks => write!(f, "fireworks"),
85            Self::Cerebras => write!(f, "cerebras"),
86            Self::XAI => write!(f, "xai"),
87            Self::Cohere => write!(f, "cohere"),
88            Self::Bedrock => write!(f, "bedrock"),
89            Self::AzureOpenAi => write!(f, "azure_openai"),
90            Self::Custom(name) => write!(f, "custom({})", name),
91        }
92    }
93}
94
95/// Memory subsystem configuration.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MemoryConfig {
98    /// Path to the SQLite database file.
99    pub db_path: String,
100    /// Whether to enable the knowledge graph.
101    #[serde(default = "default_true")]
102    pub knowledge_graph_enabled: bool,
103    /// Maximum number of memory entries to retain.
104    pub max_entries: Option<u64>,
105}
106
107/// Configuration for the public tunnel / base URL used by all channel webhooks.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TunnelConfig {
110    /// The public base URL that all channel webhooks share
111    /// (e.g. "https://abc.trycloudflare.com" or "https://channels.yourdomain.com").
112    pub base_url: String,
113    /// How this tunnel was set up: "quick", "named", or "manual".
114    #[serde(default = "default_tunnel_mode")]
115    pub mode: String,
116}
117
118fn default_tunnel_mode() -> String {
119    "manual".to_string()
120}
121
122/// Configuration for a communication channel.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ChannelConfig {
125    /// Channel type identifier (e.g. "slack", "discord", "webhook").
126    pub channel_type: String,
127    /// Environment variable holding the authentication token.
128    pub token_env: Option<String>,
129    /// Environment variable holding the webhook signing secret (for signature verification).
130    /// Slack: signing secret. Telegram: secret_token header value. Discord: public key.
131    pub webhook_secret_env: Option<String>,
132    /// Allowlisted user/chat IDs. Only these users can interact with fighters.
133    /// Empty list = open access (logs a security warning on startup).
134    #[serde(default)]
135    pub allowed_user_ids: Vec<String>,
136    /// Per-user rate limit in messages per minute. Default: 20.
137    #[serde(default = "default_channel_rate_limit")]
138    pub rate_limit_per_user: u32,
139    /// Additional channel-specific settings.
140    #[serde(default)]
141    pub settings: HashMap<String, serde_json::Value>,
142}
143
144fn default_channel_rate_limit() -> u32 {
145    20
146}
147
148/// Configuration for an MCP (Model Context Protocol) server.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct McpServerConfig {
151    /// Command to start the MCP server.
152    pub command: String,
153    /// Arguments to pass to the command.
154    #[serde(default)]
155    pub args: Vec<String>,
156    /// Environment variables to set for the MCP server process.
157    #[serde(default)]
158    pub env: HashMap<String, String>,
159}
160
161fn default_true() -> bool {
162    true
163}
164
165fn default_rate_limit_rpm() -> u32 {
166    60
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_provider_display_all_variants() {
175        assert_eq!(Provider::Anthropic.to_string(), "anthropic");
176        assert_eq!(Provider::Google.to_string(), "google");
177        assert_eq!(Provider::OpenAI.to_string(), "openai");
178        assert_eq!(Provider::Groq.to_string(), "groq");
179        assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
180        assert_eq!(Provider::Ollama.to_string(), "ollama");
181        assert_eq!(Provider::Mistral.to_string(), "mistral");
182        assert_eq!(Provider::Together.to_string(), "together");
183        assert_eq!(Provider::Fireworks.to_string(), "fireworks");
184        assert_eq!(Provider::Cerebras.to_string(), "cerebras");
185        assert_eq!(Provider::XAI.to_string(), "xai");
186        assert_eq!(Provider::Cohere.to_string(), "cohere");
187        assert_eq!(Provider::Bedrock.to_string(), "bedrock");
188        assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
189        assert_eq!(
190            Provider::Custom("my_provider".to_string()).to_string(),
191            "custom(my_provider)"
192        );
193    }
194
195    #[test]
196    fn test_provider_serde_roundtrip() {
197        let providers = vec![
198            Provider::Anthropic,
199            Provider::Google,
200            Provider::OpenAI,
201            Provider::Groq,
202            Provider::DeepSeek,
203            Provider::Ollama,
204            Provider::Mistral,
205            Provider::Together,
206            Provider::Fireworks,
207            Provider::Cerebras,
208            Provider::XAI,
209            Provider::Cohere,
210            Provider::Bedrock,
211            Provider::AzureOpenAi,
212            Provider::Custom("test".to_string()),
213        ];
214        for provider in &providers {
215            let json = serde_json::to_string(provider).expect("serialize provider");
216            let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
217            assert_eq!(&deser, provider);
218        }
219    }
220
221    #[test]
222    fn test_provider_serde_rename() {
223        assert_eq!(
224            serde_json::to_string(&Provider::OpenAI).unwrap(),
225            "\"openai\""
226        );
227        assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
228        assert_eq!(
229            serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
230            "\"azure_openai\""
231        );
232        assert_eq!(
233            serde_json::to_string(&Provider::Anthropic).unwrap(),
234            "\"anthropic\""
235        );
236    }
237
238    #[test]
239    fn test_provider_equality() {
240        assert_eq!(Provider::Anthropic, Provider::Anthropic);
241        assert_ne!(Provider::Anthropic, Provider::Google);
242        assert_eq!(
243            Provider::Custom("x".to_string()),
244            Provider::Custom("x".to_string())
245        );
246        assert_ne!(
247            Provider::Custom("x".to_string()),
248            Provider::Custom("y".to_string())
249        );
250    }
251
252    #[test]
253    fn test_model_config_serde_roundtrip() {
254        let config = ModelConfig {
255            provider: Provider::Anthropic,
256            model: "claude-sonnet-4-20250514".to_string(),
257            api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
258            base_url: None,
259            max_tokens: Some(4096),
260            temperature: Some(0.7),
261        };
262        let json = serde_json::to_string(&config).expect("serialize");
263        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
264        assert_eq!(deser.model, "claude-sonnet-4-20250514");
265        assert_eq!(deser.provider, Provider::Anthropic);
266        assert_eq!(deser.max_tokens, Some(4096));
267        assert_eq!(deser.temperature, Some(0.7));
268    }
269
270    #[test]
271    fn test_model_config_optional_fields() {
272        let config = ModelConfig {
273            provider: Provider::Ollama,
274            model: "llama3".to_string(),
275            api_key_env: None,
276            base_url: Some("http://localhost:11434".to_string()),
277            max_tokens: None,
278            temperature: None,
279        };
280        let json = serde_json::to_string(&config).expect("serialize");
281        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
282        assert!(deser.api_key_env.is_none());
283        assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
284        assert!(deser.max_tokens.is_none());
285        assert!(deser.temperature.is_none());
286    }
287
288    #[test]
289    fn test_memory_config_defaults() {
290        let json = r#"{"db_path": "/tmp/punch.db"}"#;
291        let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
292        assert_eq!(config.db_path, "/tmp/punch.db");
293        assert!(config.knowledge_graph_enabled); // default_true
294        assert!(config.max_entries.is_none());
295    }
296
297    #[test]
298    fn test_memory_config_roundtrip() {
299        let config = MemoryConfig {
300            db_path: "/data/punch.db".to_string(),
301            knowledge_graph_enabled: false,
302            max_entries: Some(10000),
303        };
304        let json = serde_json::to_string(&config).expect("serialize");
305        let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
306        assert_eq!(deser.db_path, "/data/punch.db");
307        assert!(!deser.knowledge_graph_enabled);
308        assert_eq!(deser.max_entries, Some(10000));
309    }
310
311    #[test]
312    fn test_channel_config_serde() {
313        let config = ChannelConfig {
314            channel_type: "slack".to_string(),
315            token_env: Some("SLACK_TOKEN".to_string()),
316            webhook_secret_env: Some("SLACK_SIGNING_SECRET".to_string()),
317            allowed_user_ids: vec!["U123".to_string()],
318            rate_limit_per_user: 20,
319            settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
320        };
321        let json = serde_json::to_string(&config).expect("serialize");
322        let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
323        assert_eq!(deser.channel_type, "slack");
324        assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
325        assert_eq!(deser.settings["channel_id"], "C123456");
326    }
327
328    #[test]
329    fn test_channel_config_defaults() {
330        let json = r#"{"channel_type": "webhook"}"#;
331        let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
332        assert_eq!(config.channel_type, "webhook");
333        assert!(config.token_env.is_none());
334        assert!(config.settings.is_empty());
335    }
336
337    #[test]
338    fn test_mcp_server_config_serde() {
339        let config = McpServerConfig {
340            command: "npx".to_string(),
341            args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
342            env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
343        };
344        let json = serde_json::to_string(&config).expect("serialize");
345        let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
346        assert_eq!(deser.command, "npx");
347        assert_eq!(deser.args.len(), 2);
348        assert_eq!(deser.env["NODE_ENV"], "production");
349    }
350
351    #[test]
352    fn test_mcp_server_config_defaults() {
353        let json = r#"{"command": "python"}"#;
354        let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
355        assert_eq!(config.command, "python");
356        assert!(config.args.is_empty());
357        assert!(config.env.is_empty());
358    }
359
360    #[test]
361    fn test_tunnel_config_serde() {
362        let config = TunnelConfig {
363            base_url: "https://abc.trycloudflare.com".to_string(),
364            mode: "quick".to_string(),
365        };
366        let json = serde_json::to_string(&config).expect("serialize");
367        let deser: TunnelConfig = serde_json::from_str(&json).expect("deserialize");
368        assert_eq!(deser.base_url, "https://abc.trycloudflare.com");
369        assert_eq!(deser.mode, "quick");
370    }
371
372    #[test]
373    fn test_tunnel_config_default_mode() {
374        let json = r#"{"base_url": "https://example.com"}"#;
375        let config: TunnelConfig = serde_json::from_str(json).expect("deserialize");
376        assert_eq!(config.mode, "manual");
377    }
378
379    #[test]
380    fn test_provider_hash() {
381        let mut set = std::collections::HashSet::new();
382        set.insert(Provider::Anthropic);
383        set.insert(Provider::Google);
384        set.insert(Provider::Anthropic); // duplicate
385        assert_eq!(set.len(), 2);
386    }
387}