mpl_proxy/
config.rs

1//! Proxy configuration
2//!
3//! Configuration can be loaded from:
4//! 1. YAML file (default: mpl-config.yaml)
5//! 2. Environment variables (MPL_* prefix)
6//! 3. CLI arguments (highest priority)
7
8use serde::{Deserialize, Serialize};
9use std::env;
10use std::path::Path;
11
12/// Main proxy configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ProxyConfig {
15    pub transport: TransportConfig,
16    pub mpl: MplConfig,
17    pub observability: ObservabilityConfig,
18    #[serde(default)]
19    pub routing: Vec<RouteConfig>,
20    #[serde(default)]
21    pub limits: ResourceLimits,
22}
23
24impl Default for ProxyConfig {
25    fn default() -> Self {
26        Self {
27            transport: TransportConfig::default(),
28            mpl: MplConfig::default(),
29            observability: ObservabilityConfig::default(),
30            routing: Vec::new(),
31            limits: ResourceLimits::default(),
32        }
33    }
34}
35
36impl ProxyConfig {
37    /// Load configuration from a YAML file
38    pub fn load<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
39        let contents = std::fs::read_to_string(path)?;
40        let config: Self = serde_yaml::from_str(&contents)?;
41        Ok(config)
42    }
43
44    /// Load configuration with environment variable overrides
45    ///
46    /// Environment variables (all optional):
47    /// - MPL_LISTEN: Listen address (e.g., "0.0.0.0:9443")
48    /// - MPL_UPSTREAM: Upstream server address
49    /// - MPL_REGISTRY: Registry path or URL
50    /// - MPL_MODE: "transparent" or "strict"
51    /// - MPL_PROFILE: QoM profile name
52    /// - MPL_ENFORCE_SCHEMA: "true" or "false"
53    /// - MPL_ENFORCE_ASSERTIONS: "true" or "false"
54    /// - MPL_CONNECT_TIMEOUT_MS: Connection timeout
55    /// - MPL_REQUEST_TIMEOUT_MS: Request timeout
56    /// - MPL_METRICS_PORT: Metrics server port
57    /// - MPL_LOG_LEVEL: Log level (trace, debug, info, warn, error)
58    pub fn load_with_env<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
59        let mut config = Self::load(path).unwrap_or_default();
60        config.apply_env_overrides();
61        Ok(config)
62    }
63
64    /// Apply environment variable overrides to configuration
65    pub fn apply_env_overrides(&mut self) {
66        // Transport settings
67        if let Ok(val) = env::var("MPL_LISTEN") {
68            self.transport.listen = val;
69        }
70        if let Ok(val) = env::var("MPL_UPSTREAM") {
71            self.transport.upstream = val;
72        }
73        if let Ok(val) = env::var("MPL_CONNECT_TIMEOUT_MS") {
74            if let Ok(ms) = val.parse() {
75                self.transport.connect_timeout_ms = ms;
76            }
77        }
78        if let Ok(val) = env::var("MPL_REQUEST_TIMEOUT_MS") {
79            if let Ok(ms) = val.parse() {
80                self.transport.request_timeout_ms = ms;
81            }
82        }
83
84        // MPL settings
85        if let Ok(val) = env::var("MPL_REGISTRY") {
86            self.mpl.registry = val;
87        }
88        if let Ok(val) = env::var("MPL_MODE") {
89            self.mpl.mode = match val.to_lowercase().as_str() {
90                "strict" => ProxyMode::Strict,
91                _ => ProxyMode::Transparent,
92            };
93        }
94        if let Ok(val) = env::var("MPL_PROFILE") {
95            self.mpl.required_profile = Some(val);
96        }
97        if let Ok(val) = env::var("MPL_ENFORCE_SCHEMA") {
98            self.mpl.enforce_schema = val.to_lowercase() == "true";
99        }
100        if let Ok(val) = env::var("MPL_ENFORCE_ASSERTIONS") {
101            self.mpl.enforce_assertions = val.to_lowercase() == "true";
102        }
103
104        // Observability settings
105        if let Ok(val) = env::var("MPL_METRICS_PORT") {
106            if let Ok(port) = val.parse() {
107                self.observability.metrics_port = Some(port);
108            }
109        }
110        if let Ok(val) = env::var("MPL_LOG_LEVEL") {
111            self.observability.log_level = match val.to_lowercase().as_str() {
112                "trace" => LogLevel::Trace,
113                "debug" => LogLevel::Debug,
114                "warn" => LogLevel::Warn,
115                "error" => LogLevel::Error,
116                _ => LogLevel::Info,
117            };
118        }
119
120        // Resource limits
121        if let Ok(val) = env::var("MPL_MAX_CONNECTIONS") {
122            if let Ok(n) = val.parse() {
123                self.limits.max_connections = n;
124            }
125        }
126        if let Ok(val) = env::var("MPL_RATE_LIMIT") {
127            if let Ok(n) = val.parse() {
128                self.limits.rate_limit_per_second = n;
129            }
130        }
131    }
132
133    /// Save configuration to a YAML file
134    pub fn save<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
135        let contents = serde_yaml::to_string(self)?;
136        std::fs::write(path, contents)?;
137        Ok(())
138    }
139}
140
141/// Transport configuration
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct TransportConfig {
144    /// Listen address (e.g., "0.0.0.0:9443")
145    pub listen: String,
146
147    /// Upstream server address (e.g., "mcp-server:8080")
148    pub upstream: String,
149
150    /// Protocol type
151    #[serde(default)]
152    pub protocol: Protocol,
153
154    /// Connection timeout in milliseconds
155    #[serde(default = "default_connect_timeout")]
156    pub connect_timeout_ms: u64,
157
158    /// Request timeout in milliseconds
159    #[serde(default = "default_request_timeout")]
160    pub request_timeout_ms: u64,
161
162    /// Idle connection timeout in milliseconds
163    #[serde(default = "default_idle_timeout")]
164    pub idle_timeout_ms: u64,
165
166    /// Maximum number of retries for transient failures
167    #[serde(default = "default_max_retries")]
168    pub max_retries: u32,
169
170    /// Maximum request body size in bytes
171    #[serde(default = "default_max_body_size")]
172    pub max_body_size: usize,
173}
174
175fn default_connect_timeout() -> u64 {
176    5000 // 5 seconds
177}
178
179fn default_request_timeout() -> u64 {
180    30000 // 30 seconds
181}
182
183fn default_idle_timeout() -> u64 {
184    60000 // 60 seconds
185}
186
187fn default_max_retries() -> u32 {
188    3
189}
190
191fn default_max_body_size() -> usize {
192    10 * 1024 * 1024 // 10 MB
193}
194
195impl Default for TransportConfig {
196    fn default() -> Self {
197        Self {
198            listen: "0.0.0.0:9443".to_string(),
199            upstream: "localhost:8080".to_string(),
200            protocol: Protocol::Http,
201            connect_timeout_ms: default_connect_timeout(),
202            request_timeout_ms: default_request_timeout(),
203            idle_timeout_ms: default_idle_timeout(),
204            max_retries: default_max_retries(),
205            max_body_size: default_max_body_size(),
206        }
207    }
208}
209
210impl TransportConfig {
211    /// Get connect timeout as Duration
212    pub fn connect_timeout(&self) -> std::time::Duration {
213        std::time::Duration::from_millis(self.connect_timeout_ms)
214    }
215
216    /// Get request timeout as Duration
217    pub fn request_timeout(&self) -> std::time::Duration {
218        std::time::Duration::from_millis(self.request_timeout_ms)
219    }
220
221    /// Get idle timeout as Duration
222    pub fn idle_timeout(&self) -> std::time::Duration {
223        std::time::Duration::from_millis(self.idle_timeout_ms)
224    }
225}
226
227/// Supported protocols
228#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
229#[serde(rename_all = "lowercase")]
230pub enum Protocol {
231    #[default]
232    Http,
233    WebSocket,
234    Grpc,
235}
236
237/// MPL-specific configuration
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct MplConfig {
240    /// Registry URL
241    #[serde(default = "default_registry")]
242    pub registry: String,
243
244    /// Proxy mode
245    #[serde(default)]
246    pub mode: ProxyMode,
247
248    /// Required QoM profile
249    pub required_profile: Option<String>,
250
251    /// Enforce schema validation
252    #[serde(default = "default_true")]
253    pub enforce_schema: bool,
254
255    /// Enforce assertion checks
256    #[serde(default = "default_true")]
257    pub enforce_assertions: bool,
258
259    /// Enable policy engine
260    #[serde(default)]
261    pub policy_engine: bool,
262}
263
264fn default_registry() -> String {
265    "https://github.com/Skelf-Research/mpl/raw/main/registry".to_string()
266}
267
268fn default_true() -> bool {
269    true
270}
271
272impl Default for MplConfig {
273    fn default() -> Self {
274        Self {
275            registry: default_registry(),
276            mode: ProxyMode::Transparent,
277            required_profile: Some("qom-basic".to_string()),
278            enforce_schema: true,
279            enforce_assertions: true,
280            policy_engine: false,
281        }
282    }
283}
284
285/// Proxy mode
286#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
287#[serde(rename_all = "lowercase")]
288pub enum ProxyMode {
289    /// Log only, don't block invalid requests
290    #[default]
291    Transparent,
292    /// Block requests that fail validation
293    Strict,
294}
295
296/// Observability configuration
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ObservabilityConfig {
299    /// Metrics port (Prometheus)
300    pub metrics_port: Option<u16>,
301
302    /// Metrics format
303    #[serde(default)]
304    pub metrics_format: MetricsFormat,
305
306    /// Log output
307    #[serde(default)]
308    pub logs: LogOutput,
309
310    /// Log format
311    #[serde(default)]
312    pub log_format: LogFormat,
313
314    /// Log level
315    #[serde(default)]
316    pub log_level: LogLevel,
317}
318
319impl Default for ObservabilityConfig {
320    fn default() -> Self {
321        Self {
322            metrics_port: Some(9100),
323            metrics_format: MetricsFormat::Prometheus,
324            logs: LogOutput::Stdout,
325            log_format: LogFormat::Json,
326            log_level: LogLevel::Info,
327        }
328    }
329}
330
331#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
332#[serde(rename_all = "lowercase")]
333pub enum MetricsFormat {
334    #[default]
335    Prometheus,
336    OpenTelemetry,
337}
338
339#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
340#[serde(rename_all = "lowercase")]
341pub enum LogOutput {
342    #[default]
343    Stdout,
344    Stderr,
345    File,
346}
347
348#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
349#[serde(rename_all = "lowercase")]
350pub enum LogFormat {
351    #[default]
352    Json,
353    Text,
354}
355
356#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
357#[serde(rename_all = "lowercase")]
358pub enum LogLevel {
359    Trace,
360    Debug,
361    #[default]
362    Info,
363    Warn,
364    Error,
365}
366
367/// Route configuration for SType-based routing
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct RouteConfig {
370    /// SType pattern (e.g., "org.calendar.*")
371    pub stype_pattern: String,
372
373    /// Target upstream for matching requests
374    pub upstream: String,
375}
376
377/// Resource limits configuration
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct ResourceLimits {
380    /// Maximum concurrent connections
381    #[serde(default = "default_max_connections")]
382    pub max_connections: usize,
383
384    /// Maximum requests per second (per client IP)
385    #[serde(default = "default_rate_limit")]
386    pub rate_limit_per_second: u32,
387
388    /// Burst size for rate limiting
389    #[serde(default = "default_burst_size")]
390    pub burst_size: u32,
391
392    /// Maximum pending requests in queue
393    #[serde(default = "default_max_pending")]
394    pub max_pending_requests: usize,
395
396    /// Circuit breaker: failure threshold before opening
397    #[serde(default = "default_failure_threshold")]
398    pub failure_threshold: u32,
399
400    /// Circuit breaker: recovery time in milliseconds
401    #[serde(default = "default_recovery_time")]
402    pub recovery_time_ms: u64,
403}
404
405fn default_max_connections() -> usize {
406    10000
407}
408
409fn default_rate_limit() -> u32 {
410    100
411}
412
413fn default_burst_size() -> u32 {
414    50
415}
416
417fn default_max_pending() -> usize {
418    1000
419}
420
421fn default_failure_threshold() -> u32 {
422    5
423}
424
425fn default_recovery_time() -> u64 {
426    30000 // 30 seconds
427}
428
429impl Default for ResourceLimits {
430    fn default() -> Self {
431        Self {
432            max_connections: default_max_connections(),
433            rate_limit_per_second: default_rate_limit(),
434            burst_size: default_burst_size(),
435            max_pending_requests: default_max_pending(),
436            failure_threshold: default_failure_threshold(),
437            recovery_time_ms: default_recovery_time(),
438        }
439    }
440}
441
442impl ResourceLimits {
443    /// Get recovery time as Duration
444    pub fn recovery_time(&self) -> std::time::Duration {
445        std::time::Duration::from_millis(self.recovery_time_ms)
446    }
447}