1#[cfg(all(feature = "rust_crypto", feature = "aws_lc_rs"))]
23compile_error!("features `rust_crypto` and `aws_lc_rs` are mutually exclusive; enable exactly one");
24
25#[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))]
26compile_error!("exactly one JWT crypto backend must be enabled: `rust_crypto` or `aws_lc_rs`");
27
28pub mod auth;
29pub mod config;
30pub mod oidc;
31pub mod openapi;
32pub mod shield;
33pub mod transcode;
34
35use axum::extract::State;
36use axum::http::{Request, StatusCode};
37use axum::middleware::Next;
38use axum::response::{IntoResponse, Response};
39use axum::routing::get;
40use axum::{Json, Router};
41use prost_reflect::DescriptorPool;
42use std::net::SocketAddr;
43use tower_http::cors::{AllowOrigin, CorsLayer};
44use tower_http::trace::TraceLayer;
45
46use config::{DescriptorSource, ProxyConfig};
47
48#[derive(Clone, Debug)]
50pub struct ProxyState {
51 pub service_name: String,
53 pub grpc_upstream: String,
55 pub grpc_channel: tonic::transport::Channel,
57 pub maintenance_mode: bool,
59 pub maintenance_exempt: Vec<String>,
61 pub maintenance_message: String,
63 pub forwarded_headers: Vec<String>,
65 pub metrics_namespace: String,
67 pub metrics_classes: Vec<config::MetricsClassConfig>,
69}
70
71pub struct ProxyServer {
73 config: ProxyConfig,
74 descriptor_pool: Option<DescriptorPool>,
76}
77
78impl ProxyServer {
79 pub fn from_config(config: ProxyConfig) -> Self {
81 Self {
82 config,
83 descriptor_pool: None,
84 }
85 }
86
87 pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
89 self.descriptor_pool = Some(pool);
90 self
91 }
92
93 fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
98 if let Some(pool) = &self.descriptor_pool {
99 return Ok(pool.clone());
100 }
101
102 let mut pool = DescriptorPool::new();
103
104 for source in &self.config.descriptors {
105 match source {
106 DescriptorSource::File { file } => {
107 let bytes = std::fs::read(file).map_err(|e| {
108 anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
109 })?;
110 pool.decode_file_descriptor_set(bytes.as_slice())
111 .map_err(|e| {
112 anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
113 })?;
114 tracing::info!("Loaded descriptor from {:?}", file);
115 }
116 DescriptorSource::Reflection { reflection } => {
117 tracing::warn!(
118 "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
119 reflection
120 );
121 }
122 DescriptorSource::Embedded { bytes } => {
123 pool.decode_file_descriptor_set(*bytes).map_err(|e| {
124 anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
125 })?;
126 }
127 }
128 }
129
130 Ok(pool)
131 }
132
133 pub fn router(&self) -> anyhow::Result<Router> {
135 let pool = self.load_descriptors()?;
136
137 let grpc_upstream = self.config.upstream.default.clone();
138 let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
139 .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
140 .connect_timeout(std::time::Duration::from_secs(5))
141 .timeout(std::time::Duration::from_secs(5))
142 .connect_lazy();
143
144 let service_name = self.config.service.name.clone();
145 let metrics_namespace = service_name.replace('-', "_");
146
147 let state = ProxyState {
148 service_name: service_name.clone(),
149 grpc_upstream,
150 grpc_channel,
151 maintenance_mode: self.config.maintenance.enabled,
152 maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
153 maintenance_message: self.config.maintenance.message.clone(),
154 forwarded_headers: self.config.forwarded_headers.clone(),
155 metrics_namespace,
156 metrics_classes: self.config.metrics_classes.clone(),
157 };
158
159 let cors = self.build_cors();
160
161 let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
163
164 let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
169 Some(cfg) => auth::authz::Authz::build(cfg)
170 .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
171 None => None,
172 };
173 if let Some(authz) = authz {
174 transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
175 authz,
176 auth::authz::middleware,
177 ));
178 }
179
180 let health_service_name = service_name.clone();
182 let health_routes = Router::new()
183 .route(
184 "/health",
185 get({
186 let name = health_service_name.clone();
187 move || async move {
188 Json(serde_json::json!({
189 "status": "ok",
190 "service": name,
191 }))
192 }
193 }),
194 )
195 .route("/health/live", get(|| async { StatusCode::OK }))
196 .route(
197 "/health/ready",
198 get(|State(state): State<ProxyState>| async move {
199 let mut client =
200 tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
201 match client
202 .check(tonic_health::pb::HealthCheckRequest {
203 service: String::new(),
204 })
205 .await
206 {
207 Ok(resp) => {
208 let status = resp.into_inner().status;
209 if status
210 == tonic_health::pb::health_check_response::ServingStatus::Serving
211 as i32
212 {
213 StatusCode::OK
214 } else {
215 StatusCode::SERVICE_UNAVAILABLE
216 }
217 }
218 Err(_) => StatusCode::SERVICE_UNAVAILABLE,
219 }
220 }),
221 )
222 .route("/health/startup", get(|| async { StatusCode::OK }))
223 .route(
224 "/metrics",
225 get(|| async {
226 let encoder = prometheus::TextEncoder::new();
227 let metric_families = prometheus::default_registry().gather();
228 match encoder.encode_to_string(&metric_families) {
229 Ok(text) => (
230 StatusCode::OK,
231 [(
232 axum::http::header::CONTENT_TYPE,
233 "text/plain; version=0.0.4; charset=utf-8",
234 )],
235 text,
236 )
237 .into_response(),
238 Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
239 }
240 }),
241 );
242
243 let openapi_routes = self.build_openapi_routes(&pool);
245
246 let oidc_routes = match &self.config.oidc_discovery {
248 Some(cfg) => oidc::Oidc::build(cfg)
249 .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
250 .map(|o| o.routes())
251 .unwrap_or_default(),
252 None => Router::new(),
253 };
254
255 let shield = match &self.config.shield {
257 Some(cfg) => shield::Shield::build(cfg)
258 .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
259 None => None,
260 };
261
262 let auth = match &self.config.auth {
264 Some(cfg) => {
265 auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
266 }
267 None => None,
268 };
269
270 let mut router = Router::new()
271 .merge(health_routes)
272 .merge(openapi_routes)
273 .merge(oidc_routes)
274 .merge(transcode_routes)
275 .layer(cors);
276
277 let forward_auth = auth.as_ref().and_then(|built| {
281 auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
282 });
283
284 if let Some(auth) = auth {
287 router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
288 }
289
290 if let Some(forward_auth) = &forward_auth {
291 router = router.merge(forward_auth.routes());
292 }
293
294 if let Some(shield) = shield {
298 router = router.layer(axum::middleware::from_fn_with_state(
299 shield,
300 shield::middleware,
301 ));
302 }
303
304 let router = router
305 .layer(axum::middleware::from_fn_with_state(
306 state.clone(),
307 maintenance_middleware,
308 ))
309 .layer(TraceLayer::new_for_http())
310 .with_state(state);
311
312 Ok(router)
313 }
314
315 fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
316 let openapi_config = match &self.config.openapi {
317 Some(cfg) if cfg.enabled => cfg,
318 _ => return Router::new(),
319 };
320
321 let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
322 let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
323 let openapi_path = openapi_config.path.clone();
324 let docs_path = openapi_config.docs_path.clone();
325 let title = openapi_config
326 .title
327 .clone()
328 .unwrap_or_else(|| self.config.service.name.clone());
329 let openapi_path_for_docs = openapi_path.clone();
330
331 tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
332
333 Router::new()
334 .route(
335 &openapi_path,
336 get(move || async move {
337 (
338 StatusCode::OK,
339 [(
340 axum::http::header::CONTENT_TYPE,
341 "application/json; charset=utf-8",
342 )],
343 spec_json,
344 )
345 }),
346 )
347 .route(
348 &docs_path,
349 get(move || async move {
350 let html = openapi::docs_html(&openapi_path_for_docs, &title);
351 (
352 StatusCode::OK,
353 [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
354 html,
355 )
356 }),
357 )
358 }
359
360 fn build_cors(&self) -> CorsLayer {
361 if self.config.cors.origins.is_empty() {
362 tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
363 CorsLayer::permissive()
364 } else {
365 let origins: Vec<_> = self
366 .config
367 .cors
368 .origins
369 .iter()
370 .filter_map(|o| o.parse().ok())
371 .collect();
372 CorsLayer::new()
373 .allow_origin(AllowOrigin::list(origins))
374 .allow_methods(tower_http::cors::Any)
375 .allow_headers(tower_http::cors::Any)
376 .allow_credentials(true)
377 .expose_headers([
378 "grpc-status".parse().unwrap(),
379 "grpc-message".parse().unwrap(),
380 ])
381 }
382 }
383
384 pub async fn serve(&self) -> anyhow::Result<()> {
386 let router = self.router()?;
387 let app = router.into_make_service_with_connect_info::<SocketAddr>();
388 let addr: SocketAddr = self.config.listen.http.parse()?;
389 let listener = tokio::net::TcpListener::bind(addr).await?;
390
391 tracing::info!("{} listening on {}", self.config.service.name, addr);
392 axum::serve(listener, app).await?;
393 Ok(())
394 }
395}
396
397async fn maintenance_middleware(
399 State(state): State<ProxyState>,
400 request: Request<axum::body::Body>,
401 next: Next,
402) -> Response {
403 if state.maintenance_mode {
404 let path = request.uri().path();
405 let exempt = state.maintenance_exempt.iter().any(|pattern| {
406 if pattern.ends_with("/**") {
407 let prefix = &pattern[..pattern.len() - 3];
408 path.starts_with(prefix)
409 } else {
410 path == pattern
411 }
412 });
413 if !exempt {
414 return (
415 StatusCode::SERVICE_UNAVAILABLE,
416 [("retry-after", "300")],
417 state.maintenance_message.clone(),
418 )
419 .into_response();
420 }
421 }
422 next.run(request).await
423}
424
425#[cfg(test)]
427pub(crate) fn test_channel() -> tonic::transport::Channel {
428 tonic::transport::Channel::from_static("http://127.0.0.1:1")
429 .connect_timeout(std::time::Duration::from_millis(100))
430 .connect_lazy()
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_minimal_config_server() {
439 let yaml = r#"
440upstream:
441 default: "http://127.0.0.1:50051"
442"#;
443 let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
444 let server = ProxyServer::from_config(config);
445 assert!(server.descriptor_pool.is_none());
446 }
447
448 #[tokio::test]
449 async fn test_maintenance_exempt_matching() {
450 let state = ProxyState {
451 service_name: "test".into(),
452 grpc_upstream: "http://localhost:50051".into(),
453 grpc_channel: test_channel(),
454 maintenance_mode: true,
455 maintenance_exempt: vec![
456 "/health/**".into(),
457 "/.well-known/**".into(),
458 "/metrics".into(),
459 ],
460 maintenance_message: "Down".into(),
461 forwarded_headers: vec![],
462 metrics_namespace: "test".into(),
463 metrics_classes: vec![],
464 };
465
466 let check = |path: &str| -> bool {
467 state.maintenance_exempt.iter().any(|pattern| {
468 if pattern.ends_with("/**") {
469 let prefix = &pattern[..pattern.len() - 3];
470 path.starts_with(prefix)
471 } else {
472 path == pattern
473 }
474 })
475 };
476
477 assert!(check("/health"));
478 assert!(check("/health/ready"));
479 assert!(check("/.well-known/openid-configuration"));
480 assert!(check("/metrics"));
481 assert!(!check("/v1/auth/login"));
482 assert!(!check("/oauth2/token"));
483 }
484}