Skip to main content

structured_proxy/
config.rs

1//! YAML-based proxy configuration.
2//!
3//! All product-specific behavior is driven by config, not code.
4//! Same binary, different YAML = different product proxy.
5
6use serde::Deserialize;
7use std::path::PathBuf;
8
9/// Top-level proxy configuration (loaded from YAML).
10#[derive(Debug, Clone, Deserialize)]
11pub struct ProxyConfig {
12    /// Upstream gRPC service(s).
13    pub upstream: UpstreamConfig,
14
15    /// Proto descriptor sources.
16    #[serde(default, deserialize_with = "deserialize_descriptor_sources")]
17    pub descriptors: Vec<DescriptorSource>,
18
19    /// Listen addresses.
20    #[serde(default)]
21    pub listen: ListenConfig,
22
23    /// Service identity (for health endpoint, metrics namespace).
24    #[serde(default)]
25    pub service: ServiceConfig,
26
27    /// Path aliases (e.g., /oauth2/* → /v1/oauth2/*).
28    #[serde(default)]
29    pub aliases: Vec<AliasConfig>,
30
31    /// OpenAPI generation.
32    #[serde(default)]
33    pub openapi: Option<OpenApiConfig>,
34
35    /// Auth configuration (JWT, forward auth, AuthZ).
36    #[serde(default)]
37    pub auth: Option<AuthConfig>,
38
39    /// Rate limiting (Shield).
40    #[serde(default)]
41    pub shield: Option<ShieldConfig>,
42
43    /// OIDC discovery (optional — for IdP proxies).
44    #[serde(default)]
45    pub oidc_discovery: Option<OidcDiscoveryConfig>,
46
47    /// Maintenance mode.
48    #[serde(default)]
49    pub maintenance: MaintenanceConfig,
50
51    /// CORS configuration.
52    #[serde(default)]
53    pub cors: CorsConfig,
54
55    /// Logging.
56    #[serde(default)]
57    pub logging: LoggingConfig,
58
59    /// Metrics endpoint classification (path patterns → class labels).
60    #[serde(default)]
61    pub metrics_classes: Vec<MetricsClassConfig>,
62
63    /// Headers to forward from HTTP to gRPC metadata.
64    #[serde(default = "default_forwarded_headers")]
65    pub forwarded_headers: Vec<String>,
66}
67
68fn default_forwarded_headers() -> Vec<String> {
69    vec![
70        "authorization".into(),
71        "dpop".into(),
72        "x-request-id".into(),
73        "x-forwarded-for".into(),
74        "x-forwarded-proto".into(),
75        "x-real-ip".into(),
76        "accept-language".into(),
77        "user-agent".into(),
78        "idempotency-key".into(),
79    ]
80}
81
82/// Upstream gRPC service configuration.
83#[derive(Debug, Clone, Deserialize)]
84pub struct UpstreamConfig {
85    /// gRPC upstream address (e.g., "http://localhost:4180").
86    pub default: String,
87}
88
89/// Descriptor loading source.
90#[derive(Debug, Clone)]
91pub enum DescriptorSource {
92    /// Pre-compiled descriptor file.
93    File { file: PathBuf },
94    /// gRPC server reflection (development mode).
95    Reflection { reflection: String },
96    /// Embedded bytes (set programmatically, not from YAML).
97    Embedded { bytes: &'static [u8] },
98}
99
100/// Helper for YAML deserialization (only File and Reflection variants).
101#[derive(Debug, Clone, Deserialize)]
102#[serde(untagged)]
103enum DescriptorSourceYaml {
104    File { file: PathBuf },
105    Reflection { reflection: String },
106}
107
108impl From<DescriptorSourceYaml> for DescriptorSource {
109    fn from(yaml: DescriptorSourceYaml) -> Self {
110        match yaml {
111            DescriptorSourceYaml::File { file } => DescriptorSource::File { file },
112            DescriptorSourceYaml::Reflection { reflection } => {
113                DescriptorSource::Reflection { reflection }
114            }
115        }
116    }
117}
118
119fn deserialize_descriptor_sources<'de, D>(
120    deserializer: D,
121) -> std::result::Result<Vec<DescriptorSource>, D::Error>
122where
123    D: serde::Deserializer<'de>,
124{
125    let yaml_sources: Vec<DescriptorSourceYaml> = Vec::deserialize(deserializer)?;
126    Ok(yaml_sources.into_iter().map(Into::into).collect())
127}
128
129/// Listen address configuration.
130#[derive(Debug, Clone, Deserialize)]
131pub struct ListenConfig {
132    /// HTTP listen address (default: "0.0.0.0:8080").
133    #[serde(default = "default_http_listen")]
134    pub http: String,
135}
136
137fn default_http_listen() -> String {
138    "0.0.0.0:8080".into()
139}
140
141impl Default for ListenConfig {
142    fn default() -> Self {
143        Self {
144            http: default_http_listen(),
145        }
146    }
147}
148
149/// Service identity.
150#[derive(Debug, Clone, Deserialize)]
151pub struct ServiceConfig {
152    /// Service name (appears in /health response and metrics namespace).
153    #[serde(default = "default_service_name")]
154    pub name: String,
155}
156
157fn default_service_name() -> String {
158    "structured-proxy".into()
159}
160
161impl Default for ServiceConfig {
162    fn default() -> Self {
163        Self {
164            name: default_service_name(),
165        }
166    }
167}
168
169/// Path alias (rewrite before routing).
170#[derive(Debug, Clone, Deserialize)]
171pub struct AliasConfig {
172    pub from: String,
173    pub to: String,
174}
175
176/// OpenAPI generation config.
177#[derive(Debug, Clone, Deserialize)]
178pub struct OpenApiConfig {
179    #[serde(default = "default_true")]
180    pub enabled: bool,
181    /// Path for OpenAPI JSON spec (default: "/openapi.json").
182    #[serde(default = "default_openapi_path")]
183    pub path: String,
184    /// Path for interactive API docs UI (default: "/docs").
185    #[serde(default = "default_docs_path")]
186    pub docs_path: String,
187    #[serde(default)]
188    pub title: Option<String>,
189    #[serde(default)]
190    pub version: Option<String>,
191}
192
193fn default_openapi_path() -> String {
194    "/openapi.json".into()
195}
196
197fn default_docs_path() -> String {
198    "/docs".into()
199}
200
201fn default_true() -> bool {
202    true
203}
204
205/// Auth configuration.
206#[derive(Debug, Clone, Deserialize)]
207pub struct AuthConfig {
208    /// Auth mode: "none", "jwt", "api_key".
209    #[serde(default = "default_auth_mode")]
210    pub mode: String,
211
212    /// JWT validation config.
213    #[serde(default)]
214    pub jwt: Option<JwtConfig>,
215
216    /// Forward auth endpoint.
217    #[serde(default)]
218    pub forward_auth: Option<ForwardAuthConfig>,
219
220    /// AuthZ integration (optional gRPC call).
221    #[serde(default)]
222    pub authz: Option<AuthzConfig>,
223
224    /// BFF (Backend-for-Frontend) session config.
225    #[serde(default)]
226    pub bff: Option<BffConfig>,
227}
228
229fn default_auth_mode() -> String {
230    "none".into()
231}
232
233/// JWT validation config.
234#[derive(Debug, Clone, Deserialize)]
235pub struct JwtConfig {
236    /// JWKS URI for key discovery.
237    #[serde(default)]
238    pub jwks_uri: Option<String>,
239    /// Expected issuer.
240    #[serde(default)]
241    pub issuer: Option<String>,
242    /// Expected audience.
243    #[serde(default)]
244    pub audience: Option<String>,
245    /// Path to Ed25519 public key PEM file (alternative to JWKS URI).
246    #[serde(default)]
247    pub public_key_pem_file: Option<PathBuf>,
248    /// Claims → HTTP headers mapping.
249    #[serde(default)]
250    pub claims_headers: std::collections::HashMap<String, String>,
251}
252
253/// Forward auth config.
254#[derive(Debug, Clone, Deserialize)]
255pub struct ForwardAuthConfig {
256    #[serde(default)]
257    pub enabled: bool,
258    #[serde(default = "default_forward_auth_path")]
259    pub path: String,
260    /// Route policies.
261    #[serde(default)]
262    pub policies: Vec<RoutePolicyConfig>,
263    /// Login URL for 401 redirects.
264    #[serde(default)]
265    pub login_url: Option<String>,
266    /// Applications YAML file path.
267    #[serde(default)]
268    pub applications_path: Option<PathBuf>,
269}
270
271fn default_forward_auth_path() -> String {
272    "/auth/verify".into()
273}
274
275/// Route policy entry.
276#[derive(Debug, Clone, Deserialize)]
277pub struct RoutePolicyConfig {
278    pub path: String,
279    #[serde(default = "default_methods_all")]
280    pub methods: Vec<String>,
281    #[serde(default)]
282    pub require_auth: bool,
283    #[serde(default)]
284    pub required_roles: Vec<String>,
285}
286
287fn default_methods_all() -> Vec<String> {
288    vec!["*".into()]
289}
290
291/// AuthZ gRPC integration.
292#[derive(Debug, Clone, Deserialize)]
293pub struct AuthzConfig {
294    #[serde(default)]
295    pub enabled: bool,
296    pub service: String,
297    pub method: String,
298    #[serde(default)]
299    pub subject_template: Option<String>,
300    #[serde(default)]
301    pub resource_template: Option<String>,
302    #[serde(default)]
303    pub action_template: Option<String>,
304}
305
306/// BFF session config.
307#[derive(Debug, Clone, Deserialize)]
308pub struct BffConfig {
309    #[serde(default)]
310    pub enabled: bool,
311    #[serde(default = "default_bff_cookie")]
312    pub cookie_name: String,
313    #[serde(default = "default_bff_max_age")]
314    pub max_age: u64,
315    #[serde(default = "default_bff_idle_timeout")]
316    pub idle_timeout: u64,
317    #[serde(default)]
318    pub external_url: Option<String>,
319}
320
321fn default_bff_cookie() -> String {
322    "__Host-proxy-bff".into()
323}
324fn default_bff_max_age() -> u64 {
325    86400
326}
327fn default_bff_idle_timeout() -> u64 {
328    3600
329}
330
331/// Shield (rate limiting) configuration.
332#[derive(Debug, Clone, Deserialize)]
333pub struct ShieldConfig {
334    #[serde(default)]
335    pub enabled: bool,
336    /// Endpoint classification (glob pattern → class → rate limit).
337    #[serde(default)]
338    pub endpoint_classes: Vec<EndpointClassConfig>,
339    /// Per-identifier rate limiting.
340    #[serde(default)]
341    pub identifier_endpoints: Vec<IdentifierEndpointConfig>,
342    /// Window size in seconds (default: 60).
343    #[serde(default = "default_window_secs")]
344    pub window_secs: u64,
345}
346
347fn default_window_secs() -> u64 {
348    60
349}
350
351/// Endpoint classification for rate limiting.
352#[derive(Debug, Clone, Deserialize)]
353pub struct EndpointClassConfig {
354    /// Glob pattern (e.g., "/v1/auth/**").
355    pub pattern: String,
356    /// Class name (e.g., "auth").
357    pub class: String,
358    /// Rate limit string (e.g., "20/min").
359    pub rate: String,
360}
361
362/// Per-identifier rate limiting config.
363#[derive(Debug, Clone, Deserialize)]
364pub struct IdentifierEndpointConfig {
365    pub path: String,
366    pub body_field: String,
367    pub rate: String,
368}
369
370/// OIDC discovery config.
371#[derive(Debug, Clone, Deserialize)]
372pub struct OidcDiscoveryConfig {
373    #[serde(default)]
374    pub enabled: bool,
375    pub issuer: String,
376    #[serde(default)]
377    pub authorization_endpoint: Option<String>,
378    #[serde(default)]
379    pub token_endpoint: Option<String>,
380    #[serde(default)]
381    pub userinfo_endpoint: Option<String>,
382    #[serde(default)]
383    pub jwks_uri: Option<String>,
384    #[serde(default)]
385    pub signing_key: Option<SigningKeyConfig>,
386}
387
388/// Signing key config for JWKS endpoint.
389#[derive(Debug, Clone, Deserialize)]
390pub struct SigningKeyConfig {
391    #[serde(default = "default_algorithm")]
392    pub algorithm: String,
393    pub public_key_pem_file: PathBuf,
394}
395
396fn default_algorithm() -> String {
397    "EdDSA".into()
398}
399
400/// Maintenance mode config.
401#[derive(Debug, Clone, Deserialize)]
402pub struct MaintenanceConfig {
403    #[serde(default)]
404    pub enabled: bool,
405    /// Paths exempt from maintenance mode (glob patterns).
406    #[serde(default = "default_exempt_paths")]
407    pub exempt_paths: Vec<String>,
408    #[serde(default = "default_maintenance_message")]
409    pub message: String,
410}
411
412fn default_exempt_paths() -> Vec<String> {
413    vec![
414        "/health/**".into(),
415        "/.well-known/**".into(),
416        "/metrics".into(),
417        "/auth/verify".into(),
418    ]
419}
420
421fn default_maintenance_message() -> String {
422    "Service is under maintenance. Please try again later.".into()
423}
424
425impl Default for MaintenanceConfig {
426    fn default() -> Self {
427        Self {
428            enabled: false,
429            exempt_paths: default_exempt_paths(),
430            message: default_maintenance_message(),
431        }
432    }
433}
434
435/// CORS configuration.
436#[derive(Debug, Clone, Default, Deserialize)]
437pub struct CorsConfig {
438    /// Allowed origins. Empty = permissive (dev mode).
439    #[serde(default)]
440    pub origins: Vec<String>,
441}
442
443/// Logging configuration.
444#[derive(Debug, Clone, Deserialize)]
445pub struct LoggingConfig {
446    #[serde(default = "default_log_level")]
447    pub level: String,
448    #[serde(default = "default_log_format")]
449    pub format: String,
450}
451
452fn default_log_level() -> String {
453    "info".into()
454}
455fn default_log_format() -> String {
456    "json".into()
457}
458
459impl Default for LoggingConfig {
460    fn default() -> Self {
461        Self {
462            level: default_log_level(),
463            format: default_log_format(),
464        }
465    }
466}
467
468/// Metrics endpoint classification.
469#[derive(Debug, Clone, Deserialize)]
470pub struct MetricsClassConfig {
471    /// Glob pattern for path matching.
472    pub pattern: String,
473    /// Label value for this class.
474    pub class: String,
475}
476
477impl ProxyConfig {
478    /// Load configuration from a YAML file.
479    pub fn from_file(path: &std::path::Path) -> anyhow::Result<Self> {
480        let content = std::fs::read_to_string(path)?;
481        let config: Self = serde_yaml::from_str(&content)?;
482        Ok(config)
483    }
484
485    /// Parse rate string like "20/min" → requests per window.
486    pub fn parse_rate(rate: &str) -> Option<u32> {
487        let parts: Vec<&str> = rate.split('/').collect();
488        if parts.len() != 2 {
489            return None;
490        }
491        parts[0].trim().parse().ok()
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_minimal_config_deserialize() {
501        let yaml = r#"
502upstream:
503  default: "grpc://localhost:4180"
504"#;
505        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
506        assert_eq!(config.upstream.default, "grpc://localhost:4180");
507        assert_eq!(config.listen.http, "0.0.0.0:8080");
508        assert_eq!(config.service.name, "structured-proxy");
509        assert!(config.descriptors.is_empty());
510        assert!(config.auth.is_none());
511        assert!(config.shield.is_none());
512    }
513
514    #[test]
515    fn test_full_config_deserialize() {
516        let yaml = r#"
517upstream:
518  default: "grpc://sid-identity:4180"
519
520descriptors:
521  - file: "/etc/proxy/sid.descriptor.bin"
522
523listen:
524  http: "0.0.0.0:9090"
525
526service:
527  name: "sid-proxy"
528
529aliases:
530  - from: "/oauth2/{path}"
531    to: "/v1/oauth2/{path}"
532
533auth:
534  mode: "jwt"
535  jwt:
536    issuer: "https://auth.example.com"
537    public_key_pem_file: "/etc/proxy/signing.pub"
538    claims_headers:
539      sub: "x-forwarded-user"
540      acr: "x-sid-auth-level"
541  forward_auth:
542    enabled: true
543    path: "/auth/verify"
544    policies:
545      - path: "/v1/admin/**"
546        require_auth: true
547        required_roles: ["admin"]
548      - path: "/v1/public/**"
549        require_auth: false
550
551shield:
552  enabled: true
553  endpoint_classes:
554    - pattern: "/v1/auth/**"
555      class: "auth"
556      rate: "20/min"
557    - pattern: "/**"
558      class: "default"
559      rate: "100/min"
560  identifier_endpoints:
561    - path: "/v1/auth/opaque/login/start"
562      body_field: "identifier"
563      rate: "10/min"
564
565oidc_discovery:
566  enabled: true
567  issuer: "https://auth.example.com"
568
569maintenance:
570  enabled: false
571  exempt_paths:
572    - "/health/**"
573    - "/.well-known/**"
574
575cors:
576  origins:
577    - "https://app.example.com"
578
579metrics_classes:
580  - pattern: "/v1/auth/**"
581    class: "auth"
582  - pattern: "/v1/admin/**"
583    class: "admin"
584
585forwarded_headers:
586  - "authorization"
587  - "dpop"
588  - "x-request-id"
589"#;
590        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
591        assert_eq!(config.upstream.default, "grpc://sid-identity:4180");
592        assert_eq!(config.listen.http, "0.0.0.0:9090");
593        assert_eq!(config.service.name, "sid-proxy");
594        assert_eq!(config.aliases.len(), 1);
595        assert!(config.auth.is_some());
596        assert!(config.shield.is_some());
597        assert!(config.oidc_discovery.is_some());
598        assert_eq!(config.cors.origins.len(), 1);
599        assert_eq!(config.metrics_classes.len(), 2);
600        assert_eq!(config.forwarded_headers.len(), 3);
601    }
602
603    #[test]
604    fn test_descriptor_source_file() {
605        let yaml = r#"
606upstream:
607  default: "grpc://localhost:4180"
608descriptors:
609  - file: "/etc/proxy/service.descriptor.bin"
610"#;
611        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
612        assert_eq!(config.descriptors.len(), 1);
613        match &config.descriptors[0] {
614            DescriptorSource::File { file } => {
615                assert_eq!(file.to_str().unwrap(), "/etc/proxy/service.descriptor.bin");
616            }
617            _ => panic!("expected File descriptor source"),
618        }
619    }
620
621    #[test]
622    fn test_descriptor_source_reflection() {
623        let yaml = r#"
624upstream:
625  default: "grpc://localhost:4180"
626descriptors:
627  - reflection: "grpc://localhost:4180"
628"#;
629        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
630        match &config.descriptors[0] {
631            DescriptorSource::Reflection { reflection } => {
632                assert_eq!(reflection, "grpc://localhost:4180");
633            }
634            _ => panic!("expected Reflection descriptor source"),
635        }
636    }
637
638    #[test]
639    fn test_parse_rate() {
640        assert_eq!(ProxyConfig::parse_rate("20/min"), Some(20));
641        assert_eq!(ProxyConfig::parse_rate("100/min"), Some(100));
642        assert_eq!(ProxyConfig::parse_rate("5/min"), Some(5));
643        assert_eq!(ProxyConfig::parse_rate("invalid"), None);
644    }
645
646    #[test]
647    fn test_openapi_config_deserialize() {
648        let yaml = r#"
649upstream:
650  default: "grpc://localhost:4180"
651openapi:
652  enabled: true
653  path: "/api/openapi.json"
654  docs_path: "/api/docs"
655  title: "Test API"
656  version: "2.0.0"
657"#;
658        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
659        let openapi = config.openapi.unwrap();
660        assert!(openapi.enabled);
661        assert_eq!(openapi.path, "/api/openapi.json");
662        assert_eq!(openapi.docs_path, "/api/docs");
663        assert_eq!(openapi.title.unwrap(), "Test API");
664        assert_eq!(openapi.version.unwrap(), "2.0.0");
665    }
666
667    #[test]
668    fn test_openapi_config_defaults() {
669        let yaml = r#"
670upstream:
671  default: "grpc://localhost:4180"
672openapi:
673  enabled: true
674"#;
675        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
676        let openapi = config.openapi.unwrap();
677        assert!(openapi.enabled);
678        assert_eq!(openapi.path, "/openapi.json");
679        assert_eq!(openapi.docs_path, "/docs");
680        assert!(openapi.title.is_none());
681        assert!(openapi.version.is_none());
682    }
683}