1use core::iter::once;
2use core::ops::Add;
3use std::net::IpAddr;
4use std::path::Path;
5
6use axum::Json;
7use axum::http::uri;
8use axum::{
9 Router,
10 body::Body,
11 extract::{Request, State},
12 http::StatusCode,
13 http::header,
14 http::uri::Uri,
15 response::{IntoResponse, Response},
16 routing::any,
17 routing::post,
18};
19use blueprint_core::{debug, error, info, warn};
20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
21use tower_http::cors::CorsLayer;
22use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
23use tower_http::sensitive_headers::{
24 SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer,
25};
26use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
27use tracing::instrument;
28
29use crate::api_keys::{ApiKeyGenerator, ApiKeyModel};
30use crate::certificate_authority::{
31 CertificateAuthority, ClientCertificate, CreateTlsProfileRequest, IssueCertificateRequest,
32 TlsProfileResponse, validate_certificate_request,
33};
34use crate::db::RocksDb;
35use crate::models::{ApiTokenModel, ServiceModel, TlsProfile};
36use crate::paseto_tokens::PasetoTokenManager;
37use crate::request_extensions::{AuthMethod, extract_client_cert_from_request};
38use crate::tls_client::TlsClientManager;
39use crate::tls_envelope::{TlsEnvelope, init_tls_envelope_key};
40use crate::tls_listener::{TlsListenerConfig, TlsListenerManager};
41use pem;
42use prost::Message;
43
44static GRPC_BINARY_METADATA_MAX_SIZE: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
49
50fn get_max_binary_metadata_size() -> usize {
51 *GRPC_BINARY_METADATA_MAX_SIZE.get_or_init(|| {
52 std::env::var("GRPC_BINARY_METADATA_MAX_SIZE")
53 .ok()
54 .and_then(|s| s.parse().ok())
55 .unwrap_or(16384) })
57}
58use crate::types::{ServiceId, VerifyChallengeResponse};
59use crate::validation;
60
61type HTTPClient =
62 hyper_util::client::legacy::Client<hyper_util::client::legacy::connect::HttpConnector, Body>;
63type HTTP2Client =
64 hyper_util::client::legacy::Client<hyper_util::client::legacy::connect::HttpConnector, Body>;
65
66pub const DEFAULT_AUTH_PROXY_PORT: u16 = 8276;
69
70pub struct AuthenticatedProxy {
71 http_client: HTTPClient,
72 http2_client: HTTP2Client,
73 tls_client_manager: TlsClientManager,
74 db: crate::db::RocksDb,
75 paseto_manager: PasetoTokenManager,
76 tls_envelope: TlsEnvelope,
77 tls_runtime: TlsListenerManager,
78}
79
80#[derive(Clone, Debug)]
81pub struct AuthenticatedProxyState {
82 http_client: HTTPClient,
83 http2_client: HTTP2Client,
84 tls_client_manager: TlsClientManager,
85 db: crate::db::RocksDb,
86 paseto_manager: PasetoTokenManager,
87 tls_envelope: TlsEnvelope,
88 mtls_listener_address: Option<std::net::SocketAddr>,
89 #[cfg(feature = "standalone")]
90 mtls_listener_handle: Option<std::sync::Arc<tokio::task::JoinHandle<()>>>,
91 tls_runtime: TlsListenerManager,
92}
93
94impl AuthenticatedProxyState {
95 pub fn db_ref(&self) -> &crate::db::RocksDb {
96 &self.db
97 }
98 pub fn paseto_manager_ref(&self) -> &PasetoTokenManager {
99 &self.paseto_manager
100 }
101 pub fn tls_envelope_ref(&self) -> &TlsEnvelope {
102 &self.tls_envelope
103 }
104 pub fn tls_client_manager_ref(&self) -> &TlsClientManager {
105 &self.tls_client_manager
106 }
107 pub fn tls_runtime_ref(&self) -> TlsListenerManager {
108 self.tls_runtime.clone()
109 }
110
111 #[cfg(feature = "standalone")]
112 pub fn set_mtls_listener_address(&mut self, addr: std::net::SocketAddr) {
113 self.mtls_listener_address = Some(addr);
114 }
115
116 #[cfg(not(feature = "standalone"))]
117 pub fn set_mtls_listener_address(&mut self, _addr: std::net::SocketAddr) {
118 }
120
121 pub fn get_mtls_listener_address(&self) -> Option<std::net::SocketAddr> {
122 self.mtls_listener_address
123 }
124}
125
126impl AuthenticatedProxy {
127 pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self, crate::Error> {
128 let executer = TokioExecutor::new();
129
130 let mut http_connector = HttpConnector::new();
132 http_connector.enforce_http(false); http_connector.set_nodelay(true); let http_client: HTTPClient = hyper_util::client::legacy::Builder::new(executer.clone())
137 .http2_only(false) .build(http_connector.clone());
139
140 let mut http2_connector = HttpConnector::new();
142 http2_connector.enforce_http(false); http2_connector.set_nodelay(true); let http2_client: HTTP2Client = hyper_util::client::legacy::Builder::new(executer)
147 .http2_only(true) .http2_adaptive_window(true) .build(http2_connector);
150
151 let db_config = crate::db::RocksDbConfig::default();
152 let db = crate::db::RocksDb::open(&db_path, &db_config)?;
153
154 let tls_envelope = Self::init_tls_envelope(&db_path)?;
156
157 let tls_runtime = TlsListenerManager::new(
158 db.clone(),
159 tls_envelope.clone(),
160 TlsListenerConfig::default(),
161 );
162
163 Self::hydrate_tls_runtime(&tls_runtime, &db)?;
164
165 let tls_client_manager = TlsClientManager::new(db.clone());
167
168 let paseto_manager = Self::init_paseto_manager(&db_path)?;
170
171 Ok(AuthenticatedProxy {
172 http_client,
173 http2_client,
174 tls_client_manager,
175 db,
176 paseto_manager,
177 tls_envelope,
178 tls_runtime,
179 })
180 }
181
182 fn hydrate_tls_runtime(runtime: &TlsListenerManager, db: &RocksDb) -> Result<(), crate::Error> {
183 use crate::db::cf;
184 use rocksdb::IteratorMode;
185
186 let cf_handle = db
187 .cf_handle(cf::SERVICES_USER_KEYS_CF)
188 .ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
189 let iter = db.iterator_cf(&cf_handle, IteratorMode::Start);
190 let mut profiles = Vec::new();
191
192 for item in iter {
193 let (key, value) = item?;
194 if key.len() < 16 {
195 continue;
196 }
197 let mut id_bytes = [0u8; 16];
198 id_bytes.copy_from_slice(&key[..16]);
199 let service_id = ServiceId::from_be_bytes(id_bytes);
200 let service = ServiceModel::decode(value.as_ref())?;
201 if let Some(profile) = service.tls_profile() {
202 if profile.tls_enabled {
203 profiles.push((service_id, profile.clone()));
204 }
205 }
206 }
207
208 if profiles.is_empty() {
209 return Ok(());
210 }
211
212 let runtime_clone = runtime.clone();
213 let future = async move {
214 for (service_id, profile) in profiles {
215 runtime_clone
216 .load_service_profile(service_id, &profile)
217 .await?;
218 }
219 Ok::<(), crate::Error>(())
220 };
221
222 tokio::runtime::Builder::new_current_thread()
223 .enable_all()
224 .build()
225 .map_err(crate::Error::Io)?
226 .block_on(future)
227 }
228
229 fn init_paseto_manager<P: AsRef<Path>>(db_path: P) -> Result<PasetoTokenManager, crate::Error> {
231 use std::fs;
232 use std::io::{Read, Write};
233
234 if let Ok(key_hex) = std::env::var("PASETO_SIGNING_KEY") {
236 if let Ok(key_bytes) = hex::decode(&key_hex) {
237 if key_bytes.len() == 32 {
238 let mut key_array = [0u8; 32];
239 key_array.copy_from_slice(&key_bytes);
240 let key = crate::paseto_tokens::PasetoKey::from_bytes(key_array);
241 return Ok(PasetoTokenManager::with_key(
242 key,
243 std::time::Duration::from_secs(15 * 60),
244 ));
245 }
246 }
247 warn!("Invalid PASETO_SIGNING_KEY environment variable, generating new key");
248 }
249
250 let key_path = db_path.as_ref().join(".paseto_key");
252 if key_path.exists() {
253 let mut file = fs::File::open(&key_path).map_err(crate::Error::Io)?;
254 let mut key_bytes = vec![];
255 file.read_to_end(&mut key_bytes).map_err(crate::Error::Io)?;
256
257 if key_bytes.len() == 32 {
258 let mut key_array = [0u8; 32];
259 key_array.copy_from_slice(&key_bytes);
260 let key = crate::paseto_tokens::PasetoKey::from_bytes(key_array);
261 info!("Loaded existing Paseto signing key from disk");
262 return Ok(PasetoTokenManager::with_key(
263 key,
264 std::time::Duration::from_secs(15 * 60),
265 ));
266 }
267 }
268
269 let manager = PasetoTokenManager::new(std::time::Duration::from_secs(15 * 60));
271 let key = manager.get_key();
272
273 let mut file = fs::File::create(&key_path).map_err(crate::Error::Io)?;
275 let key_bytes = key.as_bytes();
276 file.write_all(&key_bytes).map_err(crate::Error::Io)?;
277 file.sync_all().map_err(crate::Error::Io)?;
278
279 #[cfg(unix)]
281 {
282 use std::os::unix::fs::PermissionsExt;
283 let permissions = std::fs::Permissions::from_mode(0o600);
284 fs::set_permissions(&key_path, permissions).map_err(crate::Error::Io)?;
285 }
286
287 info!("Generated and saved new Paseto signing key");
288 Ok(manager)
289 }
290
291 fn init_tls_envelope<P: AsRef<Path>>(db_path: P) -> Result<TlsEnvelope, crate::Error> {
293 let envelope_key = init_tls_envelope_key(&db_path)
294 .map_err(|e| crate::Error::Io(std::io::Error::other(e.to_string())))?;
295
296 Ok(TlsEnvelope::with_key(envelope_key))
297 }
298
299 pub fn router(self) -> Router<()> {
300 let runtime = self.tls_runtime.clone();
301 let state = AuthenticatedProxyState {
302 http_client: self.http_client,
303 http2_client: self.http2_client,
304 tls_client_manager: self.tls_client_manager,
305 db: self.db,
306 paseto_manager: self.paseto_manager,
307 tls_envelope: self.tls_envelope,
308 mtls_listener_address: None,
309 #[cfg(feature = "standalone")]
310 mtls_listener_handle: None,
311 tls_runtime: runtime.clone(),
312 };
313 let router = Router::new()
314 .nest("/v1", Self::internal_api_router_v1())
315 .fallback(any(unified_proxy))
316 .layer(SetRequestIdLayer::new(
317 header::HeaderName::from_static("x-request-id"),
318 MakeRequestUuid,
319 ))
320 .layer(PropagateRequestIdLayer::new(
322 header::HeaderName::from_static("x-request-id"),
323 ))
324 .layer(SetSensitiveRequestHeadersLayer::new(once(
325 header::AUTHORIZATION,
326 )))
327 .layer(
328 TraceLayer::new_for_http()
329 .make_span_with(DefaultMakeSpan::new().include_headers(true))
330 .on_response(DefaultOnResponse::new().include_headers(true)),
331 )
332 .layer(CorsLayer::permissive())
333 .layer(SetSensitiveResponseHeadersLayer::new(once(
334 header::AUTHORIZATION,
335 )))
336 .with_state(state);
337
338 runtime.install_router(router.clone());
339 router
340 }
341
342 pub fn db(&self) -> RocksDb {
343 self.db.clone()
344 }
345
346 pub fn internal_api_router_v1() -> Router<AuthenticatedProxyState> {
348 Router::new()
349 .route("/auth/challenge", post(auth_challenge))
350 .route("/auth/verify", post(auth_verify))
351 .route("/auth/exchange", post(auth_exchange))
352 .route("/oauth/token", post(crate::oauth::token::oauth_token))
354 .route(
356 "/admin/services/{service_id}/tls-profile",
357 axum::routing::put(update_tls_profile),
358 )
359 .route("/auth/certificates", axum::routing::post(issue_certificate))
360 }
361}
362
363async fn auth_challenge(
365 service_id: ServiceId,
366 State(s): State<AuthenticatedProxyState>,
367 Json(payload): Json<crate::types::ChallengeRequest>,
368) -> Result<Json<crate::types::ChallengeResponse>, StatusCode> {
369 let mut rng = blueprint_std::BlueprintRng::new();
370 let service = ServiceModel::find_by_id(service_id, &s.db)
371 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
372 .ok_or(StatusCode::NOT_FOUND)?;
373
374 let public_key = payload.pub_key;
375 if !service.is_owner(payload.key_type, &public_key) {
376 return Err(StatusCode::UNAUTHORIZED);
377 }
378 let challenge = crate::generate_challenge(&mut rng);
379 let now = std::time::SystemTime::now();
380 let expires_at = now
381 .duration_since(std::time::UNIX_EPOCH)
382 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
383 .add(std::time::Duration::from_secs(30))
384 .as_secs();
385 Ok(Json(crate::types::ChallengeResponse {
386 challenge,
387 expires_at,
388 }))
389}
390
391async fn auth_verify(
393 service_id: ServiceId,
394 State(s): State<AuthenticatedProxyState>,
395 Json(payload): Json<crate::types::VerifyChallengeRequest>,
396) -> impl IntoResponse {
397 let mut rng = blueprint_std::BlueprintRng::new();
398 let service = match ServiceModel::find_by_id(service_id, &s.db) {
399 Ok(Some(service)) => service,
400 Ok(None) => {
401 return (
402 StatusCode::NOT_FOUND,
403 Json(VerifyChallengeResponse::UnexpectedError {
404 message: "Service not found".to_string(),
405 }),
406 );
407 }
408 Err(e) => {
409 return (
410 StatusCode::INTERNAL_SERVER_ERROR,
411 Json(crate::types::VerifyChallengeResponse::UnexpectedError {
412 message: format!("Internal server error: {e}"),
413 }),
414 );
415 }
416 };
417
418 let public_key = payload.challenge_request.pub_key;
419 if !service.is_owner(payload.challenge_request.key_type, &public_key) {
420 return (
421 StatusCode::UNAUTHORIZED,
422 Json(crate::types::VerifyChallengeResponse::Unauthorized),
423 );
424 }
425 let result = crate::verify_challenge(
427 &payload.challenge,
428 &payload.signature,
429 &public_key,
430 payload.challenge_request.key_type,
431 );
432 match result {
433 Ok(true) => {
434 let validated_headers = match validation::validate_headers(&payload.additional_headers)
436 {
437 Ok(headers) => headers,
438 Err(e) => {
439 return (
440 StatusCode::BAD_REQUEST,
441 Json(VerifyChallengeResponse::UnexpectedError {
442 message: format!("Invalid headers: {e}"),
443 }),
444 );
445 }
446 };
447
448 let protected_headers =
450 validation::process_headers_with_pii_protection(&validated_headers);
451
452 let api_key_gen = ApiKeyGenerator::with_prefix(service.api_key_prefix());
454 let expires_at = payload.expires_at.max(
455 std::time::SystemTime::now()
456 .duration_since(std::time::UNIX_EPOCH)
457 .unwrap_or_default()
458 .as_secs()
459 + (90 * 24 * 60 * 60), );
461
462 let api_key =
463 api_key_gen.generate_key(service_id, expires_at, protected_headers, &mut rng);
464
465 let mut api_key_model = ApiKeyModel::from(&api_key);
466 if let Err(e) = api_key_model.save(&s.db) {
467 return (
468 StatusCode::INTERNAL_SERVER_ERROR,
469 Json(VerifyChallengeResponse::UnexpectedError {
470 message: format!("Internal server error: {e}"),
471 }),
472 );
473 }
474
475 (
476 StatusCode::CREATED,
477 Json(VerifyChallengeResponse::Verified {
478 api_key: api_key.full_key().to_string(),
479 expires_at,
480 }),
481 )
482 }
483 Ok(false) => (
484 StatusCode::UNAUTHORIZED,
485 Json(crate::types::VerifyChallengeResponse::InvalidSignature),
486 ),
487 Err(e) => (
488 StatusCode::INTERNAL_SERVER_ERROR,
489 Json(crate::types::VerifyChallengeResponse::UnexpectedError {
490 message: format!("Internal server error: {e}"),
491 }),
492 ),
493 }
494}
495
496async fn auth_exchange(
498 State(s): State<AuthenticatedProxyState>,
499 headers: axum::http::HeaderMap,
500 Json(payload): Json<crate::auth_token::TokenExchangeRequest>,
501) -> impl IntoResponse {
502 let auth_header = match headers.get(crate::types::headers::AUTHORIZATION) {
504 Some(header_value) => {
505 let header_str = match header_value.to_str() {
506 Ok(s) => s,
507 Err(_) => {
508 return (
509 StatusCode::BAD_REQUEST,
510 Json(serde_json::json!({
511 "error": "invalid_authorization_header",
512 "message": "Authorization header is not valid UTF-8"
513 })),
514 );
515 }
516 };
517
518 if let Some(token) = header_str.strip_prefix("Bearer ") {
520 token
521 } else {
522 return (
523 StatusCode::BAD_REQUEST,
524 Json(serde_json::json!({
525 "error": "invalid_authorization_header",
526 "message": "Authorization header must use Bearer scheme with API key"
527 })),
528 );
529 }
530 }
531 None => {
532 return (
533 StatusCode::UNAUTHORIZED,
534 Json(serde_json::json!({
535 "error": "missing_authorization_header",
536 "message": "Authorization header with Bearer API key is required"
537 })),
538 );
539 }
540 };
541
542 let key_id = if let Some((key_id_part, _)) = auth_header.split_once('.') {
544 key_id_part
545 } else {
546 return (
547 StatusCode::BAD_REQUEST,
548 Json(serde_json::json!({
549 "error": "invalid_api_key_format",
550 "message": "API key must have format ak_xxxxx.yyyyy"
551 })),
552 );
553 };
554
555 let mut api_key_model = match crate::api_keys::ApiKeyModel::find_by_key_id(key_id, &s.db) {
557 Ok(Some(model)) => model,
558 Ok(None) => {
559 return (
560 StatusCode::UNAUTHORIZED,
561 Json(serde_json::json!({
562 "error": "invalid_api_key",
563 "message": "API key not found"
564 })),
565 );
566 }
567 Err(e) => {
568 error!("Database error looking up API key: {}", e);
569 return (
570 StatusCode::INTERNAL_SERVER_ERROR,
571 Json(serde_json::json!({
572 "error": "internal_error",
573 "message": "Failed to validate API key"
574 })),
575 );
576 }
577 };
578
579 if !api_key_model.validates_key(auth_header) {
581 return (
582 StatusCode::UNAUTHORIZED,
583 Json(serde_json::json!({
584 "error": "invalid_api_key",
585 "message": "API key validation failed"
586 })),
587 );
588 }
589
590 if api_key_model.is_expired() {
592 return (
593 StatusCode::UNAUTHORIZED,
594 Json(serde_json::json!({
595 "error": "expired_api_key",
596 "message": "API key has expired"
597 })),
598 );
599 }
600
601 if !api_key_model.is_enabled {
602 return (
603 StatusCode::UNAUTHORIZED,
604 Json(serde_json::json!({
605 "error": "disabled_api_key",
606 "message": "API key is disabled"
607 })),
608 );
609 }
610
611 if let Err(e) = api_key_model.update_last_used(&s.db) {
613 warn!("Failed to update API key last_used timestamp: {}", e);
614 }
615
616 let mut headers = api_key_model.get_default_headers();
618 for (key, value) in payload.additional_headers {
619 headers.insert(key, value);
620 }
621
622 let validated_headers = match crate::validation::validate_headers(&headers) {
624 Ok(headers) => headers,
625 Err(e) => {
626 return (
627 StatusCode::BAD_REQUEST,
628 Json(serde_json::json!({
629 "error": "invalid_headers",
630 "message": format!("Header validation failed: {}", e)
631 })),
632 );
633 }
634 };
635
636 let protected_headers =
638 crate::validation::process_headers_with_pii_protection(&validated_headers);
639
640 let service_id = api_key_model.service_id();
642 let tenant_id = protected_headers.get("X-Tenant-Id").cloned();
643 let custom_ttl = payload.ttl_seconds.map(std::time::Duration::from_secs);
644
645 let access_token = match s.paseto_manager.generate_token(
646 service_id,
647 api_key_model.key_id.clone(),
648 tenant_id,
649 protected_headers,
650 custom_ttl,
651 None,
652 ) {
653 Ok(token) => token,
654 Err(e) => {
655 error!("Failed to create Paseto token: {}", e);
656 return (
657 StatusCode::INTERNAL_SERVER_ERROR,
658 Json(serde_json::json!({
659 "error": "token_generation_failed",
660 "message": "Failed to generate access token"
661 })),
662 );
663 }
664 };
665
666 let expires_at = std::time::SystemTime::now()
668 .duration_since(std::time::UNIX_EPOCH)
669 .unwrap_or_default()
670 .as_secs()
671 + custom_ttl
672 .unwrap_or(s.paseto_manager.default_ttl())
673 .as_secs();
674
675 let response = crate::auth_token::TokenExchangeResponse::new(access_token, expires_at);
676
677 (
678 StatusCode::OK,
679 Json(serde_json::to_value(response).unwrap()),
680 )
681}
682
683async fn update_tls_profile(
684 service_id: ServiceId,
685 State(s): State<AuthenticatedProxyState>,
686 Json(payload): Json<CreateTlsProfileRequest>,
687) -> Result<Json<TlsProfileResponse>, StatusCode> {
688 let envelope = s.tls_envelope_ref().clone();
689 let runtime = s.tls_runtime_ref();
690
691 let mut service = match ServiceModel::find_by_id(service_id, &s.db) {
692 Ok(Some(service)) => service,
693 Ok(None) => return Err(StatusCode::NOT_FOUND),
694 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
695 };
696
697 let existing_profile = service.tls_profile.clone();
698
699 let allowlist_to_store = payload
700 .allowed_dns_names
701 .clone()
702 .or_else(|| {
703 existing_profile
704 .as_ref()
705 .map(|p| p.allowed_dns_names.clone())
706 })
707 .unwrap_or_default();
708
709 let server_dns_names = if let Some(list) = payload.allowed_dns_names.clone() {
710 if list.is_empty() { None } else { Some(list) }
711 } else if let Some(profile) = existing_profile.as_ref() {
712 if !profile.allowed_dns_names.is_empty() {
713 Some(profile.allowed_dns_names.clone())
714 } else if let Some(existing_sni) = profile.sni.clone() {
715 Some(vec![existing_sni])
716 } else {
717 None
718 }
719 } else {
720 None
721 }
722 .unwrap_or_else(|| vec!["localhost".to_string()]);
723
724 let mut tls_profile = existing_profile.unwrap_or_else(|| TlsProfile {
725 tls_enabled: false,
726 require_client_mtls: false,
727 encrypted_server_cert: Vec::new(),
728 encrypted_server_key: Vec::new(),
729 encrypted_client_ca_bundle: Vec::new(),
730 encrypted_upstream_ca_bundle: Vec::new(),
731 encrypted_upstream_client_cert: Vec::new(),
732 encrypted_upstream_client_key: Vec::new(),
733 client_cert_ttl_hours: payload.client_cert_ttl_hours,
734 sni: None,
735 subject_alt_name_template: payload.subject_alt_name_template.clone(),
736 allowed_dns_names: allowlist_to_store.clone(),
737 });
738
739 let ca = if !tls_profile.encrypted_client_ca_bundle.is_empty() {
740 load_persisted_ca(&envelope, &tls_profile)?
741 } else {
742 CertificateAuthority::new(envelope.clone()).map_err(|e| {
743 error!("Failed to create certificate authority: {e}");
744 StatusCode::INTERNAL_SERVER_ERROR
745 })?
746 };
747
748 let mut bundle = ca.ca_certificate_pem();
749 if !bundle.ends_with('\n') {
750 bundle.push('\n');
751 }
752 bundle.push_str(&ca.ca_private_key_pem());
753
754 tls_profile.encrypted_client_ca_bundle = envelope.encrypt(bundle.as_bytes()).map_err(|e| {
755 error!("Failed to encrypt CA bundle: {e}");
756 StatusCode::INTERNAL_SERVER_ERROR
757 })?;
758
759 let (server_cert, server_key) = ca
760 .generate_server_certificate(service_id, server_dns_names.clone())
761 .map_err(|e| {
762 error!("Failed to generate server certificate: {e}");
763 StatusCode::INTERNAL_SERVER_ERROR
764 })?;
765
766 tls_profile.encrypted_server_cert = envelope.encrypt(server_cert.as_bytes()).map_err(|e| {
767 error!("Failed to encrypt server certificate: {e}");
768 StatusCode::INTERNAL_SERVER_ERROR
769 })?;
770 tls_profile.encrypted_server_key = envelope.encrypt(server_key.as_bytes()).map_err(|e| {
771 error!("Failed to encrypt server key: {e}");
772 StatusCode::INTERNAL_SERVER_ERROR
773 })?;
774
775 tls_profile.tls_enabled = true;
776 tls_profile.require_client_mtls = payload.require_client_mtls;
777 tls_profile.client_cert_ttl_hours = payload.client_cert_ttl_hours;
778 if let Some(template) = payload.subject_alt_name_template.clone() {
779 tls_profile.subject_alt_name_template = Some(template);
780 }
781 tls_profile.allowed_dns_names = allowlist_to_store.clone();
782 tls_profile.sni = server_dns_names.first().cloned();
783
784 service.tls_profile = Some(tls_profile.clone());
785 if let Err(e) = service.save(service_id, &s.db) {
786 error!("Failed to persist TLS profile: {e}");
787 return Err(StatusCode::INTERNAL_SERVER_ERROR);
788 }
789
790 let bound_addr = runtime
791 .upsert_service_profile(service_id, &tls_profile)
792 .await
793 .map_err(|e| {
794 error!("Failed to activate TLS profile: {e}");
795 StatusCode::INTERNAL_SERVER_ERROR
796 })?;
797
798 info!(
799 "TLS profile enabled for service {} with listener {}",
800 service_id, bound_addr
801 );
802
803 let listener_uri = match bound_addr.ip() {
804 IpAddr::V4(ip) if ip.is_unspecified() => {
805 format!("https://localhost:{}", bound_addr.port())
806 }
807 IpAddr::V6(ip) if ip.is_unspecified() => {
808 format!("https://[::1]:{}", bound_addr.port())
809 }
810 _ => format!("https://{bound_addr}"),
811 };
812
813 Ok(Json(TlsProfileResponse {
814 tls_enabled: true,
815 require_client_mtls: payload.require_client_mtls,
816 client_cert_ttl_hours: payload.client_cert_ttl_hours,
817 mtls_listener: listener_uri,
818 http_listener: Some(format!("http://localhost:{DEFAULT_AUTH_PROXY_PORT}")),
819 ca_certificate_pem: Some(ca.ca_certificate_pem()),
820 subject_alt_name_template: tls_profile.subject_alt_name_template.clone(),
821 allowed_dns_names: tls_profile.allowed_dns_names.clone(),
822 }))
823}
824
825async fn issue_certificate(
827 State(s): State<AuthenticatedProxyState>,
828 Json(payload): Json<IssueCertificateRequest>,
829) -> Result<Json<ClientCertificate>, StatusCode> {
830 let service_id = ServiceId::new(payload.service_id);
831
832 let service = match ServiceModel::find_by_id(service_id, &s.db) {
834 Ok(Some(service)) => service,
835 Ok(None) => return Err(StatusCode::NOT_FOUND),
836 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
837 };
838
839 let tls_profile = match &service.tls_profile {
841 Some(profile) => profile,
842 None => return Err(StatusCode::BAD_REQUEST),
843 };
844
845 if let Err(e) = validate_certificate_request(&payload, tls_profile) {
847 warn!("Certificate request validation failed: {}", e);
848 return Err(StatusCode::BAD_REQUEST);
849 }
850
851 let mut service_ref = service.clone();
853 let ca = if !tls_profile.encrypted_client_ca_bundle.is_empty() {
854 load_persisted_ca(&s.tls_envelope, tls_profile)?
855 } else {
856 let ca = CertificateAuthority::new(s.tls_envelope.clone()).map_err(|e| {
857 error!("Failed to initialise certificate authority: {e}");
858 StatusCode::INTERNAL_SERVER_ERROR
859 })?;
860
861 persist_new_ca(&s.db, service_id, &mut service_ref, &ca)?;
862 if let Some(profile) = service_ref.tls_profile.as_ref() {
863 if let Err(err) = s
864 .tls_runtime_ref()
865 .upsert_service_profile(service_id, profile)
866 .await
867 {
868 error!("Failed to refresh TLS runtime after CA creation: {err}");
869 return Err(StatusCode::INTERNAL_SERVER_ERROR);
870 }
871 }
872 ca
873 };
874
875 let mut client_cert = match ca.generate_client_certificate(
877 payload.common_name,
878 payload.subject_alt_names,
879 payload.ttl_hours,
880 ) {
881 Ok(cert) => cert,
882 Err(e) => {
883 error!("Failed to generate client certificate: {}", e);
884 return Err(StatusCode::INTERNAL_SERVER_ERROR);
885 }
886 };
887
888 if client_cert.revocation_url.is_none() {
889 client_cert.revocation_url = Some(format!(
890 "/v1/auth/certificates/{}/revoke",
891 client_cert.serial
892 ));
893 }
894
895 info!(
896 "Issued client certificate for service {} with serial {}",
897 service_id, client_cert.serial
898 );
899 Ok(Json(client_cert))
900}
901
902fn load_persisted_ca(
903 envelope: &TlsEnvelope,
904 profile: &TlsProfile,
905) -> Result<CertificateAuthority, StatusCode> {
906 let decrypted = envelope
907 .decrypt(&profile.encrypted_client_ca_bundle)
908 .map_err(|e| {
909 error!("Failed to decrypt stored CA bundle: {e}");
910 StatusCode::INTERNAL_SERVER_ERROR
911 })?;
912
913 let pem_str = String::from_utf8(decrypted).map_err(|e| {
914 error!("Stored CA bundle is not valid UTF-8: {e}");
915 StatusCode::INTERNAL_SERVER_ERROR
916 })?;
917
918 let blocks = pem::parse_many(&pem_str).map_err(|e| {
919 error!("Failed to parse stored CA bundle: {e}");
920 StatusCode::INTERNAL_SERVER_ERROR
921 })?;
922
923 let mut ca_cert_pem: Option<String> = None;
924 let mut ca_key_pem: Option<String> = None;
925
926 for block in blocks {
927 let encoded = pem::encode(&block);
928 match block.tag.as_str() {
929 "CERTIFICATE" if ca_cert_pem.is_none() => ca_cert_pem = Some(encoded),
930 tag if tag.ends_with("PRIVATE KEY") && ca_key_pem.is_none() => {
931 ca_key_pem = Some(encoded)
932 }
933 _ => {}
934 }
935 }
936
937 let cert = ca_cert_pem.ok_or_else(|| {
938 error!("Stored CA bundle missing certificate block");
939 StatusCode::INTERNAL_SERVER_ERROR
940 })?;
941 let key = ca_key_pem.ok_or_else(|| {
942 error!("Stored CA bundle missing private key block");
943 StatusCode::INTERNAL_SERVER_ERROR
944 })?;
945
946 CertificateAuthority::from_components(cert, key, clone_envelope(envelope)).map_err(|e| {
947 error!("Failed to rehydrate certificate authority: {e}");
948 StatusCode::INTERNAL_SERVER_ERROR
949 })
950}
951
952fn persist_new_ca(
953 db: &RocksDb,
954 service_id: ServiceId,
955 service: &mut ServiceModel,
956 ca: &CertificateAuthority,
957) -> Result<(), StatusCode> {
958 let mut bundle = ca.ca_certificate_pem();
959 if !bundle.ends_with('\n') {
960 bundle.push('\n');
961 }
962 bundle.push_str(&ca.ca_private_key_pem());
963
964 let encrypted = ca.envelope().encrypt(bundle.as_bytes()).map_err(|e| {
965 error!("Failed to encrypt CA bundle: {e}");
966 StatusCode::INTERNAL_SERVER_ERROR
967 })?;
968
969 if let Some(profile) = service.tls_profile.as_mut() {
970 profile.encrypted_client_ca_bundle = encrypted;
971 } else {
972 error!("TLS profile unexpectedly missing while persisting CA");
973 return Err(StatusCode::INTERNAL_SERVER_ERROR);
974 }
975
976 service.save(service_id, db).map_err(|e| {
977 error!("Failed to persist CA bundle to service: {e}");
978 StatusCode::INTERNAL_SERVER_ERROR
979 })
980}
981
982fn clone_envelope(envelope: &TlsEnvelope) -> TlsEnvelope {
983 TlsEnvelope::with_key(envelope.key().clone())
984}
985
986async fn handle_legacy_token(
988 token: crate::api_tokens::ApiToken,
989 db: &crate::db::RocksDb,
990) -> Result<
991 (
992 crate::types::ServiceId,
993 std::collections::BTreeMap<String, String>,
994 ),
995 StatusCode,
996> {
997 let (token_id, token_str) = (token.0, token.1.as_str());
998
999 let api_token = match ApiTokenModel::find_token_id(token_id, db) {
1000 Ok(Some(token)) if token.is(token_str) && !token.is_expired() && token.is_enabled => token,
1001 Ok(Some(_)) | Ok(None) => {
1002 warn!("Invalid or expired legacy token");
1003 return Err(StatusCode::UNAUTHORIZED);
1004 }
1005 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1006 };
1007
1008 let additional_headers = api_token.get_additional_headers();
1009 Ok((api_token.service_id(), additional_headers))
1010}
1011
1012async fn handle_api_key(
1014 api_key: &str,
1015 db: &crate::db::RocksDb,
1016) -> Result<
1017 (
1018 crate::types::ServiceId,
1019 std::collections::BTreeMap<String, String>,
1020 ),
1021 StatusCode,
1022> {
1023 let key_id = api_key
1025 .split_once('.')
1026 .map(|(key_id_part, _)| key_id_part)
1027 .ok_or(StatusCode::BAD_REQUEST)?;
1028
1029 let mut api_key_model = match crate::api_keys::ApiKeyModel::find_by_key_id(key_id, db) {
1031 Ok(Some(model)) => model,
1032 Ok(None) => {
1033 warn!("API key not found: {}", key_id);
1034 return Err(StatusCode::UNAUTHORIZED);
1035 }
1036 Err(e) => {
1037 error!("Database error looking up API key: {}", e);
1038 return Err(StatusCode::INTERNAL_SERVER_ERROR);
1039 }
1040 };
1041
1042 if !api_key_model.validates_key(api_key) {
1044 warn!("API key validation failed: {}", key_id);
1045 return Err(StatusCode::UNAUTHORIZED);
1046 }
1047
1048 if api_key_model.is_expired() {
1050 warn!("API key expired: {}", key_id);
1051 return Err(StatusCode::UNAUTHORIZED);
1052 }
1053
1054 if !api_key_model.is_enabled {
1055 warn!("API key disabled: {}", key_id);
1056 return Err(StatusCode::UNAUTHORIZED);
1057 }
1058
1059 if let Err(e) = api_key_model.update_last_used(db) {
1061 warn!("Failed to update API key last_used timestamp: {}", e);
1062 }
1063
1064 let additional_headers = api_key_model.get_default_headers();
1065 Ok((api_key_model.service_id(), additional_headers))
1066}
1067
1068async fn handle_paseto_token(
1070 token: &str,
1071 paseto_manager: &crate::paseto_tokens::PasetoTokenManager,
1072) -> Result<
1073 (
1074 crate::types::ServiceId,
1075 std::collections::BTreeMap<String, String>,
1076 ),
1077 StatusCode,
1078> {
1079 let claims = match paseto_manager.validate_token(token) {
1080 Ok(claims) => claims,
1081 Err(e) => {
1082 warn!("Paseto token validation failed: {}", e);
1083 return Err(StatusCode::UNAUTHORIZED);
1084 }
1085 };
1086
1087 let mut headers = claims.additional_headers.clone();
1089
1090 if let Some(scopes_vec) = claims.scopes.clone() {
1092 if !scopes_vec.is_empty() {
1093 let mut set = std::collections::BTreeSet::new();
1094 for s in scopes_vec {
1095 set.insert(s.to_lowercase());
1096 }
1097 let scopes_str = set.into_iter().collect::<Vec<_>>().join(" ");
1098 headers.insert("x-scopes".to_string(), scopes_str);
1099 }
1100 }
1101
1102 Ok((claims.service_id, headers))
1104}
1105
1106fn is_forbidden_header(header_name: &str) -> bool {
1109 let lower = header_name.to_lowercase();
1111 matches!(
1112 lower.as_str(),
1113 "host"
1114 | "content-length"
1115 | "transfer-encoding"
1116 | "connection"
1117 | "upgrade"
1118 | "proxy-authorization"
1119 | "proxy-authenticate"
1120 | "x-forwarded-host"
1121 | "x-real-ip"
1122 | "x-forwarded-for"
1123 | "x-forwarded-proto"
1124 | "forwarded"
1125 )
1126}
1127
1128fn is_auth_header(header_name: &str) -> bool {
1130 let lower = header_name.to_lowercase();
1131 lower == "authorization"
1132 || lower.starts_with("x-tenant-")
1133 || lower == "x-scope"
1134 || lower == "x-scopes"
1135}
1136
1137fn is_grpc_required_header(header_name: &str) -> bool {
1139 let lower = header_name.to_lowercase();
1140 matches!(
1141 lower.as_str(),
1142 "content-type" | "te" | "grpc-encoding" | "grpc-accept-encoding"
1143 )
1144}
1145
1146fn is_proxy_injected_header_allowed(header_name: &str) -> bool {
1149 let lower = header_name.to_lowercase();
1150 lower.starts_with("x-tenant-")
1152 || lower == "x-scope"
1153 || lower == "x-scopes"
1154 || lower == "x-request-id"
1156 || lower == "x-trace-id"
1157 || lower == "user-agent"
1158 || lower.starts_with("grpc-")
1160}
1161
1162fn validate_binary_metadata(header_name: &str, header_value: &str) -> bool {
1165 let lower = header_name.to_lowercase();
1166
1167 if !lower.ends_with("-bin") {
1169 return false;
1170 }
1171
1172 if header_value.len() > get_max_binary_metadata_size() {
1174 warn!(
1175 "Binary metadata header {} exceeds size limit: {} bytes (max: {})",
1176 header_name,
1177 header_value.len(),
1178 get_max_binary_metadata_size()
1179 );
1180 return false;
1181 }
1182
1183 if header_value.contains('\r') || header_value.contains('\n') {
1185 warn!(
1186 "Binary metadata header {} contains CRLF characters",
1187 header_name
1188 );
1189 return false;
1190 }
1191
1192 if !header_value.chars().all(|c| {
1194 c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=' || c == '-' || c == '_'
1195 }) {
1196 warn!(
1197 "Binary metadata header {} contains invalid base64 characters",
1198 header_name
1199 );
1200 return false;
1201 }
1202
1203 true
1204}
1205
1206async fn extract_and_validate_auth(
1209 headers: &axum::http::HeaderMap,
1210 db: &crate::db::RocksDb,
1211 paseto_manager: &crate::paseto_tokens::PasetoTokenManager,
1212) -> Result<
1213 (
1214 crate::types::ServiceId,
1215 std::collections::BTreeMap<String, String>,
1216 ),
1217 StatusCode,
1218> {
1219 let auth_header = headers
1220 .get(crate::types::headers::AUTHORIZATION)
1221 .and_then(|h| h.to_str().ok())
1222 .and_then(|h| h.strip_prefix("Bearer "))
1223 .ok_or_else(|| {
1224 warn!("Missing or invalid Authorization header");
1225 StatusCode::UNAUTHORIZED
1226 })?;
1227
1228 debug!("Processing auth header: {}", auth_header);
1229
1230 match crate::auth_token::AuthToken::parse(auth_header) {
1232 Ok(crate::auth_token::AuthToken::Legacy(legacy_token)) => {
1233 handle_legacy_token(legacy_token, db).await
1234 }
1235 Ok(crate::auth_token::AuthToken::ApiKey(api_key)) => handle_api_key(&api_key, db).await,
1236 Ok(crate::auth_token::AuthToken::AccessToken(_)) => {
1237 Err(StatusCode::INTERNAL_SERVER_ERROR)
1239 }
1240 Err(_) if auth_header.starts_with("v4.local.") => {
1241 handle_paseto_token(auth_header, paseto_manager).await
1243 }
1244 Err(e) => {
1245 warn!("Token parsing error for '{}': {:?}", auth_header, e);
1246 Err(StatusCode::UNAUTHORIZED)
1247 }
1248 }
1249}
1250
1251fn apply_additional_headers(
1253 req: &mut Request,
1254 additional_headers: std::collections::BTreeMap<String, String>,
1255 is_grpc: bool,
1256) {
1257 for (header_name, header_value) in additional_headers {
1258 if is_forbidden_header(&header_name) {
1260 warn!("Attempted to inject forbidden header: {}", header_name);
1261 continue;
1262 }
1263
1264 let is_tenant_header = header_name.to_lowercase().starts_with("x-tenant-");
1267 let is_scope_header =
1268 header_name.to_lowercase() == "x-scope" || header_name.to_lowercase() == "x-scopes";
1269 let is_allowed_proxy_header = is_proxy_injected_header_allowed(&header_name);
1270
1271 let is_allowed = !is_auth_header(&header_name)
1273 || (is_grpc && is_grpc_required_header(&header_name))
1274 || is_tenant_header
1275 || is_scope_header
1276 || (is_grpc && is_allowed_proxy_header);
1277
1278 if !is_allowed {
1279 continue;
1280 }
1281
1282 if let Ok(name) = header::HeaderName::from_bytes(header_name.as_bytes()) {
1283 if let Ok(value) = header::HeaderValue::from_str(&header_value) {
1285 if is_grpc
1287 && header_name.to_lowercase().ends_with("-bin")
1288 && !validate_binary_metadata(&header_name, &header_value)
1289 {
1290 warn!(
1291 "Invalid binary metadata header {}: {}",
1292 header_name, header_value
1293 );
1294 continue;
1295 }
1296
1297 if header_value.contains('\r') || header_value.contains('\n') {
1299 warn!("Header value contains CRLF: {}", header_name);
1300 continue;
1301 }
1302 req.headers_mut().insert(name, value);
1303 } else {
1304 warn!("Invalid header value for {}: {}", header_name, header_value);
1305 }
1306 } else {
1307 warn!("Invalid header name: {}", header_name);
1308 }
1309 }
1310}
1311
1312fn sanitize_request_headers(req: &mut Request, is_grpc: bool) {
1314 let mut to_remove: Vec<header::HeaderName> = Vec::new();
1315 for (name, _value) in req.headers().iter() {
1316 let name_str = name.as_str();
1317 if is_auth_header(name_str) || is_forbidden_header(name_str) {
1318 if !(is_grpc && is_grpc_required_header(name_str)) {
1320 to_remove.push(name.clone());
1321 }
1322 }
1323 }
1324 for name in to_remove {
1325 req.headers_mut().remove(name);
1326 }
1327}
1328
1329fn is_grpc_request(headers: &axum::http::HeaderMap, req: &Request) -> bool {
1331 let content_type = headers.get("content-type");
1333 let is_grpc_content_type = content_type
1334 .and_then(|ct| ct.to_str().ok())
1335 .map(|ct| {
1336 debug!("Content-Type header: {}", ct);
1337 ct.to_lowercase().starts_with("application/grpc")
1338 })
1339 .unwrap_or(false);
1340
1341 let te_header = headers.get("te");
1343 let has_te_trailers = te_header
1344 .and_then(|te| te.to_str().ok())
1345 .map(|te| {
1346 debug!("TE header: {}", te);
1347 te.to_lowercase() == "trailers"
1348 })
1349 .unwrap_or(false);
1350
1351 debug!(
1352 "gRPC detection - content-type: {:?}, is_grpc: {}, te: {:?}, has_trailers: {}",
1353 content_type, is_grpc_content_type, te_header, has_te_trailers
1354 );
1355
1356 let is_http2 = req.version() == axum::http::Version::HTTP_2;
1358 debug!("HTTP version: {:?}, is_http2: {}", req.version(), is_http2);
1359
1360 let result = is_grpc_content_type && has_te_trailers && is_http2;
1361 debug!("gRPC request detected: {}", result);
1362 result
1363}
1364
1365#[instrument(skip_all)]
1367async fn unified_proxy(
1368 headers: axum::http::HeaderMap,
1369 State(s): State<AuthenticatedProxyState>,
1370 req: Request,
1371) -> Result<Response, StatusCode> {
1372 info!("Unified proxy called with headers: {:?}", headers);
1373 info!("Request method: {}, URI: {}", req.method(), req.uri());
1374
1375 let client_cert = extract_client_cert_from_request(&req);
1377
1378 let auth_method = if client_cert.is_some() {
1380 AuthMethod::Mtls
1381 } else if headers.get("authorization").is_some() {
1382 AuthMethod::AccessToken
1383 } else if headers.get("x-api-key").is_some() {
1384 AuthMethod::ApiKey
1385 } else {
1386 AuthMethod::OAuth
1387 };
1388
1389 let is_grpc = is_grpc_request(&headers, &req);
1390
1391 if is_grpc {
1392 info!("Detected gRPC request, using gRPC proxy path");
1393 grpc_proxy_with_mtls(headers, State(s), req, client_cert, auth_method).await
1394 } else {
1395 info!("Detected HTTP request, using HTTP proxy path");
1396 reverse_proxy_with_mtls(headers, State(s), req, client_cert, auth_method).await
1397 }
1398}
1399
1400#[instrument(skip_all)]
1402async fn grpc_proxy_with_mtls(
1403 headers: axum::http::HeaderMap,
1404 State(s): State<AuthenticatedProxyState>,
1405 mut req: Request,
1406 client_cert: Option<crate::tls_listener::ClientCertInfo>,
1407 _auth_method: AuthMethod,
1408) -> Result<Response, StatusCode> {
1409 debug!("gRPC proxy with mTLS called with headers: {:?}", headers);
1410
1411 let (service_id, mut additional_headers) =
1413 extract_and_validate_auth(&headers, &s.db, &s.paseto_manager).await?;
1414
1415 let service = match ServiceModel::find_by_id(service_id, &s.db) {
1417 Ok(Some(service)) => service,
1418 Ok(None) => return Err(StatusCode::NOT_FOUND),
1419 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1420 };
1421
1422 if let Some(tls_profile) = &service.tls_profile {
1424 if tls_profile.require_client_mtls && client_cert.is_none() {
1425 warn!(
1426 "mTLS required but no client certificate provided for service {}",
1427 service_id
1428 );
1429 return Err(StatusCode::UNAUTHORIZED);
1430 }
1431 }
1432
1433 if let Some(cert) = &client_cert {
1435 additional_headers.insert("x-client-cert-subject".to_string(), cert.subject.clone());
1436 additional_headers.insert("x-client-cert-issuer".to_string(), cert.issuer.clone());
1437 additional_headers.insert("x-client-cert-serial".to_string(), cert.serial.clone());
1438 additional_headers.insert("x-auth-method".to_string(), "mtls".to_string());
1439 }
1440
1441 let target_host = service
1442 .upstream_url()
1443 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1444
1445 debug!("Target host: {:?}", target_host);
1446
1447 let path = req.uri().path();
1448 let path_query = req
1449 .uri()
1450 .path_and_query()
1451 .map(|v| v.as_str())
1452 .unwrap_or(path);
1453 let target_uri = Uri::builder()
1454 .scheme(target_host.scheme().cloned().unwrap_or(uri::Scheme::HTTP))
1455 .authority(
1456 target_host
1457 .authority()
1458 .cloned()
1459 .unwrap_or(uri::Authority::from_static("localhost")),
1460 )
1461 .path_and_query(path_query)
1462 .build()
1463 .map_err(|_| StatusCode::BAD_REQUEST)?;
1464
1465 debug!("Target URI: {:?}", target_uri);
1466
1467 *req.uri_mut() = target_uri;
1469
1470 sanitize_request_headers(&mut req, true);
1472 apply_additional_headers(&mut req, additional_headers, true);
1473
1474 debug!("Forwarding gRPC request with headers: {:?}", req.headers());
1475
1476 let use_tls = target_host.scheme() == Some(&uri::Scheme::HTTPS);
1478
1479 let response = if use_tls {
1481 let tls_client = s
1483 .tls_client_manager
1484 .get_client_for_service(service_id)
1485 .await
1486 .map_err(|e| {
1487 error!(
1488 "Failed to get TLS client for service {}: {:?}",
1489 service_id, e
1490 );
1491 StatusCode::INTERNAL_SERVER_ERROR
1492 })?;
1493
1494 let (parts, body) = req.into_parts();
1496 let req_with_incoming = Request::from_parts(parts, body);
1497 tls_client
1498 .http2_client
1499 .request(req_with_incoming)
1500 .await
1501 .map_err(|e| {
1502 error!("Failed to forward gRPC request with TLS: {:?}", e);
1503 StatusCode::BAD_GATEWAY
1504 })?
1505 } else {
1506 s.http2_client.request(req).await.map_err(|e| {
1508 error!("Failed to forward gRPC request: {:?}", e);
1509 StatusCode::BAD_GATEWAY
1510 })?
1511 };
1512
1513 debug!("gRPC response received: {:?}", response);
1514
1515 Ok(response.into_response())
1516}
1517
1518#[instrument(skip_all)]
1520async fn reverse_proxy_with_mtls(
1521 headers: axum::http::HeaderMap,
1522 State(s): State<AuthenticatedProxyState>,
1523 mut req: Request,
1524 client_cert: Option<crate::tls_listener::ClientCertInfo>,
1525 _auth_method: AuthMethod,
1526) -> Result<Response, StatusCode> {
1527 let (service_id, mut additional_headers) =
1529 extract_and_validate_auth(&headers, &s.db, &s.paseto_manager).await?;
1530
1531 let service = match ServiceModel::find_by_id(service_id, &s.db) {
1533 Ok(Some(service)) => service,
1534 Ok(None) => return Err(StatusCode::NOT_FOUND),
1535 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1536 };
1537
1538 if let Some(tls_profile) = &service.tls_profile {
1539 if tls_profile.require_client_mtls && client_cert.is_none() {
1540 warn!(
1541 "mTLS required but no client certificate provided for service {}",
1542 service_id
1543 );
1544 return Err(StatusCode::UNAUTHORIZED);
1545 }
1546 }
1547
1548 if let Some(cert) = &client_cert {
1550 additional_headers.insert("x-client-cert-subject".to_string(), cert.subject.clone());
1551 additional_headers.insert("x-client-cert-issuer".to_string(), cert.issuer.clone());
1552 additional_headers.insert("x-client-cert-serial".to_string(), cert.serial.clone());
1553 additional_headers.insert("x-auth-method".to_string(), "mtls".to_string());
1554 }
1555
1556 let target_host = service
1557 .upstream_url()
1558 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1559
1560 let path = req.uri().path();
1561 let path_query = req
1562 .uri()
1563 .path_and_query()
1564 .map(|v| v.as_str())
1565 .unwrap_or(path);
1566 let target_uri = Uri::builder()
1567 .scheme(target_host.scheme().cloned().unwrap_or(uri::Scheme::HTTP))
1568 .authority(
1569 target_host
1570 .authority()
1571 .cloned()
1572 .unwrap_or(uri::Authority::from_static("localhost")),
1573 )
1574 .path_and_query(path_query)
1575 .build()
1576 .map_err(|_| StatusCode::BAD_REQUEST)?;
1577
1578 *req.uri_mut() = target_uri;
1580
1581 sanitize_request_headers(&mut req, false);
1583 apply_additional_headers(&mut req, additional_headers, false);
1584
1585 let use_tls = target_host.scheme() == Some(&uri::Scheme::HTTPS);
1587
1588 let response = if use_tls {
1590 let tls_client = s
1592 .tls_client_manager
1593 .get_client_for_service(service_id)
1594 .await
1595 .map_err(|e| {
1596 error!(
1597 "Failed to get TLS client for service {}: {:?}",
1598 service_id, e
1599 );
1600 StatusCode::INTERNAL_SERVER_ERROR
1601 })?;
1602
1603 let (parts, body) = req.into_parts();
1605 let req_with_incoming = Request::from_parts(parts, body);
1606 tls_client
1607 .http_client
1608 .request(req_with_incoming)
1609 .await
1610 .map_err(|e| {
1611 error!("Failed to forward HTTP request with TLS: {:?}", e);
1612 StatusCode::BAD_GATEWAY
1613 })?
1614 } else {
1615 s.http_client
1617 .request(req)
1618 .await
1619 .map_err(|_| StatusCode::BAD_GATEWAY)?
1620 };
1621
1622 Ok(response.into_response())
1623}
1624
1625#[cfg(test)]
1626mod tests {
1627 use std::collections::BTreeMap;
1628 use std::net::Ipv4Addr;
1629
1630 use tempfile::tempdir;
1631
1632 use super::*;
1633 use crate::{
1634 test_client::TestClient,
1635 types::{ChallengeRequest, ChallengeResponse, KeyType, VerifyChallengeResponse, headers},
1636 };
1637
1638 #[tokio::test]
1639 async fn auth_flow_works() {
1640 let _guard = tracing::subscriber::set_default(
1641 tracing_subscriber::fmt()
1642 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1643 .with_line_number(true)
1644 .with_file(true)
1645 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
1646 .with_test_writer()
1647 .finish(),
1648 );
1649 let mut rng = blueprint_std::BlueprintRng::new();
1650 let tmp = tempdir().unwrap();
1651 let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1652
1653 let hello_world_router =
1655 Router::new().route("/hello", axum::routing::get(|| async { "Hello, World!" }));
1656
1657 let (hello_world_server, local_addr) = {
1659 let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
1660 .await
1661 .expect("Failed to bind to address");
1662 let server = axum::serve(listener, hello_world_router);
1663 let local_address = server.local_addr().unwrap();
1664 let handle = tokio::spawn(async move {
1665 if let Err(e) = server.await {
1666 eprintln!("Hello world server error: {e}");
1667 }
1668 });
1669 (handle, local_address)
1670 };
1671
1672 let service_id = ServiceId::new(0);
1674 let mut service = crate::models::ServiceModel {
1675 api_key_prefix: "test_".to_string(),
1676 owners: Vec::new(),
1677 upstream_url: format!("http://localhost:{}", local_addr.port()),
1678 tls_profile: None,
1679 };
1680
1681 let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1682 let public_key = signing_key.verifying_key().to_sec1_bytes();
1683
1684 service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1686 service.save(service_id, &proxy.db).unwrap();
1687
1688 let router = proxy.router();
1689 let client = TestClient::new(router);
1690
1691 let req = ChallengeRequest {
1693 pub_key: public_key.clone().into(),
1694 key_type: KeyType::Ecdsa,
1695 };
1696
1697 let res = client
1698 .post("/v1/auth/challenge")
1699 .header(headers::X_SERVICE_ID, service_id.to_string())
1700 .json(&req)
1701 .await;
1702
1703 let res: ChallengeResponse = res.json().await;
1704
1705 let (signature, _) = signing_key
1707 .sign_prehash_recoverable(&res.challenge)
1708 .unwrap();
1709 assert!(
1711 crate::verify_challenge(
1712 &res.challenge,
1713 &signature.to_vec(),
1714 &public_key,
1715 KeyType::Ecdsa
1716 )
1717 .unwrap()
1718 );
1719
1720 let req = crate::types::VerifyChallengeRequest {
1722 challenge: res.challenge,
1723 signature: signature.to_bytes().into(),
1724 challenge_request: req,
1725 expires_at: 0,
1726 additional_headers: BTreeMap::new(),
1727 };
1728
1729 let res = client
1730 .post("/v1/auth/verify")
1731 .header(headers::X_SERVICE_ID, ServiceId::new(0).to_string())
1732 .json(&req)
1733 .await;
1734 let res: VerifyChallengeResponse = res.json().await;
1735
1736 assert!(matches!(res, VerifyChallengeResponse::Verified { .. }));
1737 let api_key = match res {
1738 VerifyChallengeResponse::Verified { api_key, .. } => api_key,
1739 _ => panic!("Expected a verified response"),
1740 };
1741
1742 let res = client
1744 .get("/hello")
1745 .header(headers::AUTHORIZATION, format!("Bearer {api_key}"))
1746 .await;
1747 assert!(
1748 res.status().is_success(),
1749 "Request to reverse proxy failed: {res:?}",
1750 );
1751
1752 hello_world_server.abort(); }
1754
1755 #[tokio::test]
1756 async fn auth_flow_with_additional_headers() {
1757 use std::collections::BTreeMap;
1758
1759 let _guard = tracing::subscriber::set_default(
1760 tracing_subscriber::fmt()
1761 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1762 .with_line_number(true)
1763 .with_file(true)
1764 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
1765 .with_test_writer()
1766 .finish(),
1767 );
1768 let mut rng = blueprint_std::BlueprintRng::new();
1769 let tmp = tempdir().unwrap();
1770 let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1771
1772 let echo_router = Router::new().route(
1774 "/echo",
1775 axum::routing::get(|headers: axum::http::HeaderMap| async move {
1776 let mut response_headers = BTreeMap::new();
1777 for (name, value) in headers.iter() {
1778 if name.as_str().starts_with("x-tenant-")
1779 || name.as_str().starts_with("X-Tenant-")
1780 {
1781 response_headers
1782 .insert(name.to_string(), value.to_str().unwrap_or("").to_string());
1783 }
1784 }
1785 axum::Json(response_headers)
1786 }),
1787 );
1788
1789 let (echo_server, local_addr) = {
1791 let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
1792 .await
1793 .expect("Failed to bind to address");
1794 let server = axum::serve(listener, echo_router);
1795 let local_address = server.local_addr().unwrap();
1796 let handle = tokio::spawn(async move {
1797 if let Err(e) = server.await {
1798 eprintln!("Echo server error: {e}");
1799 }
1800 });
1801 (handle, local_address)
1802 };
1803
1804 let service_id = ServiceId::new(1);
1806 let mut service = crate::models::ServiceModel {
1807 api_key_prefix: "test_".to_string(),
1808 owners: Vec::new(),
1809 upstream_url: format!("http://localhost:{}", local_addr.port()),
1810 tls_profile: None,
1811 };
1812
1813 let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1814 let public_key = signing_key.verifying_key().to_sec1_bytes();
1815
1816 service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1817 service.save(service_id, &proxy.db).unwrap();
1818
1819 let router = proxy.router();
1820 let client = TestClient::new(router);
1821
1822 let req = ChallengeRequest {
1824 pub_key: public_key.clone().into(),
1825 key_type: KeyType::Ecdsa,
1826 };
1827
1828 let res = client
1829 .post("/v1/auth/challenge")
1830 .header(headers::X_SERVICE_ID, service_id.to_string())
1831 .json(&req)
1832 .await;
1833
1834 let res: ChallengeResponse = res.json().await;
1835
1836 let (signature, _) = signing_key
1838 .sign_prehash_recoverable(&res.challenge)
1839 .unwrap();
1840
1841 let mut additional_headers = BTreeMap::new();
1843 let user_id = "user123@example.com";
1844 let tenant_id = crate::validation::hash_user_id(user_id);
1845 additional_headers.insert("X-Tenant-Id".to_string(), tenant_id.clone());
1846 additional_headers.insert("X-Tenant-Name".to_string(), "Acme Corp".to_string());
1847
1848 let req = crate::types::VerifyChallengeRequest {
1849 challenge: res.challenge,
1850 signature: signature.to_bytes().into(),
1851 challenge_request: req,
1852 expires_at: 0,
1853 additional_headers,
1854 };
1855
1856 let res = client
1857 .post("/v1/auth/verify")
1858 .header(headers::X_SERVICE_ID, service_id.to_string())
1859 .json(&req)
1860 .await;
1861 let res: VerifyChallengeResponse = res.json().await;
1862
1863 assert!(matches!(res, VerifyChallengeResponse::Verified { .. }));
1864 let api_key = match res {
1865 VerifyChallengeResponse::Verified { api_key, .. } => api_key,
1866 _ => panic!("Expected a verified response"),
1867 };
1868
1869 let res = client
1871 .get("/echo")
1872 .header(headers::AUTHORIZATION, format!("Bearer {api_key}"))
1873 .await;
1874
1875 assert!(res.status().is_success());
1876
1877 let response_headers: BTreeMap<String, String> = res.json().await;
1878 assert_eq!(response_headers.get("x-tenant-id"), Some(&tenant_id));
1879 assert_eq!(
1880 response_headers.get("x-tenant-name"),
1881 Some(&"Acme Corp".to_string())
1882 );
1883
1884 echo_server.abort();
1885 }
1886
1887 #[tokio::test]
1888 async fn auth_flow_rejects_invalid_headers() {
1889 use std::collections::BTreeMap;
1890
1891 let mut rng = blueprint_std::BlueprintRng::new();
1892 let tmp = tempdir().unwrap();
1893 let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1894
1895 let service_id = ServiceId::new(2);
1896 let mut service = crate::models::ServiceModel {
1897 api_key_prefix: "test_".to_string(),
1898 owners: Vec::new(),
1899 upstream_url: "http://localhost:9999".to_string(),
1900 tls_profile: None,
1901 };
1902
1903 let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1904 let public_key = signing_key.verifying_key().to_sec1_bytes();
1905
1906 service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1907 service.save(service_id, &proxy.db).unwrap();
1908
1909 let router = proxy.router();
1910 let client = TestClient::new(router);
1911
1912 let req = ChallengeRequest {
1914 pub_key: public_key.clone().into(),
1915 key_type: KeyType::Ecdsa,
1916 };
1917
1918 let res = client
1919 .post("/v1/auth/challenge")
1920 .header(headers::X_SERVICE_ID, service_id.to_string())
1921 .json(&req)
1922 .await;
1923
1924 let res: ChallengeResponse = res.json().await;
1925
1926 let (signature, _) = signing_key
1927 .sign_prehash_recoverable(&res.challenge)
1928 .unwrap();
1929
1930 let mut additional_headers = BTreeMap::new();
1932 additional_headers.insert("Connection".to_string(), "close".to_string());
1933
1934 let req = crate::types::VerifyChallengeRequest {
1935 challenge: res.challenge,
1936 signature: signature.to_bytes().into(),
1937 challenge_request: req,
1938 expires_at: 0,
1939 additional_headers,
1940 };
1941
1942 let res = client
1943 .post("/v1/auth/verify")
1944 .header(headers::X_SERVICE_ID, service_id.to_string())
1945 .json(&req)
1946 .await;
1947
1948 let res: VerifyChallengeResponse = res.json().await;
1949
1950 assert!(
1952 matches!(res, VerifyChallengeResponse::UnexpectedError { message } if message.contains("Invalid headers"))
1953 );
1954 }
1955}