Skip to main content

structured_proxy/
lib.rs

1//! Universal gRPC→REST transcoding proxy.
2//!
3//! Config-driven: same binary, different YAML = different product proxy.
4//! Works with ANY gRPC service via proto descriptors as config.
5//!
6//! ## Usage
7//!
8//! ```bash
9//! structured-proxy --config sid-proxy.yaml
10//! structured-proxy --config sflow-proxy.yaml
11//! ```
12//!
13//! ## JWT crypto backend
14//!
15//! Exactly one crypto backend feature must be enabled (they are mutually
16//! exclusive): `rust_crypto` (default, pure Rust) or `aws_lc_rs` (opt-in,
17//! constant-time / FIPS-capable, links aws-lc via C FFI). Enabling both or
18//! neither is rejected at compile time by the guards below.
19
20// jsonwebtoken selects its provider from these features and would otherwise
21// panic at runtime on an invalid combination; turn that into a build error.
22#[cfg(all(feature = "rust_crypto", feature = "aws_lc_rs"))]
23compile_error!("features `rust_crypto` and `aws_lc_rs` are mutually exclusive; enable exactly one");
24
25#[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))]
26compile_error!("exactly one JWT crypto backend must be enabled: `rust_crypto` or `aws_lc_rs`");
27
28pub mod auth;
29pub mod config;
30pub mod oidc;
31pub mod openapi;
32pub mod shield;
33pub mod transcode;
34
35use axum::extract::State;
36use axum::http::{Request, StatusCode};
37use axum::middleware::Next;
38use axum::response::{IntoResponse, Response};
39use axum::routing::get;
40use axum::{Json, Router};
41use prost_reflect::DescriptorPool;
42use std::net::SocketAddr;
43use tower_http::cors::{AllowOrigin, CorsLayer};
44use tower_http::trace::TraceLayer;
45
46use config::{DescriptorSource, ProxyConfig};
47
48/// Shared state for all proxy handlers.
49#[derive(Clone, Debug)]
50pub struct ProxyState {
51    /// Service name from config.
52    pub service_name: String,
53    /// gRPC upstream address.
54    pub grpc_upstream: String,
55    /// Lazy gRPC channel to upstream service.
56    pub grpc_channel: tonic::transport::Channel,
57    /// Maintenance mode active.
58    pub maintenance_mode: bool,
59    /// Maintenance exempt path patterns.
60    pub maintenance_exempt: Vec<String>,
61    /// Maintenance message.
62    pub maintenance_message: String,
63    /// Headers to forward from HTTP to gRPC.
64    pub forwarded_headers: Vec<String>,
65    /// Metrics namespace (derived from service name).
66    pub metrics_namespace: String,
67    /// Path class patterns for metrics.
68    pub metrics_classes: Vec<config::MetricsClassConfig>,
69}
70
71/// Universal proxy server.
72pub struct ProxyServer {
73    config: ProxyConfig,
74    /// Optional pre-loaded descriptor pool (for embedded mode).
75    descriptor_pool: Option<DescriptorPool>,
76}
77
78impl ProxyServer {
79    /// Create from YAML config file.
80    pub fn from_config(config: ProxyConfig) -> Self {
81        Self {
82            config,
83            descriptor_pool: None,
84        }
85    }
86
87    /// Create with an embedded descriptor pool (for sid-proxy backward compat).
88    pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
89        self.descriptor_pool = Some(pool);
90        self
91    }
92
93    /// Load descriptor pool from configured sources.
94    ///
95    /// Multiple descriptor files are merged into a single pool,
96    /// enabling multi-service proxying from one binary.
97    fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
98        if let Some(pool) = &self.descriptor_pool {
99            return Ok(pool.clone());
100        }
101
102        let mut pool = DescriptorPool::new();
103
104        for source in &self.config.descriptors {
105            match source {
106                DescriptorSource::File { file } => {
107                    let bytes = std::fs::read(file).map_err(|e| {
108                        anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
109                    })?;
110                    pool.decode_file_descriptor_set(bytes.as_slice())
111                        .map_err(|e| {
112                            anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
113                        })?;
114                    tracing::info!("Loaded descriptor from {:?}", file);
115                }
116                DescriptorSource::Reflection { reflection } => {
117                    tracing::warn!(
118                        "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
119                        reflection
120                    );
121                }
122                DescriptorSource::Embedded { bytes } => {
123                    pool.decode_file_descriptor_set(*bytes).map_err(|e| {
124                        anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
125                    })?;
126                }
127            }
128        }
129
130        Ok(pool)
131    }
132
133    /// Build the axum router with all endpoints.
134    pub fn router(&self) -> anyhow::Result<Router> {
135        let pool = self.load_descriptors()?;
136
137        let grpc_upstream = self.config.upstream.default.clone();
138        let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
139            .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
140            .connect_timeout(std::time::Duration::from_secs(5))
141            .timeout(std::time::Duration::from_secs(5))
142            .connect_lazy();
143
144        let service_name = self.config.service.name.clone();
145        let metrics_namespace = service_name.replace('-', "_");
146
147        let state = ProxyState {
148            service_name: service_name.clone(),
149            grpc_upstream,
150            grpc_channel,
151            maintenance_mode: self.config.maintenance.enabled,
152            maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
153            maintenance_message: self.config.maintenance.message.clone(),
154            forwarded_headers: self.config.forwarded_headers.clone(),
155            metrics_namespace,
156            metrics_classes: self.config.metrics_classes.clone(),
157        };
158
159        let cors = self.build_cors();
160
161        // Build transcoding routes from descriptor pool.
162        let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
163
164        // External authorization (Envoy ext_authz) gates only the proxied API
165        // routes, never health / metrics / discovery. It runs inside the auth
166        // layer below, so the Check call sees the identity headers the JWT
167        // middleware injected.
168        let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
169            Some(cfg) => auth::authz::Authz::build(cfg)
170                .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
171            None => None,
172        };
173        if let Some(authz) = authz {
174            transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
175                authz,
176                auth::authz::middleware,
177            ));
178        }
179
180        // Health routes
181        let health_service_name = service_name.clone();
182        let health_routes = Router::new()
183            .route(
184                "/health",
185                get({
186                    let name = health_service_name.clone();
187                    move || async move {
188                        Json(serde_json::json!({
189                            "status": "ok",
190                            "service": name,
191                        }))
192                    }
193                }),
194            )
195            .route("/health/live", get(|| async { StatusCode::OK }))
196            .route(
197                "/health/ready",
198                get(|State(state): State<ProxyState>| async move {
199                    let mut client =
200                        tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
201                    match client
202                        .check(tonic_health::pb::HealthCheckRequest {
203                            service: String::new(),
204                        })
205                        .await
206                    {
207                        Ok(resp) => {
208                            let status = resp.into_inner().status;
209                            if status
210                                == tonic_health::pb::health_check_response::ServingStatus::Serving
211                                    as i32
212                            {
213                                StatusCode::OK
214                            } else {
215                                StatusCode::SERVICE_UNAVAILABLE
216                            }
217                        }
218                        Err(_) => StatusCode::SERVICE_UNAVAILABLE,
219                    }
220                }),
221            )
222            .route("/health/startup", get(|| async { StatusCode::OK }))
223            .route(
224                "/metrics",
225                get(|| async {
226                    let encoder = prometheus::TextEncoder::new();
227                    let metric_families = prometheus::default_registry().gather();
228                    match encoder.encode_to_string(&metric_families) {
229                        Ok(text) => (
230                            StatusCode::OK,
231                            [(
232                                axum::http::header::CONTENT_TYPE,
233                                "text/plain; version=0.0.4; charset=utf-8",
234                            )],
235                            text,
236                        )
237                            .into_response(),
238                        Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
239                    }
240                }),
241            );
242
243        // OpenAPI + docs routes (if enabled).
244        let openapi_routes = self.build_openapi_routes(&pool);
245
246        // OIDC discovery routes (if enabled). Public, like the health endpoints.
247        let oidc_routes = match &self.config.oidc_discovery {
248            Some(cfg) => oidc::Oidc::build(cfg)
249                .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
250                .map(|o| o.routes())
251                .unwrap_or_default(),
252            None => Router::new(),
253        };
254
255        // Rate limiting (Shield), if configured and enabled.
256        let shield = match &self.config.shield {
257            Some(cfg) => shield::Shield::build(cfg)
258                .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
259            None => None,
260        };
261
262        // JWT auth, if configured (auth.mode == "jwt").
263        let auth = match &self.config.auth {
264            Some(cfg) => {
265                auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
266            }
267            None => None,
268        };
269
270        let mut router = Router::new()
271            .merge(health_routes)
272            .merge(openapi_routes)
273            .merge(oidc_routes)
274            .merge(transcode_routes)
275            .layer(cors);
276
277        // Forward-auth verification endpoint, sharing the built Auth. Mounted
278        // after the auth layer below so the endpoint itself is not gated by the
279        // JWT middleware (it answers the gate, it isn't behind it).
280        let forward_auth = auth.as_ref().and_then(|built| {
281            auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
282        });
283
284        // Auth runs inside Shield (added first = inner): rate limiting sheds
285        // load before any signature verification work.
286        if let Some(auth) = auth {
287            router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
288        }
289
290        if let Some(forward_auth) = &forward_auth {
291            router = router.merge(forward_auth.routes());
292        }
293
294        // Shield is added before maintenance so maintenance wraps it (outer
295        // layers run first): a request rejected by the maintenance gate must
296        // not be charged against its rate-limit budget.
297        if let Some(shield) = shield {
298            router = router.layer(axum::middleware::from_fn_with_state(
299                shield,
300                shield::middleware,
301            ));
302        }
303
304        let router = router
305            .layer(axum::middleware::from_fn_with_state(
306                state.clone(),
307                maintenance_middleware,
308            ))
309            .layer(TraceLayer::new_for_http())
310            .with_state(state);
311
312        Ok(router)
313    }
314
315    fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
316        let openapi_config = match &self.config.openapi {
317            Some(cfg) if cfg.enabled => cfg,
318            _ => return Router::new(),
319        };
320
321        let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
322        let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
323        let openapi_path = openapi_config.path.clone();
324        let docs_path = openapi_config.docs_path.clone();
325        let title = openapi_config
326            .title
327            .clone()
328            .unwrap_or_else(|| self.config.service.name.clone());
329        let openapi_path_for_docs = openapi_path.clone();
330
331        tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
332
333        Router::new()
334            .route(
335                &openapi_path,
336                get(move || async move {
337                    (
338                        StatusCode::OK,
339                        [(
340                            axum::http::header::CONTENT_TYPE,
341                            "application/json; charset=utf-8",
342                        )],
343                        spec_json,
344                    )
345                }),
346            )
347            .route(
348                &docs_path,
349                get(move || async move {
350                    let html = openapi::docs_html(&openapi_path_for_docs, &title);
351                    (
352                        StatusCode::OK,
353                        [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
354                        html,
355                    )
356                }),
357            )
358    }
359
360    fn build_cors(&self) -> CorsLayer {
361        if self.config.cors.origins.is_empty() {
362            tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
363            CorsLayer::permissive()
364        } else {
365            let origins: Vec<_> = self
366                .config
367                .cors
368                .origins
369                .iter()
370                .filter_map(|o| o.parse().ok())
371                .collect();
372            CorsLayer::new()
373                .allow_origin(AllowOrigin::list(origins))
374                .allow_methods(tower_http::cors::Any)
375                .allow_headers(tower_http::cors::Any)
376                .allow_credentials(true)
377                .expose_headers([
378                    "grpc-status".parse().unwrap(),
379                    "grpc-message".parse().unwrap(),
380                ])
381        }
382    }
383
384    /// Start serving on configured address.
385    pub async fn serve(&self) -> anyhow::Result<()> {
386        let router = self.router()?;
387        let app = router.into_make_service_with_connect_info::<SocketAddr>();
388        let addr: SocketAddr = self.config.listen.http.parse()?;
389        let listener = tokio::net::TcpListener::bind(addr).await?;
390
391        tracing::info!("{} listening on {}", self.config.service.name, addr);
392        axum::serve(listener, app).await?;
393        Ok(())
394    }
395}
396
397/// Maintenance mode middleware.
398async fn maintenance_middleware(
399    State(state): State<ProxyState>,
400    request: Request<axum::body::Body>,
401    next: Next,
402) -> Response {
403    if state.maintenance_mode {
404        let path = request.uri().path();
405        let exempt = state.maintenance_exempt.iter().any(|pattern| {
406            if pattern.ends_with("/**") {
407                let prefix = &pattern[..pattern.len() - 3];
408                path.starts_with(prefix)
409            } else {
410                path == pattern
411            }
412        });
413        if !exempt {
414            return (
415                StatusCode::SERVICE_UNAVAILABLE,
416                [("retry-after", "300")],
417                state.maintenance_message.clone(),
418            )
419                .into_response();
420        }
421    }
422    next.run(request).await
423}
424
425/// Create a lazy gRPC channel for testing (connects to nowhere).
426#[cfg(test)]
427pub(crate) fn test_channel() -> tonic::transport::Channel {
428    tonic::transport::Channel::from_static("http://127.0.0.1:1")
429        .connect_timeout(std::time::Duration::from_millis(100))
430        .connect_lazy()
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_minimal_config_server() {
439        let yaml = r#"
440upstream:
441  default: "http://127.0.0.1:50051"
442"#;
443        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
444        let server = ProxyServer::from_config(config);
445        assert!(server.descriptor_pool.is_none());
446    }
447
448    #[tokio::test]
449    async fn test_maintenance_exempt_matching() {
450        let state = ProxyState {
451            service_name: "test".into(),
452            grpc_upstream: "http://localhost:50051".into(),
453            grpc_channel: test_channel(),
454            maintenance_mode: true,
455            maintenance_exempt: vec![
456                "/health/**".into(),
457                "/.well-known/**".into(),
458                "/metrics".into(),
459            ],
460            maintenance_message: "Down".into(),
461            forwarded_headers: vec![],
462            metrics_namespace: "test".into(),
463            metrics_classes: vec![],
464        };
465
466        let check = |path: &str| -> bool {
467            state.maintenance_exempt.iter().any(|pattern| {
468                if pattern.ends_with("/**") {
469                    let prefix = &pattern[..pattern.len() - 3];
470                    path.starts_with(prefix)
471                } else {
472                    path == pattern
473                }
474            })
475        };
476
477        assert!(check("/health"));
478        assert!(check("/health/ready"));
479        assert!(check("/.well-known/openid-configuration"));
480        assert!(check("/metrics"));
481        assert!(!check("/v1/auth/login"));
482        assert!(!check("/oauth2/token"));
483    }
484}