1use serde::{Deserialize, Serialize};
9use std::env;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ProxyConfig {
15 pub transport: TransportConfig,
16 pub mpl: MplConfig,
17 pub observability: ObservabilityConfig,
18 #[serde(default)]
19 pub routing: Vec<RouteConfig>,
20 #[serde(default)]
21 pub limits: ResourceLimits,
22}
23
24impl Default for ProxyConfig {
25 fn default() -> Self {
26 Self {
27 transport: TransportConfig::default(),
28 mpl: MplConfig::default(),
29 observability: ObservabilityConfig::default(),
30 routing: Vec::new(),
31 limits: ResourceLimits::default(),
32 }
33 }
34}
35
36impl ProxyConfig {
37 pub fn load<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
39 let contents = std::fs::read_to_string(path)?;
40 let config: Self = serde_yaml::from_str(&contents)?;
41 Ok(config)
42 }
43
44 pub fn load_with_env<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
59 let mut config = Self::load(path).unwrap_or_default();
60 config.apply_env_overrides();
61 Ok(config)
62 }
63
64 pub fn apply_env_overrides(&mut self) {
66 if let Ok(val) = env::var("MPL_LISTEN") {
68 self.transport.listen = val;
69 }
70 if let Ok(val) = env::var("MPL_UPSTREAM") {
71 self.transport.upstream = val;
72 }
73 if let Ok(val) = env::var("MPL_CONNECT_TIMEOUT_MS") {
74 if let Ok(ms) = val.parse() {
75 self.transport.connect_timeout_ms = ms;
76 }
77 }
78 if let Ok(val) = env::var("MPL_REQUEST_TIMEOUT_MS") {
79 if let Ok(ms) = val.parse() {
80 self.transport.request_timeout_ms = ms;
81 }
82 }
83
84 if let Ok(val) = env::var("MPL_REGISTRY") {
86 self.mpl.registry = val;
87 }
88 if let Ok(val) = env::var("MPL_MODE") {
89 self.mpl.mode = match val.to_lowercase().as_str() {
90 "strict" => ProxyMode::Strict,
91 _ => ProxyMode::Transparent,
92 };
93 }
94 if let Ok(val) = env::var("MPL_PROFILE") {
95 self.mpl.required_profile = Some(val);
96 }
97 if let Ok(val) = env::var("MPL_ENFORCE_SCHEMA") {
98 self.mpl.enforce_schema = val.to_lowercase() == "true";
99 }
100 if let Ok(val) = env::var("MPL_ENFORCE_ASSERTIONS") {
101 self.mpl.enforce_assertions = val.to_lowercase() == "true";
102 }
103
104 if let Ok(val) = env::var("MPL_METRICS_PORT") {
106 if let Ok(port) = val.parse() {
107 self.observability.metrics_port = Some(port);
108 }
109 }
110 if let Ok(val) = env::var("MPL_LOG_LEVEL") {
111 self.observability.log_level = match val.to_lowercase().as_str() {
112 "trace" => LogLevel::Trace,
113 "debug" => LogLevel::Debug,
114 "warn" => LogLevel::Warn,
115 "error" => LogLevel::Error,
116 _ => LogLevel::Info,
117 };
118 }
119
120 if let Ok(val) = env::var("MPL_MAX_CONNECTIONS") {
122 if let Ok(n) = val.parse() {
123 self.limits.max_connections = n;
124 }
125 }
126 if let Ok(val) = env::var("MPL_RATE_LIMIT") {
127 if let Ok(n) = val.parse() {
128 self.limits.rate_limit_per_second = n;
129 }
130 }
131 }
132
133 pub fn save<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
135 let contents = serde_yaml::to_string(self)?;
136 std::fs::write(path, contents)?;
137 Ok(())
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct TransportConfig {
144 pub listen: String,
146
147 pub upstream: String,
149
150 #[serde(default)]
152 pub protocol: Protocol,
153
154 #[serde(default = "default_connect_timeout")]
156 pub connect_timeout_ms: u64,
157
158 #[serde(default = "default_request_timeout")]
160 pub request_timeout_ms: u64,
161
162 #[serde(default = "default_idle_timeout")]
164 pub idle_timeout_ms: u64,
165
166 #[serde(default = "default_max_retries")]
168 pub max_retries: u32,
169
170 #[serde(default = "default_max_body_size")]
172 pub max_body_size: usize,
173}
174
175fn default_connect_timeout() -> u64 {
176 5000 }
178
179fn default_request_timeout() -> u64 {
180 30000 }
182
183fn default_idle_timeout() -> u64 {
184 60000 }
186
187fn default_max_retries() -> u32 {
188 3
189}
190
191fn default_max_body_size() -> usize {
192 10 * 1024 * 1024 }
194
195impl Default for TransportConfig {
196 fn default() -> Self {
197 Self {
198 listen: "0.0.0.0:9443".to_string(),
199 upstream: "localhost:8080".to_string(),
200 protocol: Protocol::Http,
201 connect_timeout_ms: default_connect_timeout(),
202 request_timeout_ms: default_request_timeout(),
203 idle_timeout_ms: default_idle_timeout(),
204 max_retries: default_max_retries(),
205 max_body_size: default_max_body_size(),
206 }
207 }
208}
209
210impl TransportConfig {
211 pub fn connect_timeout(&self) -> std::time::Duration {
213 std::time::Duration::from_millis(self.connect_timeout_ms)
214 }
215
216 pub fn request_timeout(&self) -> std::time::Duration {
218 std::time::Duration::from_millis(self.request_timeout_ms)
219 }
220
221 pub fn idle_timeout(&self) -> std::time::Duration {
223 std::time::Duration::from_millis(self.idle_timeout_ms)
224 }
225}
226
227#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
229#[serde(rename_all = "lowercase")]
230pub enum Protocol {
231 #[default]
232 Http,
233 WebSocket,
234 Grpc,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct MplConfig {
240 #[serde(default = "default_registry")]
242 pub registry: String,
243
244 #[serde(default)]
246 pub mode: ProxyMode,
247
248 pub required_profile: Option<String>,
250
251 #[serde(default = "default_true")]
253 pub enforce_schema: bool,
254
255 #[serde(default = "default_true")]
257 pub enforce_assertions: bool,
258
259 #[serde(default)]
261 pub policy_engine: bool,
262}
263
264fn default_registry() -> String {
265 "https://github.com/Skelf-Research/mpl/raw/main/registry".to_string()
266}
267
268fn default_true() -> bool {
269 true
270}
271
272impl Default for MplConfig {
273 fn default() -> Self {
274 Self {
275 registry: default_registry(),
276 mode: ProxyMode::Transparent,
277 required_profile: Some("qom-basic".to_string()),
278 enforce_schema: true,
279 enforce_assertions: true,
280 policy_engine: false,
281 }
282 }
283}
284
285#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
287#[serde(rename_all = "lowercase")]
288pub enum ProxyMode {
289 #[default]
291 Transparent,
292 Strict,
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ObservabilityConfig {
299 pub metrics_port: Option<u16>,
301
302 #[serde(default)]
304 pub metrics_format: MetricsFormat,
305
306 #[serde(default)]
308 pub logs: LogOutput,
309
310 #[serde(default)]
312 pub log_format: LogFormat,
313
314 #[serde(default)]
316 pub log_level: LogLevel,
317}
318
319impl Default for ObservabilityConfig {
320 fn default() -> Self {
321 Self {
322 metrics_port: Some(9100),
323 metrics_format: MetricsFormat::Prometheus,
324 logs: LogOutput::Stdout,
325 log_format: LogFormat::Json,
326 log_level: LogLevel::Info,
327 }
328 }
329}
330
331#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
332#[serde(rename_all = "lowercase")]
333pub enum MetricsFormat {
334 #[default]
335 Prometheus,
336 OpenTelemetry,
337}
338
339#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
340#[serde(rename_all = "lowercase")]
341pub enum LogOutput {
342 #[default]
343 Stdout,
344 Stderr,
345 File,
346}
347
348#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
349#[serde(rename_all = "lowercase")]
350pub enum LogFormat {
351 #[default]
352 Json,
353 Text,
354}
355
356#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
357#[serde(rename_all = "lowercase")]
358pub enum LogLevel {
359 Trace,
360 Debug,
361 #[default]
362 Info,
363 Warn,
364 Error,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct RouteConfig {
370 pub stype_pattern: String,
372
373 pub upstream: String,
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct ResourceLimits {
380 #[serde(default = "default_max_connections")]
382 pub max_connections: usize,
383
384 #[serde(default = "default_rate_limit")]
386 pub rate_limit_per_second: u32,
387
388 #[serde(default = "default_burst_size")]
390 pub burst_size: u32,
391
392 #[serde(default = "default_max_pending")]
394 pub max_pending_requests: usize,
395
396 #[serde(default = "default_failure_threshold")]
398 pub failure_threshold: u32,
399
400 #[serde(default = "default_recovery_time")]
402 pub recovery_time_ms: u64,
403}
404
405fn default_max_connections() -> usize {
406 10000
407}
408
409fn default_rate_limit() -> u32 {
410 100
411}
412
413fn default_burst_size() -> u32 {
414 50
415}
416
417fn default_max_pending() -> usize {
418 1000
419}
420
421fn default_failure_threshold() -> u32 {
422 5
423}
424
425fn default_recovery_time() -> u64 {
426 30000 }
428
429impl Default for ResourceLimits {
430 fn default() -> Self {
431 Self {
432 max_connections: default_max_connections(),
433 rate_limit_per_second: default_rate_limit(),
434 burst_size: default_burst_size(),
435 max_pending_requests: default_max_pending(),
436 failure_threshold: default_failure_threshold(),
437 recovery_time_ms: default_recovery_time(),
438 }
439 }
440}
441
442impl ResourceLimits {
443 pub fn recovery_time(&self) -> std::time::Duration {
445 std::time::Duration::from_millis(self.recovery_time_ms)
446 }
447}