Skip to main content

ohttp_gateway/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3use std::time::Duration;
4
5#[derive(Clone, Debug, Deserialize, Serialize)]
6pub struct AppConfig {
7    // Server configuration
8    pub port: String,
9    pub backend_url: String,
10    pub request_timeout: Duration,
11    pub max_body_size: usize,
12
13    // Key management
14    pub key_rotation_interval: Duration,
15    pub key_retention_period: Duration,
16    pub key_rotation_enabled: bool,
17
18    // Security configuration
19    pub allowed_target_origins: Option<HashSet<String>>,
20    pub target_rewrites: Option<TargetRewriteConfig>,
21    pub rate_limit: Option<RateLimitConfig>,
22
23    // Operational configuration
24    pub metrics_enabled: bool,
25    pub debug_mode: bool,
26    pub log_format: LogFormat,
27    pub log_level: String,
28
29    // OHTTP specific
30    pub custom_request_type: Option<String>,
31    pub custom_response_type: Option<String>,
32    pub seed_secret_key: Option<String>,
33}
34
35#[derive(Clone, Debug, Deserialize, Serialize)]
36pub struct TargetRewriteConfig {
37    pub rewrites: std::collections::HashMap<String, TargetRewrite>,
38}
39
40#[derive(Clone, Debug, Deserialize, Serialize)]
41pub struct TargetRewrite {
42    pub scheme: String,
43    pub host: String,
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize)]
47pub struct RateLimitConfig {
48    pub requests_per_second: u32,
49    pub burst_size: u32,
50    pub by_ip: bool,
51}
52
53#[derive(Clone, Debug, Deserialize, Serialize)]
54#[serde(rename_all = "lowercase")]
55pub enum LogFormat {
56    Default,
57    Json,
58}
59
60impl Default for AppConfig {
61    fn default() -> Self {
62        Self {
63            port: "0.0.0.0:8000".to_string(),
64            backend_url: "http://localhost:8080".to_string(),
65            request_timeout: Duration::from_secs(30),
66            max_body_size: 10 * 1024 * 1024, // 10MB
67            key_rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
68            key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
69            key_rotation_enabled: true,
70            allowed_target_origins: None,
71            target_rewrites: None,
72            rate_limit: None,
73            metrics_enabled: true,
74            debug_mode: false,
75            log_format: LogFormat::Default,
76            log_level: "info".to_string(),
77            custom_request_type: None,
78            custom_response_type: None,
79            seed_secret_key: None,
80        }
81    }
82}
83
84impl AppConfig {
85    pub fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
86        let mut config = Self::default();
87
88        // Basic configuration
89        if let Ok(port) = std::env::var("PORT") {
90            config.port = format!("0.0.0.0:{port}");
91        }
92
93        if let Ok(url) = std::env::var("BACKEND_URL") {
94            config.backend_url = url;
95        }
96
97        if let Ok(timeout) = std::env::var("REQUEST_TIMEOUT") {
98            config.request_timeout = Duration::from_secs(timeout.parse()?);
99        }
100
101        if let Ok(size) = std::env::var("MAX_BODY_SIZE") {
102            config.max_body_size = size.parse()?;
103        }
104
105        // Key management
106        if let Ok(interval) = std::env::var("KEY_ROTATION_INTERVAL") {
107            config.key_rotation_interval = Duration::from_secs(interval.parse()?);
108        }
109
110        if let Ok(period) = std::env::var("KEY_RETENTION_PERIOD") {
111            config.key_retention_period = Duration::from_secs(period.parse()?);
112        }
113
114        if let Ok(enabled) = std::env::var("KEY_ROTATION_ENABLED") {
115            config.key_rotation_enabled = enabled.parse()?;
116        }
117
118        // Security configuration
119        if let Ok(origins) = std::env::var("ALLOWED_TARGET_ORIGINS") {
120            let origins_set: HashSet<String> = origins
121                .split(',')
122                .map(|s| s.trim().to_string())
123                .filter(|s| !s.is_empty())
124                .collect();
125
126            if !origins_set.is_empty() {
127                config.allowed_target_origins = Some(origins_set);
128            }
129        }
130
131        if let Ok(rewrites_json) = std::env::var("TARGET_REWRITES") {
132            let rewrites: std::collections::HashMap<String, TargetRewrite> =
133                serde_json::from_str(&rewrites_json)?;
134            config.target_rewrites = Some(TargetRewriteConfig { rewrites });
135        }
136
137        // Rate limiting
138        if let Ok(rps) = std::env::var("RATE_LIMIT_RPS") {
139            let rate_limit = RateLimitConfig {
140                requests_per_second: rps.parse()?,
141                burst_size: std::env::var("RATE_LIMIT_BURST")
142                    .ok()
143                    .and_then(|s| s.parse().ok())
144                    .unwrap_or(100),
145                by_ip: std::env::var("RATE_LIMIT_BY_IP")
146                    .ok()
147                    .and_then(|s| s.parse().ok())
148                    .unwrap_or(true),
149            };
150            config.rate_limit = Some(rate_limit);
151        }
152
153        // Operational configuration
154        if let Ok(enabled) = std::env::var("METRICS_ENABLED") {
155            config.metrics_enabled = enabled.parse()?;
156        }
157
158        if let Ok(debug) = std::env::var("GATEWAY_DEBUG") {
159            config.debug_mode = debug.parse()?;
160        }
161
162        if let Ok(format) = std::env::var("LOG_FORMAT") {
163            config.log_format = match format.to_lowercase().as_str() {
164                "json" => LogFormat::Json,
165                _ => LogFormat::Default,
166            };
167        }
168
169        if let Ok(level) = std::env::var("LOG_LEVEL") {
170            config.log_level = level;
171        }
172
173        // OHTTP specific
174        if let Ok(req_type) = std::env::var("CUSTOM_REQUEST_TYPE") {
175            config.custom_request_type = Some(req_type);
176        }
177
178        if let Ok(resp_type) = std::env::var("CUSTOM_RESPONSE_TYPE") {
179            config.custom_response_type = Some(resp_type);
180        }
181
182        if let Ok(seed) = std::env::var("SEED_SECRET_KEY") {
183            config.seed_secret_key = Some(seed);
184        }
185
186        // Validate configuration
187        config.validate()?;
188
189        Ok(config)
190    }
191
192    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
193        // Validate key rotation settings
194        if self.key_retention_period > self.key_rotation_interval {
195            return Err("Key retention period cannot be longer than rotation interval".into());
196        }
197
198        // Validate custom content types
199        match (&self.custom_request_type, &self.custom_response_type) {
200            (Some(req), Some(resp)) if req == resp => {
201                return Err("Request and response content types must be different".into());
202            }
203            (Some(_), None) | (None, Some(_)) => {
204                return Err("Both custom request and response types must be specified".into());
205            }
206            _ => {}
207        }
208
209        // Validate seed if provided
210        if let Some(seed) = &self.seed_secret_key {
211            let decoded =
212                hex::decode(seed).map_err(|_| "SEED_SECRET_KEY must be a hex-encoded string")?;
213
214            if decoded.len() < 32 {
215                return Err("SEED_SECRET_KEY must be at least 32 bytes (64 hex characters)".into());
216            }
217        }
218
219        Ok(())
220    }
221
222    /// Check if a target origin is allowed
223    pub fn is_origin_allowed(&self, origin: &str) -> bool {
224        match &self.allowed_target_origins {
225            Some(allowed) => allowed.contains(origin),
226            None => true, // No restrictions if not configured
227        }
228    }
229
230    /// Get rewrite configuration for a host
231    pub fn get_rewrite(&self, host: &str) -> Option<&TargetRewrite> {
232        self.target_rewrites
233            .as_ref()
234            .and_then(|config| config.rewrites.get(host))
235    }
236}