Skip to main content

punch_types/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Top-level Punch configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PunchConfig {
7    /// Socket address for the Arena API server (e.g. "127.0.0.1:6660").
8    /// Use 127.0.0.1 for local-only access. Only bind to 0.0.0.0 if you
9    /// need external access AND have authentication configured.
10    pub api_listen: String,
11    /// API key for authentication. If empty, auth is disabled (dev mode).
12    #[serde(default)]
13    pub api_key: String,
14    /// Per-IP rate limit in requests per minute. Default: 60.
15    #[serde(default = "default_rate_limit_rpm")]
16    pub rate_limit_rpm: u32,
17    /// Default model to use when none is specified.
18    pub default_model: ModelConfig,
19    /// Memory subsystem configuration.
20    pub memory: MemoryConfig,
21    /// Tunnel / public URL configuration shared by all channel webhooks.
22    #[serde(default)]
23    pub tunnel: Option<TunnelConfig>,
24    /// Channel configurations keyed by channel name (e.g. "slack", "discord").
25    #[serde(default)]
26    pub channels: HashMap<String, ChannelConfig>,
27    /// MCP server definitions keyed by server name.
28    #[serde(default)]
29    pub mcp_servers: HashMap<String, McpServerConfig>,
30    /// Smart model routing configuration. When enabled, messages are routed
31    /// to cheap / mid / expensive models based on query complexity.
32    #[serde(default)]
33    pub model_routing: ModelRoutingConfig,
34    /// Budget limits and eco mode configuration. When set, fighters enter
35    /// eco mode (cheap tier, no reflection, minimal tools) as limits approach.
36    #[serde(default)]
37    pub budget: BudgetConfig,
38}
39
40/// Configuration for a language model.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43    /// The provider to use.
44    pub provider: Provider,
45    /// Model identifier (e.g. "claude-sonnet-4-20250514").
46    pub model: String,
47    /// Environment variable name that holds the API key.
48    pub api_key_env: Option<String>,
49    /// Optional base URL override for the provider API.
50    pub base_url: Option<String>,
51    /// Maximum tokens to generate per request.
52    pub max_tokens: Option<u32>,
53    /// Sampling temperature.
54    pub temperature: Option<f32>,
55}
56
57/// Supported model providers.
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum Provider {
61    Anthropic,
62    Google,
63    #[serde(rename = "openai")]
64    OpenAI,
65    Groq,
66    DeepSeek,
67    Ollama,
68    Mistral,
69    Together,
70    Fireworks,
71    Cerebras,
72    #[serde(rename = "xai")]
73    XAI,
74    Cohere,
75    Bedrock,
76    #[serde(rename = "azure_openai")]
77    AzureOpenAi,
78    Custom(String),
79}
80
81impl std::fmt::Display for Provider {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            Self::Anthropic => write!(f, "anthropic"),
85            Self::Google => write!(f, "google"),
86            Self::OpenAI => write!(f, "openai"),
87            Self::Groq => write!(f, "groq"),
88            Self::DeepSeek => write!(f, "deepseek"),
89            Self::Ollama => write!(f, "ollama"),
90            Self::Mistral => write!(f, "mistral"),
91            Self::Together => write!(f, "together"),
92            Self::Fireworks => write!(f, "fireworks"),
93            Self::Cerebras => write!(f, "cerebras"),
94            Self::XAI => write!(f, "xai"),
95            Self::Cohere => write!(f, "cohere"),
96            Self::Bedrock => write!(f, "bedrock"),
97            Self::AzureOpenAi => write!(f, "azure_openai"),
98            Self::Custom(name) => write!(f, "custom({})", name),
99        }
100    }
101}
102
103/// Memory subsystem configuration.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct MemoryConfig {
106    /// Path to the SQLite database file.
107    pub db_path: String,
108    /// Whether to enable the knowledge graph.
109    #[serde(default = "default_true")]
110    pub knowledge_graph_enabled: bool,
111    /// Maximum number of memory entries to retain.
112    pub max_entries: Option<u64>,
113}
114
115/// Configuration for the public tunnel / base URL used by all channel webhooks.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TunnelConfig {
118    /// The public base URL that all channel webhooks share
119    /// (e.g. "https://abc.trycloudflare.com" or "https://channels.yourdomain.com").
120    pub base_url: String,
121    /// How this tunnel was set up: "quick", "named", or "manual".
122    #[serde(default = "default_tunnel_mode")]
123    pub mode: String,
124}
125
126fn default_tunnel_mode() -> String {
127    "manual".to_string()
128}
129
130/// Configuration for a communication channel.
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ChannelConfig {
133    /// Channel type identifier (e.g. "slack", "discord", "webhook").
134    pub channel_type: String,
135    /// Environment variable holding the authentication token.
136    pub token_env: Option<String>,
137    /// Environment variable holding the webhook signing secret (for signature verification).
138    /// Slack: signing secret. Telegram: secret_token header value. Discord: public key.
139    pub webhook_secret_env: Option<String>,
140    /// Allowlisted user/chat IDs. Only these users can interact with fighters.
141    /// Empty list = open access (logs a security warning on startup).
142    #[serde(default)]
143    pub allowed_user_ids: Vec<String>,
144    /// Per-user rate limit in messages per minute. Default: 20.
145    #[serde(default = "default_channel_rate_limit")]
146    pub rate_limit_per_user: u32,
147    /// Additional channel-specific settings.
148    #[serde(default)]
149    pub settings: HashMap<String, serde_json::Value>,
150}
151
152fn default_channel_rate_limit() -> u32 {
153    20
154}
155
156/// Configuration for an MCP (Model Context Protocol) server.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct McpServerConfig {
159    /// Command to start the MCP server.
160    pub command: String,
161    /// Arguments to pass to the command.
162    #[serde(default)]
163    pub args: Vec<String>,
164    /// Environment variables to set for the MCP server process.
165    #[serde(default)]
166    pub env: HashMap<String, String>,
167}
168
169fn default_true() -> bool {
170    true
171}
172
173fn default_rate_limit_rpm() -> u32 {
174    60
175}
176
177/// Configuration for smart model routing based on query complexity.
178///
179/// When enabled, messages are classified into tiers (cheap / mid / expensive)
180/// using keyword heuristics, and routed to the appropriate model. If a tier's
181/// model is not configured, the default model is used as fallback.
182#[derive(Debug, Clone, Serialize, Deserialize, Default)]
183pub struct ModelRoutingConfig {
184    /// Whether model routing is enabled. When `false`, the default model is
185    /// used for all messages (backward-compatible default).
186    #[serde(default)]
187    pub enabled: bool,
188    /// Model for simple messages (greetings, yes/no answers). Cheap nano-tier.
189    pub cheap: Option<ModelConfig>,
190    /// Model for tool-calling messages (search, email, calendar, etc.).
191    pub mid: Option<ModelConfig>,
192    /// Model for complex reasoning (analysis, comparison, code review, etc.).
193    pub expensive: Option<ModelConfig>,
194}
195
196/// Budget limits and eco mode configuration.
197///
198/// When budget limits are set, fighters automatically enter "eco mode" as
199/// spending approaches the configured thresholds. Eco mode degrades to:
200/// cheap model tier only, no post-bout reflection, minimal tool loading.
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct BudgetConfig {
203    /// Maximum daily cost in USD. When approached, eco mode activates.
204    pub daily_cost_limit_usd: Option<f64>,
205    /// Maximum monthly cost in USD (30-day rolling window).
206    pub monthly_cost_limit_usd: Option<f64>,
207    /// Percentage of any limit at which eco mode activates (default: 80).
208    /// At 100% the fighter is fully blocked.
209    #[serde(default = "default_eco_threshold")]
210    pub eco_mode_threshold_percent: u8,
211}
212
213fn default_eco_threshold() -> u8 {
214    80
215}
216
217impl Default for BudgetConfig {
218    fn default() -> Self {
219        Self {
220            daily_cost_limit_usd: None,
221            monthly_cost_limit_usd: None,
222            eco_mode_threshold_percent: default_eco_threshold(),
223        }
224    }
225}
226
227impl BudgetConfig {
228    /// Returns true if any budget limit is configured.
229    pub fn has_any_limit(&self) -> bool {
230        self.daily_cost_limit_usd.is_some() || self.monthly_cost_limit_usd.is_some()
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_provider_display_all_variants() {
240        assert_eq!(Provider::Anthropic.to_string(), "anthropic");
241        assert_eq!(Provider::Google.to_string(), "google");
242        assert_eq!(Provider::OpenAI.to_string(), "openai");
243        assert_eq!(Provider::Groq.to_string(), "groq");
244        assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
245        assert_eq!(Provider::Ollama.to_string(), "ollama");
246        assert_eq!(Provider::Mistral.to_string(), "mistral");
247        assert_eq!(Provider::Together.to_string(), "together");
248        assert_eq!(Provider::Fireworks.to_string(), "fireworks");
249        assert_eq!(Provider::Cerebras.to_string(), "cerebras");
250        assert_eq!(Provider::XAI.to_string(), "xai");
251        assert_eq!(Provider::Cohere.to_string(), "cohere");
252        assert_eq!(Provider::Bedrock.to_string(), "bedrock");
253        assert_eq!(Provider::AzureOpenAi.to_string(), "azure_openai");
254        assert_eq!(
255            Provider::Custom("my_provider".to_string()).to_string(),
256            "custom(my_provider)"
257        );
258    }
259
260    #[test]
261    fn test_provider_serde_roundtrip() {
262        let providers = vec![
263            Provider::Anthropic,
264            Provider::Google,
265            Provider::OpenAI,
266            Provider::Groq,
267            Provider::DeepSeek,
268            Provider::Ollama,
269            Provider::Mistral,
270            Provider::Together,
271            Provider::Fireworks,
272            Provider::Cerebras,
273            Provider::XAI,
274            Provider::Cohere,
275            Provider::Bedrock,
276            Provider::AzureOpenAi,
277            Provider::Custom("test".to_string()),
278        ];
279        for provider in &providers {
280            let json = serde_json::to_string(provider).expect("serialize provider");
281            let deser: Provider = serde_json::from_str(&json).expect("deserialize provider");
282            assert_eq!(&deser, provider);
283        }
284    }
285
286    #[test]
287    fn test_provider_serde_rename() {
288        assert_eq!(
289            serde_json::to_string(&Provider::OpenAI).unwrap(),
290            "\"openai\""
291        );
292        assert_eq!(serde_json::to_string(&Provider::XAI).unwrap(), "\"xai\"");
293        assert_eq!(
294            serde_json::to_string(&Provider::AzureOpenAi).unwrap(),
295            "\"azure_openai\""
296        );
297        assert_eq!(
298            serde_json::to_string(&Provider::Anthropic).unwrap(),
299            "\"anthropic\""
300        );
301    }
302
303    #[test]
304    fn test_provider_equality() {
305        assert_eq!(Provider::Anthropic, Provider::Anthropic);
306        assert_ne!(Provider::Anthropic, Provider::Google);
307        assert_eq!(
308            Provider::Custom("x".to_string()),
309            Provider::Custom("x".to_string())
310        );
311        assert_ne!(
312            Provider::Custom("x".to_string()),
313            Provider::Custom("y".to_string())
314        );
315    }
316
317    #[test]
318    fn test_model_config_serde_roundtrip() {
319        let config = ModelConfig {
320            provider: Provider::Anthropic,
321            model: "claude-sonnet-4-20250514".to_string(),
322            api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
323            base_url: None,
324            max_tokens: Some(4096),
325            temperature: Some(0.7),
326        };
327        let json = serde_json::to_string(&config).expect("serialize");
328        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
329        assert_eq!(deser.model, "claude-sonnet-4-20250514");
330        assert_eq!(deser.provider, Provider::Anthropic);
331        assert_eq!(deser.max_tokens, Some(4096));
332        assert_eq!(deser.temperature, Some(0.7));
333    }
334
335    #[test]
336    fn test_model_config_optional_fields() {
337        let config = ModelConfig {
338            provider: Provider::Ollama,
339            model: "llama3".to_string(),
340            api_key_env: None,
341            base_url: Some("http://localhost:11434".to_string()),
342            max_tokens: None,
343            temperature: None,
344        };
345        let json = serde_json::to_string(&config).expect("serialize");
346        let deser: ModelConfig = serde_json::from_str(&json).expect("deserialize");
347        assert!(deser.api_key_env.is_none());
348        assert_eq!(deser.base_url, Some("http://localhost:11434".to_string()));
349        assert!(deser.max_tokens.is_none());
350        assert!(deser.temperature.is_none());
351    }
352
353    #[test]
354    fn test_memory_config_defaults() {
355        let json = r#"{"db_path": "/tmp/punch.db"}"#;
356        let config: MemoryConfig = serde_json::from_str(json).expect("deserialize");
357        assert_eq!(config.db_path, "/tmp/punch.db");
358        assert!(config.knowledge_graph_enabled); // default_true
359        assert!(config.max_entries.is_none());
360    }
361
362    #[test]
363    fn test_memory_config_roundtrip() {
364        let config = MemoryConfig {
365            db_path: "/data/punch.db".to_string(),
366            knowledge_graph_enabled: false,
367            max_entries: Some(10000),
368        };
369        let json = serde_json::to_string(&config).expect("serialize");
370        let deser: MemoryConfig = serde_json::from_str(&json).expect("deserialize");
371        assert_eq!(deser.db_path, "/data/punch.db");
372        assert!(!deser.knowledge_graph_enabled);
373        assert_eq!(deser.max_entries, Some(10000));
374    }
375
376    #[test]
377    fn test_channel_config_serde() {
378        let config = ChannelConfig {
379            channel_type: "slack".to_string(),
380            token_env: Some("SLACK_TOKEN".to_string()),
381            webhook_secret_env: Some("SLACK_SIGNING_SECRET".to_string()),
382            allowed_user_ids: vec!["U123".to_string()],
383            rate_limit_per_user: 20,
384            settings: HashMap::from([("channel_id".to_string(), serde_json::json!("C123456"))]),
385        };
386        let json = serde_json::to_string(&config).expect("serialize");
387        let deser: ChannelConfig = serde_json::from_str(&json).expect("deserialize");
388        assert_eq!(deser.channel_type, "slack");
389        assert_eq!(deser.token_env, Some("SLACK_TOKEN".to_string()));
390        assert_eq!(deser.settings["channel_id"], "C123456");
391    }
392
393    #[test]
394    fn test_channel_config_defaults() {
395        let json = r#"{"channel_type": "webhook"}"#;
396        let config: ChannelConfig = serde_json::from_str(json).expect("deserialize");
397        assert_eq!(config.channel_type, "webhook");
398        assert!(config.token_env.is_none());
399        assert!(config.settings.is_empty());
400    }
401
402    #[test]
403    fn test_mcp_server_config_serde() {
404        let config = McpServerConfig {
405            command: "npx".to_string(),
406            args: vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
407            env: HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
408        };
409        let json = serde_json::to_string(&config).expect("serialize");
410        let deser: McpServerConfig = serde_json::from_str(&json).expect("deserialize");
411        assert_eq!(deser.command, "npx");
412        assert_eq!(deser.args.len(), 2);
413        assert_eq!(deser.env["NODE_ENV"], "production");
414    }
415
416    #[test]
417    fn test_mcp_server_config_defaults() {
418        let json = r#"{"command": "python"}"#;
419        let config: McpServerConfig = serde_json::from_str(json).expect("deserialize");
420        assert_eq!(config.command, "python");
421        assert!(config.args.is_empty());
422        assert!(config.env.is_empty());
423    }
424
425    #[test]
426    fn test_tunnel_config_serde() {
427        let config = TunnelConfig {
428            base_url: "https://abc.trycloudflare.com".to_string(),
429            mode: "quick".to_string(),
430        };
431        let json = serde_json::to_string(&config).expect("serialize");
432        let deser: TunnelConfig = serde_json::from_str(&json).expect("deserialize");
433        assert_eq!(deser.base_url, "https://abc.trycloudflare.com");
434        assert_eq!(deser.mode, "quick");
435    }
436
437    #[test]
438    fn test_tunnel_config_default_mode() {
439        let json = r#"{"base_url": "https://example.com"}"#;
440        let config: TunnelConfig = serde_json::from_str(json).expect("deserialize");
441        assert_eq!(config.mode, "manual");
442    }
443
444    #[test]
445    fn test_provider_hash() {
446        let mut set = std::collections::HashSet::new();
447        set.insert(Provider::Anthropic);
448        set.insert(Provider::Google);
449        set.insert(Provider::Anthropic); // duplicate
450        assert_eq!(set.len(), 2);
451    }
452
453    #[test]
454    fn test_budget_config_defaults() {
455        let config = BudgetConfig::default();
456        assert!(!config.has_any_limit());
457        assert!(config.daily_cost_limit_usd.is_none());
458        assert!(config.monthly_cost_limit_usd.is_none());
459        assert_eq!(config.eco_mode_threshold_percent, 80);
460    }
461
462    #[test]
463    fn test_budget_config_has_limit() {
464        let config = BudgetConfig {
465            daily_cost_limit_usd: Some(1.0),
466            ..Default::default()
467        };
468        assert!(config.has_any_limit());
469    }
470
471    #[test]
472    fn test_budget_config_serde() {
473        let json = r#"{"daily_cost_limit_usd": 5.0, "monthly_cost_limit_usd": 15.0, "eco_mode_threshold_percent": 70}"#;
474        let config: BudgetConfig = serde_json::from_str(json).expect("deserialize");
475        assert_eq!(config.daily_cost_limit_usd, Some(5.0));
476        assert_eq!(config.monthly_cost_limit_usd, Some(15.0));
477        assert_eq!(config.eco_mode_threshold_percent, 70);
478        assert!(config.has_any_limit());
479    }
480
481    #[test]
482    fn test_budget_config_serde_empty() {
483        let json = "{}";
484        let config: BudgetConfig = serde_json::from_str(json).expect("deserialize");
485        assert!(!config.has_any_limit());
486        assert_eq!(config.eco_mode_threshold_percent, 80);
487    }
488}