blueprint_auth/
proxy.rs

1use core::iter::once;
2use core::ops::Add;
3use std::net::IpAddr;
4use std::path::Path;
5
6use axum::Json;
7use axum::http::uri;
8use axum::{
9    Router,
10    body::Body,
11    extract::{Request, State},
12    http::StatusCode,
13    http::header,
14    http::uri::Uri,
15    response::{IntoResponse, Response},
16    routing::any,
17    routing::post,
18};
19use blueprint_core::{debug, error, info, warn};
20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
21use tower_http::cors::CorsLayer;
22use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
23use tower_http::sensitive_headers::{
24    SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer,
25};
26use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
27use tracing::instrument;
28
29use crate::api_keys::{ApiKeyGenerator, ApiKeyModel};
30use crate::certificate_authority::{
31    CertificateAuthority, ClientCertificate, CreateTlsProfileRequest, IssueCertificateRequest,
32    TlsProfileResponse, validate_certificate_request,
33};
34use crate::db::RocksDb;
35use crate::models::{ApiTokenModel, ServiceModel, TlsProfile};
36use crate::paseto_tokens::PasetoTokenManager;
37use crate::request_extensions::{AuthMethod, extract_client_cert_from_request};
38use crate::tls_client::TlsClientManager;
39use crate::tls_envelope::{TlsEnvelope, init_tls_envelope_key};
40use crate::tls_listener::{TlsListenerConfig, TlsListenerManager};
41use pem;
42use prost::Message;
43
44/// Maximum size for binary metadata headers in bytes
45/// Configurable via build-time environment variable GRPC_BINARY_METADATA_MAX_SIZE
46/// Default value is 16384 bytes (16KB) if not specified
47/// Note: This uses a static variable instead of const to allow runtime configuration
48static GRPC_BINARY_METADATA_MAX_SIZE: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
49
50fn get_max_binary_metadata_size() -> usize {
51    *GRPC_BINARY_METADATA_MAX_SIZE.get_or_init(|| {
52        std::env::var("GRPC_BINARY_METADATA_MAX_SIZE")
53            .ok()
54            .and_then(|s| s.parse().ok())
55            .unwrap_or(16384) // 16KB default
56    })
57}
58use crate::types::{ServiceId, VerifyChallengeResponse};
59use crate::validation;
60
61type HTTPClient =
62    hyper_util::client::legacy::Client<hyper_util::client::legacy::connect::HttpConnector, Body>;
63type HTTP2Client =
64    hyper_util::client::legacy::Client<hyper_util::client::legacy::connect::HttpConnector, Body>;
65
66/// The default port for the authenticated proxy server
67// T9 Mapping of TBPM (Tangle Blueprint Manager)
68pub const DEFAULT_AUTH_PROXY_PORT: u16 = 8276;
69
70pub struct AuthenticatedProxy {
71    http_client: HTTPClient,
72    http2_client: HTTP2Client,
73    tls_client_manager: TlsClientManager,
74    db: crate::db::RocksDb,
75    paseto_manager: PasetoTokenManager,
76    tls_envelope: TlsEnvelope,
77    tls_runtime: TlsListenerManager,
78}
79
80#[derive(Clone, Debug)]
81pub struct AuthenticatedProxyState {
82    http_client: HTTPClient,
83    http2_client: HTTP2Client,
84    tls_client_manager: TlsClientManager,
85    db: crate::db::RocksDb,
86    paseto_manager: PasetoTokenManager,
87    tls_envelope: TlsEnvelope,
88    mtls_listener_address: Option<std::net::SocketAddr>,
89    #[cfg(feature = "standalone")]
90    mtls_listener_handle: Option<std::sync::Arc<tokio::task::JoinHandle<()>>>,
91    tls_runtime: TlsListenerManager,
92}
93
94impl AuthenticatedProxyState {
95    pub fn db_ref(&self) -> &crate::db::RocksDb {
96        &self.db
97    }
98    pub fn paseto_manager_ref(&self) -> &PasetoTokenManager {
99        &self.paseto_manager
100    }
101    pub fn tls_envelope_ref(&self) -> &TlsEnvelope {
102        &self.tls_envelope
103    }
104    pub fn tls_client_manager_ref(&self) -> &TlsClientManager {
105        &self.tls_client_manager
106    }
107    pub fn tls_runtime_ref(&self) -> TlsListenerManager {
108        self.tls_runtime.clone()
109    }
110
111    #[cfg(feature = "standalone")]
112    pub fn set_mtls_listener_address(&mut self, addr: std::net::SocketAddr) {
113        self.mtls_listener_address = Some(addr);
114    }
115
116    #[cfg(not(feature = "standalone"))]
117    pub fn set_mtls_listener_address(&mut self, _addr: std::net::SocketAddr) {
118        // No-op when standalone feature is not enabled
119    }
120
121    pub fn get_mtls_listener_address(&self) -> Option<std::net::SocketAddr> {
122        self.mtls_listener_address
123    }
124}
125
126impl AuthenticatedProxy {
127    pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self, crate::Error> {
128        let executer = TokioExecutor::new();
129
130        // Configure HTTP connector for HTTP/1.1 support (fallback for non-TLS)
131        let mut http_connector = HttpConnector::new();
132        http_connector.enforce_http(false); // Allow both HTTP and HTTPS
133        http_connector.set_nodelay(true); // Improve performance
134
135        // Build HTTP/1.1 client for REST requests (fallback for non-TLS)
136        let http_client: HTTPClient = hyper_util::client::legacy::Builder::new(executer.clone())
137            .http2_only(false) // Allow HTTP/1.1 only
138            .build(http_connector.clone());
139
140        // Configure HTTP connector for HTTP/2 support (fallback for non-TLS)
141        let mut http2_connector = HttpConnector::new();
142        http2_connector.enforce_http(false); // Allow both HTTP and HTTPS
143        http2_connector.set_nodelay(true); // Improve performance for gRPC
144
145        // Build HTTP/2 client for gRPC requests (fallback for non-TLS)
146        let http2_client: HTTP2Client = hyper_util::client::legacy::Builder::new(executer)
147            .http2_only(true) // Use HTTP/2 only for gRPC compatibility
148            .http2_adaptive_window(true) // Enable adaptive flow control for better gRPC performance
149            .build(http2_connector);
150
151        let db_config = crate::db::RocksDbConfig::default();
152        let db = crate::db::RocksDb::open(&db_path, &db_config)?;
153
154        // Initialize TLS envelope with persistent key
155        let tls_envelope = Self::init_tls_envelope(&db_path)?;
156
157        let tls_runtime = TlsListenerManager::new(
158            db.clone(),
159            tls_envelope.clone(),
160            TlsListenerConfig::default(),
161        );
162
163        Self::hydrate_tls_runtime(&tls_runtime, &db)?;
164
165        // Initialize TLS client manager
166        let tls_client_manager = TlsClientManager::new(db.clone());
167
168        // Initialize Paseto token manager with persistent key
169        let paseto_manager = Self::init_paseto_manager(&db_path)?;
170
171        Ok(AuthenticatedProxy {
172            http_client,
173            http2_client,
174            tls_client_manager,
175            db,
176            paseto_manager,
177            tls_envelope,
178            tls_runtime,
179        })
180    }
181
182    fn hydrate_tls_runtime(runtime: &TlsListenerManager, db: &RocksDb) -> Result<(), crate::Error> {
183        use crate::db::cf;
184        use rocksdb::IteratorMode;
185
186        let cf_handle = db
187            .cf_handle(cf::SERVICES_USER_KEYS_CF)
188            .ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
189        let iter = db.iterator_cf(&cf_handle, IteratorMode::Start);
190        let mut profiles = Vec::new();
191
192        for item in iter {
193            let (key, value) = item?;
194            if key.len() < 16 {
195                continue;
196            }
197            let mut id_bytes = [0u8; 16];
198            id_bytes.copy_from_slice(&key[..16]);
199            let service_id = ServiceId::from_be_bytes(id_bytes);
200            let service = ServiceModel::decode(value.as_ref())?;
201            if let Some(profile) = service.tls_profile() {
202                if profile.tls_enabled {
203                    profiles.push((service_id, profile.clone()));
204                }
205            }
206        }
207
208        if profiles.is_empty() {
209            return Ok(());
210        }
211
212        let runtime_clone = runtime.clone();
213        let future = async move {
214            for (service_id, profile) in profiles {
215                runtime_clone
216                    .load_service_profile(service_id, &profile)
217                    .await?;
218            }
219            Ok::<(), crate::Error>(())
220        };
221
222        tokio::runtime::Builder::new_current_thread()
223            .enable_all()
224            .build()
225            .map_err(crate::Error::Io)?
226            .block_on(future)
227    }
228
229    /// Initialize Paseto token manager with persistent key
230    fn init_paseto_manager<P: AsRef<Path>>(db_path: P) -> Result<PasetoTokenManager, crate::Error> {
231        use std::fs;
232        use std::io::{Read, Write};
233
234        // Try to load key from environment variable first
235        if let Ok(key_hex) = std::env::var("PASETO_SIGNING_KEY") {
236            if let Ok(key_bytes) = hex::decode(&key_hex) {
237                if key_bytes.len() == 32 {
238                    let mut key_array = [0u8; 32];
239                    key_array.copy_from_slice(&key_bytes);
240                    let key = crate::paseto_tokens::PasetoKey::from_bytes(key_array);
241                    return Ok(PasetoTokenManager::with_key(
242                        key,
243                        std::time::Duration::from_secs(15 * 60),
244                    ));
245                }
246            }
247            warn!("Invalid PASETO_SIGNING_KEY environment variable, generating new key");
248        }
249
250        // Try to load key from file in db directory
251        let key_path = db_path.as_ref().join(".paseto_key");
252        if key_path.exists() {
253            let mut file = fs::File::open(&key_path).map_err(crate::Error::Io)?;
254            let mut key_bytes = vec![];
255            file.read_to_end(&mut key_bytes).map_err(crate::Error::Io)?;
256
257            if key_bytes.len() == 32 {
258                let mut key_array = [0u8; 32];
259                key_array.copy_from_slice(&key_bytes);
260                let key = crate::paseto_tokens::PasetoKey::from_bytes(key_array);
261                info!("Loaded existing Paseto signing key from disk");
262                return Ok(PasetoTokenManager::with_key(
263                    key,
264                    std::time::Duration::from_secs(15 * 60),
265                ));
266            }
267        }
268
269        // Generate new key and save it
270        let manager = PasetoTokenManager::new(std::time::Duration::from_secs(15 * 60));
271        let key = manager.get_key();
272
273        // Save key to file
274        let mut file = fs::File::create(&key_path).map_err(crate::Error::Io)?;
275        let key_bytes = key.as_bytes();
276        file.write_all(&key_bytes).map_err(crate::Error::Io)?;
277        file.sync_all().map_err(crate::Error::Io)?;
278
279        // Set restrictive permissions on the key file (Unix only)
280        #[cfg(unix)]
281        {
282            use std::os::unix::fs::PermissionsExt;
283            let permissions = std::fs::Permissions::from_mode(0o600);
284            fs::set_permissions(&key_path, permissions).map_err(crate::Error::Io)?;
285        }
286
287        info!("Generated and saved new Paseto signing key");
288        Ok(manager)
289    }
290
291    /// Initialize TLS envelope with persistent key
292    fn init_tls_envelope<P: AsRef<Path>>(db_path: P) -> Result<TlsEnvelope, crate::Error> {
293        let envelope_key = init_tls_envelope_key(&db_path)
294            .map_err(|e| crate::Error::Io(std::io::Error::other(e.to_string())))?;
295
296        Ok(TlsEnvelope::with_key(envelope_key))
297    }
298
299    pub fn router(self) -> Router<()> {
300        let runtime = self.tls_runtime.clone();
301        let state = AuthenticatedProxyState {
302            http_client: self.http_client,
303            http2_client: self.http2_client,
304            tls_client_manager: self.tls_client_manager,
305            db: self.db,
306            paseto_manager: self.paseto_manager,
307            tls_envelope: self.tls_envelope,
308            mtls_listener_address: None,
309            #[cfg(feature = "standalone")]
310            mtls_listener_handle: None,
311            tls_runtime: runtime.clone(),
312        };
313        let router = Router::new()
314            .nest("/v1", Self::internal_api_router_v1())
315            .fallback(any(unified_proxy))
316            .layer(SetRequestIdLayer::new(
317                header::HeaderName::from_static("x-request-id"),
318                MakeRequestUuid,
319            ))
320            // propagate the header to the response before the response reaches `TraceLayer`
321            .layer(PropagateRequestIdLayer::new(
322                header::HeaderName::from_static("x-request-id"),
323            ))
324            .layer(SetSensitiveRequestHeadersLayer::new(once(
325                header::AUTHORIZATION,
326            )))
327            .layer(
328                TraceLayer::new_for_http()
329                    .make_span_with(DefaultMakeSpan::new().include_headers(true))
330                    .on_response(DefaultOnResponse::new().include_headers(true)),
331            )
332            .layer(CorsLayer::permissive())
333            .layer(SetSensitiveResponseHeadersLayer::new(once(
334                header::AUTHORIZATION,
335            )))
336            .with_state(state);
337
338        runtime.install_router(router.clone());
339        router
340    }
341
342    pub fn db(&self) -> RocksDb {
343        self.db.clone()
344    }
345
346    /// Internal API router for version 1
347    pub fn internal_api_router_v1() -> Router<AuthenticatedProxyState> {
348        Router::new()
349            .route("/auth/challenge", post(auth_challenge))
350            .route("/auth/verify", post(auth_verify))
351            .route("/auth/exchange", post(auth_exchange))
352            // OAuth 2.0 JWT Bearer Assertion token endpoint (RFC 7523)
353            .route("/oauth/token", post(crate::oauth::token::oauth_token))
354            // mTLS administration endpoints
355            .route(
356                "/admin/services/{service_id}/tls-profile",
357                axum::routing::put(update_tls_profile),
358            )
359            .route("/auth/certificates", axum::routing::post(issue_certificate))
360    }
361}
362
363/// Auth challenge endpoint that handles authentication challenges
364async fn auth_challenge(
365    service_id: ServiceId,
366    State(s): State<AuthenticatedProxyState>,
367    Json(payload): Json<crate::types::ChallengeRequest>,
368) -> Result<Json<crate::types::ChallengeResponse>, StatusCode> {
369    let mut rng = blueprint_std::BlueprintRng::new();
370    let service = ServiceModel::find_by_id(service_id, &s.db)
371        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
372        .ok_or(StatusCode::NOT_FOUND)?;
373
374    let public_key = payload.pub_key;
375    if !service.is_owner(payload.key_type, &public_key) {
376        return Err(StatusCode::UNAUTHORIZED);
377    }
378    let challenge = crate::generate_challenge(&mut rng);
379    let now = std::time::SystemTime::now();
380    let expires_at = now
381        .duration_since(std::time::UNIX_EPOCH)
382        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
383        .add(std::time::Duration::from_secs(30))
384        .as_secs();
385    Ok(Json(crate::types::ChallengeResponse {
386        challenge,
387        expires_at,
388    }))
389}
390
391/// Auth verify endpoint that handles authentication verification
392async fn auth_verify(
393    service_id: ServiceId,
394    State(s): State<AuthenticatedProxyState>,
395    Json(payload): Json<crate::types::VerifyChallengeRequest>,
396) -> impl IntoResponse {
397    let mut rng = blueprint_std::BlueprintRng::new();
398    let service = match ServiceModel::find_by_id(service_id, &s.db) {
399        Ok(Some(service)) => service,
400        Ok(None) => {
401            return (
402                StatusCode::NOT_FOUND,
403                Json(VerifyChallengeResponse::UnexpectedError {
404                    message: "Service not found".to_string(),
405                }),
406            );
407        }
408        Err(e) => {
409            return (
410                StatusCode::INTERNAL_SERVER_ERROR,
411                Json(crate::types::VerifyChallengeResponse::UnexpectedError {
412                    message: format!("Internal server error: {e}"),
413                }),
414            );
415        }
416    };
417
418    let public_key = payload.challenge_request.pub_key;
419    if !service.is_owner(payload.challenge_request.key_type, &public_key) {
420        return (
421            StatusCode::UNAUTHORIZED,
422            Json(crate::types::VerifyChallengeResponse::Unauthorized),
423        );
424    }
425    // Verify the challenge
426    let result = crate::verify_challenge(
427        &payload.challenge,
428        &payload.signature,
429        &public_key,
430        payload.challenge_request.key_type,
431    );
432    match result {
433        Ok(true) => {
434            // Validate additional headers before storing
435            let validated_headers = match validation::validate_headers(&payload.additional_headers)
436            {
437                Ok(headers) => headers,
438                Err(e) => {
439                    return (
440                        StatusCode::BAD_REQUEST,
441                        Json(VerifyChallengeResponse::UnexpectedError {
442                            message: format!("Invalid headers: {e}"),
443                        }),
444                    );
445                }
446            };
447
448            // Apply PII protection by hashing sensitive fields
449            let protected_headers =
450                validation::process_headers_with_pii_protection(&validated_headers);
451
452            // Generate long-lived API key (90 days)
453            let api_key_gen = ApiKeyGenerator::with_prefix(service.api_key_prefix());
454            let expires_at = payload.expires_at.max(
455                std::time::SystemTime::now()
456                    .duration_since(std::time::UNIX_EPOCH)
457                    .unwrap_or_default()
458                    .as_secs()
459                    + (90 * 24 * 60 * 60), // 90 days
460            );
461
462            let api_key =
463                api_key_gen.generate_key(service_id, expires_at, protected_headers, &mut rng);
464
465            let mut api_key_model = ApiKeyModel::from(&api_key);
466            if let Err(e) = api_key_model.save(&s.db) {
467                return (
468                    StatusCode::INTERNAL_SERVER_ERROR,
469                    Json(VerifyChallengeResponse::UnexpectedError {
470                        message: format!("Internal server error: {e}"),
471                    }),
472                );
473            }
474
475            (
476                StatusCode::CREATED,
477                Json(VerifyChallengeResponse::Verified {
478                    api_key: api_key.full_key().to_string(),
479                    expires_at,
480                }),
481            )
482        }
483        Ok(false) => (
484            StatusCode::UNAUTHORIZED,
485            Json(crate::types::VerifyChallengeResponse::InvalidSignature),
486        ),
487        Err(e) => (
488            StatusCode::INTERNAL_SERVER_ERROR,
489            Json(crate::types::VerifyChallengeResponse::UnexpectedError {
490                message: format!("Internal server error: {e}"),
491            }),
492        ),
493    }
494}
495
496/// Token exchange endpoint that converts API keys to Paseto access tokens
497async fn auth_exchange(
498    State(s): State<AuthenticatedProxyState>,
499    headers: axum::http::HeaderMap,
500    Json(payload): Json<crate::auth_token::TokenExchangeRequest>,
501) -> impl IntoResponse {
502    // Extract API key from Authorization header
503    let auth_header = match headers.get(crate::types::headers::AUTHORIZATION) {
504        Some(header_value) => {
505            let header_str = match header_value.to_str() {
506                Ok(s) => s,
507                Err(_) => {
508                    return (
509                        StatusCode::BAD_REQUEST,
510                        Json(serde_json::json!({
511                            "error": "invalid_authorization_header",
512                            "message": "Authorization header is not valid UTF-8"
513                        })),
514                    );
515                }
516            };
517
518            // Extract Bearer token
519            if let Some(token) = header_str.strip_prefix("Bearer ") {
520                token
521            } else {
522                return (
523                    StatusCode::BAD_REQUEST,
524                    Json(serde_json::json!({
525                        "error": "invalid_authorization_header",
526                        "message": "Authorization header must use Bearer scheme with API key"
527                    })),
528                );
529            }
530        }
531        None => {
532            return (
533                StatusCode::UNAUTHORIZED,
534                Json(serde_json::json!({
535                    "error": "missing_authorization_header",
536                    "message": "Authorization header with Bearer API key is required"
537                })),
538            );
539        }
540    };
541
542    // Parse API key format: "ak_xxxxx.yyyyy"
543    let key_id = if let Some((key_id_part, _)) = auth_header.split_once('.') {
544        key_id_part
545    } else {
546        return (
547            StatusCode::BAD_REQUEST,
548            Json(serde_json::json!({
549                "error": "invalid_api_key_format",
550                "message": "API key must have format ak_xxxxx.yyyyy"
551            })),
552        );
553    };
554
555    // Find API key in database
556    let mut api_key_model = match crate::api_keys::ApiKeyModel::find_by_key_id(key_id, &s.db) {
557        Ok(Some(model)) => model,
558        Ok(None) => {
559            return (
560                StatusCode::UNAUTHORIZED,
561                Json(serde_json::json!({
562                    "error": "invalid_api_key",
563                    "message": "API key not found"
564                })),
565            );
566        }
567        Err(e) => {
568            error!("Database error looking up API key: {}", e);
569            return (
570                StatusCode::INTERNAL_SERVER_ERROR,
571                Json(serde_json::json!({
572                    "error": "internal_error",
573                    "message": "Failed to validate API key"
574                })),
575            );
576        }
577    };
578
579    // Validate the full key matches
580    if !api_key_model.validates_key(auth_header) {
581        return (
582            StatusCode::UNAUTHORIZED,
583            Json(serde_json::json!({
584                "error": "invalid_api_key",
585                "message": "API key validation failed"
586            })),
587        );
588    }
589
590    // Check if key is expired or disabled
591    if api_key_model.is_expired() {
592        return (
593            StatusCode::UNAUTHORIZED,
594            Json(serde_json::json!({
595                "error": "expired_api_key",
596                "message": "API key has expired"
597            })),
598        );
599    }
600
601    if !api_key_model.is_enabled {
602        return (
603            StatusCode::UNAUTHORIZED,
604            Json(serde_json::json!({
605                "error": "disabled_api_key",
606                "message": "API key is disabled"
607            })),
608        );
609    }
610
611    // Update last used timestamp
612    if let Err(e) = api_key_model.update_last_used(&s.db) {
613        warn!("Failed to update API key last_used timestamp: {}", e);
614    }
615
616    // Merge default headers with request headers
617    let mut headers = api_key_model.get_default_headers();
618    for (key, value) in payload.additional_headers {
619        headers.insert(key, value);
620    }
621
622    // Validate merged headers
623    let validated_headers = match crate::validation::validate_headers(&headers) {
624        Ok(headers) => headers,
625        Err(e) => {
626            return (
627                StatusCode::BAD_REQUEST,
628                Json(serde_json::json!({
629                    "error": "invalid_headers",
630                    "message": format!("Header validation failed: {}", e)
631                })),
632            );
633        }
634    };
635
636    // Apply PII protection
637    let protected_headers =
638        crate::validation::process_headers_with_pii_protection(&validated_headers);
639
640    // Generate Paseto access token
641    let service_id = api_key_model.service_id();
642    let tenant_id = protected_headers.get("X-Tenant-Id").cloned();
643    let custom_ttl = payload.ttl_seconds.map(std::time::Duration::from_secs);
644
645    let access_token = match s.paseto_manager.generate_token(
646        service_id,
647        api_key_model.key_id.clone(),
648        tenant_id,
649        protected_headers,
650        custom_ttl,
651        None,
652    ) {
653        Ok(token) => token,
654        Err(e) => {
655            error!("Failed to create Paseto token: {}", e);
656            return (
657                StatusCode::INTERNAL_SERVER_ERROR,
658                Json(serde_json::json!({
659                    "error": "token_generation_failed",
660                    "message": "Failed to generate access token"
661                })),
662            );
663        }
664    };
665
666    // Calculate expiration time
667    let expires_at = std::time::SystemTime::now()
668        .duration_since(std::time::UNIX_EPOCH)
669        .unwrap_or_default()
670        .as_secs()
671        + custom_ttl
672            .unwrap_or(s.paseto_manager.default_ttl())
673            .as_secs();
674
675    let response = crate::auth_token::TokenExchangeResponse::new(access_token, expires_at);
676
677    (
678        StatusCode::OK,
679        Json(serde_json::to_value(response).unwrap()),
680    )
681}
682
683async fn update_tls_profile(
684    service_id: ServiceId,
685    State(s): State<AuthenticatedProxyState>,
686    Json(payload): Json<CreateTlsProfileRequest>,
687) -> Result<Json<TlsProfileResponse>, StatusCode> {
688    let envelope = s.tls_envelope_ref().clone();
689    let runtime = s.tls_runtime_ref();
690
691    let mut service = match ServiceModel::find_by_id(service_id, &s.db) {
692        Ok(Some(service)) => service,
693        Ok(None) => return Err(StatusCode::NOT_FOUND),
694        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
695    };
696
697    let existing_profile = service.tls_profile.clone();
698
699    let allowlist_to_store = payload
700        .allowed_dns_names
701        .clone()
702        .or_else(|| {
703            existing_profile
704                .as_ref()
705                .map(|p| p.allowed_dns_names.clone())
706        })
707        .unwrap_or_default();
708
709    let server_dns_names = if let Some(list) = payload.allowed_dns_names.clone() {
710        if list.is_empty() { None } else { Some(list) }
711    } else if let Some(profile) = existing_profile.as_ref() {
712        if !profile.allowed_dns_names.is_empty() {
713            Some(profile.allowed_dns_names.clone())
714        } else if let Some(existing_sni) = profile.sni.clone() {
715            Some(vec![existing_sni])
716        } else {
717            None
718        }
719    } else {
720        None
721    }
722    .unwrap_or_else(|| vec!["localhost".to_string()]);
723
724    let mut tls_profile = existing_profile.unwrap_or_else(|| TlsProfile {
725        tls_enabled: false,
726        require_client_mtls: false,
727        encrypted_server_cert: Vec::new(),
728        encrypted_server_key: Vec::new(),
729        encrypted_client_ca_bundle: Vec::new(),
730        encrypted_upstream_ca_bundle: Vec::new(),
731        encrypted_upstream_client_cert: Vec::new(),
732        encrypted_upstream_client_key: Vec::new(),
733        client_cert_ttl_hours: payload.client_cert_ttl_hours,
734        sni: None,
735        subject_alt_name_template: payload.subject_alt_name_template.clone(),
736        allowed_dns_names: allowlist_to_store.clone(),
737    });
738
739    let ca = if !tls_profile.encrypted_client_ca_bundle.is_empty() {
740        load_persisted_ca(&envelope, &tls_profile)?
741    } else {
742        CertificateAuthority::new(envelope.clone()).map_err(|e| {
743            error!("Failed to create certificate authority: {e}");
744            StatusCode::INTERNAL_SERVER_ERROR
745        })?
746    };
747
748    let mut bundle = ca.ca_certificate_pem();
749    if !bundle.ends_with('\n') {
750        bundle.push('\n');
751    }
752    bundle.push_str(&ca.ca_private_key_pem());
753
754    tls_profile.encrypted_client_ca_bundle = envelope.encrypt(bundle.as_bytes()).map_err(|e| {
755        error!("Failed to encrypt CA bundle: {e}");
756        StatusCode::INTERNAL_SERVER_ERROR
757    })?;
758
759    let (server_cert, server_key) = ca
760        .generate_server_certificate(service_id, server_dns_names.clone())
761        .map_err(|e| {
762            error!("Failed to generate server certificate: {e}");
763            StatusCode::INTERNAL_SERVER_ERROR
764        })?;
765
766    tls_profile.encrypted_server_cert = envelope.encrypt(server_cert.as_bytes()).map_err(|e| {
767        error!("Failed to encrypt server certificate: {e}");
768        StatusCode::INTERNAL_SERVER_ERROR
769    })?;
770    tls_profile.encrypted_server_key = envelope.encrypt(server_key.as_bytes()).map_err(|e| {
771        error!("Failed to encrypt server key: {e}");
772        StatusCode::INTERNAL_SERVER_ERROR
773    })?;
774
775    tls_profile.tls_enabled = true;
776    tls_profile.require_client_mtls = payload.require_client_mtls;
777    tls_profile.client_cert_ttl_hours = payload.client_cert_ttl_hours;
778    if let Some(template) = payload.subject_alt_name_template.clone() {
779        tls_profile.subject_alt_name_template = Some(template);
780    }
781    tls_profile.allowed_dns_names = allowlist_to_store.clone();
782    tls_profile.sni = server_dns_names.first().cloned();
783
784    service.tls_profile = Some(tls_profile.clone());
785    if let Err(e) = service.save(service_id, &s.db) {
786        error!("Failed to persist TLS profile: {e}");
787        return Err(StatusCode::INTERNAL_SERVER_ERROR);
788    }
789
790    let bound_addr = runtime
791        .upsert_service_profile(service_id, &tls_profile)
792        .await
793        .map_err(|e| {
794            error!("Failed to activate TLS profile: {e}");
795            StatusCode::INTERNAL_SERVER_ERROR
796        })?;
797
798    info!(
799        "TLS profile enabled for service {} with listener {}",
800        service_id, bound_addr
801    );
802
803    let listener_uri = match bound_addr.ip() {
804        IpAddr::V4(ip) if ip.is_unspecified() => {
805            format!("https://localhost:{}", bound_addr.port())
806        }
807        IpAddr::V6(ip) if ip.is_unspecified() => {
808            format!("https://[::1]:{}", bound_addr.port())
809        }
810        _ => format!("https://{bound_addr}"),
811    };
812
813    Ok(Json(TlsProfileResponse {
814        tls_enabled: true,
815        require_client_mtls: payload.require_client_mtls,
816        client_cert_ttl_hours: payload.client_cert_ttl_hours,
817        mtls_listener: listener_uri,
818        http_listener: Some(format!("http://localhost:{DEFAULT_AUTH_PROXY_PORT}")),
819        ca_certificate_pem: Some(ca.ca_certificate_pem()),
820        subject_alt_name_template: tls_profile.subject_alt_name_template.clone(),
821        allowed_dns_names: tls_profile.allowed_dns_names.clone(),
822    }))
823}
824
825/// Issue a client certificate for mTLS authentication
826async fn issue_certificate(
827    State(s): State<AuthenticatedProxyState>,
828    Json(payload): Json<IssueCertificateRequest>,
829) -> Result<Json<ClientCertificate>, StatusCode> {
830    let service_id = ServiceId::new(payload.service_id);
831
832    // Find the service
833    let service = match ServiceModel::find_by_id(service_id, &s.db) {
834        Ok(Some(service)) => service,
835        Ok(None) => return Err(StatusCode::NOT_FOUND),
836        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
837    };
838
839    // Check if TLS profile is configured
840    let tls_profile = match &service.tls_profile {
841        Some(profile) => profile,
842        None => return Err(StatusCode::BAD_REQUEST),
843    };
844
845    // Validate the certificate request against the profile
846    if let Err(e) = validate_certificate_request(&payload, tls_profile) {
847        warn!("Certificate request validation failed: {}", e);
848        return Err(StatusCode::BAD_REQUEST);
849    }
850
851    // Initialize certificate authority using persisted CA material if available
852    let mut service_ref = service.clone();
853    let ca = if !tls_profile.encrypted_client_ca_bundle.is_empty() {
854        load_persisted_ca(&s.tls_envelope, tls_profile)?
855    } else {
856        let ca = CertificateAuthority::new(s.tls_envelope.clone()).map_err(|e| {
857            error!("Failed to initialise certificate authority: {e}");
858            StatusCode::INTERNAL_SERVER_ERROR
859        })?;
860
861        persist_new_ca(&s.db, service_id, &mut service_ref, &ca)?;
862        if let Some(profile) = service_ref.tls_profile.as_ref() {
863            if let Err(err) = s
864                .tls_runtime_ref()
865                .upsert_service_profile(service_id, profile)
866                .await
867            {
868                error!("Failed to refresh TLS runtime after CA creation: {err}");
869                return Err(StatusCode::INTERNAL_SERVER_ERROR);
870            }
871        }
872        ca
873    };
874
875    // Generate client certificate
876    let mut client_cert = match ca.generate_client_certificate(
877        payload.common_name,
878        payload.subject_alt_names,
879        payload.ttl_hours,
880    ) {
881        Ok(cert) => cert,
882        Err(e) => {
883            error!("Failed to generate client certificate: {}", e);
884            return Err(StatusCode::INTERNAL_SERVER_ERROR);
885        }
886    };
887
888    if client_cert.revocation_url.is_none() {
889        client_cert.revocation_url = Some(format!(
890            "/v1/auth/certificates/{}/revoke",
891            client_cert.serial
892        ));
893    }
894
895    info!(
896        "Issued client certificate for service {} with serial {}",
897        service_id, client_cert.serial
898    );
899    Ok(Json(client_cert))
900}
901
902fn load_persisted_ca(
903    envelope: &TlsEnvelope,
904    profile: &TlsProfile,
905) -> Result<CertificateAuthority, StatusCode> {
906    let decrypted = envelope
907        .decrypt(&profile.encrypted_client_ca_bundle)
908        .map_err(|e| {
909            error!("Failed to decrypt stored CA bundle: {e}");
910            StatusCode::INTERNAL_SERVER_ERROR
911        })?;
912
913    let pem_str = String::from_utf8(decrypted).map_err(|e| {
914        error!("Stored CA bundle is not valid UTF-8: {e}");
915        StatusCode::INTERNAL_SERVER_ERROR
916    })?;
917
918    let blocks = pem::parse_many(&pem_str).map_err(|e| {
919        error!("Failed to parse stored CA bundle: {e}");
920        StatusCode::INTERNAL_SERVER_ERROR
921    })?;
922
923    let mut ca_cert_pem: Option<String> = None;
924    let mut ca_key_pem: Option<String> = None;
925
926    for block in blocks {
927        let encoded = pem::encode(&block);
928        match block.tag.as_str() {
929            "CERTIFICATE" if ca_cert_pem.is_none() => ca_cert_pem = Some(encoded),
930            tag if tag.ends_with("PRIVATE KEY") && ca_key_pem.is_none() => {
931                ca_key_pem = Some(encoded)
932            }
933            _ => {}
934        }
935    }
936
937    let cert = ca_cert_pem.ok_or_else(|| {
938        error!("Stored CA bundle missing certificate block");
939        StatusCode::INTERNAL_SERVER_ERROR
940    })?;
941    let key = ca_key_pem.ok_or_else(|| {
942        error!("Stored CA bundle missing private key block");
943        StatusCode::INTERNAL_SERVER_ERROR
944    })?;
945
946    CertificateAuthority::from_components(cert, key, clone_envelope(envelope)).map_err(|e| {
947        error!("Failed to rehydrate certificate authority: {e}");
948        StatusCode::INTERNAL_SERVER_ERROR
949    })
950}
951
952fn persist_new_ca(
953    db: &RocksDb,
954    service_id: ServiceId,
955    service: &mut ServiceModel,
956    ca: &CertificateAuthority,
957) -> Result<(), StatusCode> {
958    let mut bundle = ca.ca_certificate_pem();
959    if !bundle.ends_with('\n') {
960        bundle.push('\n');
961    }
962    bundle.push_str(&ca.ca_private_key_pem());
963
964    let encrypted = ca.envelope().encrypt(bundle.as_bytes()).map_err(|e| {
965        error!("Failed to encrypt CA bundle: {e}");
966        StatusCode::INTERNAL_SERVER_ERROR
967    })?;
968
969    if let Some(profile) = service.tls_profile.as_mut() {
970        profile.encrypted_client_ca_bundle = encrypted;
971    } else {
972        error!("TLS profile unexpectedly missing while persisting CA");
973        return Err(StatusCode::INTERNAL_SERVER_ERROR);
974    }
975
976    service.save(service_id, db).map_err(|e| {
977        error!("Failed to persist CA bundle to service: {e}");
978        StatusCode::INTERNAL_SERVER_ERROR
979    })
980}
981
982fn clone_envelope(envelope: &TlsEnvelope) -> TlsEnvelope {
983    TlsEnvelope::with_key(envelope.key().clone())
984}
985
986/// Handle legacy API token validation
987async fn handle_legacy_token(
988    token: crate::api_tokens::ApiToken,
989    db: &crate::db::RocksDb,
990) -> Result<
991    (
992        crate::types::ServiceId,
993        std::collections::BTreeMap<String, String>,
994    ),
995    StatusCode,
996> {
997    let (token_id, token_str) = (token.0, token.1.as_str());
998
999    let api_token = match ApiTokenModel::find_token_id(token_id, db) {
1000        Ok(Some(token)) if token.is(token_str) && !token.is_expired() && token.is_enabled => token,
1001        Ok(Some(_)) | Ok(None) => {
1002            warn!("Invalid or expired legacy token");
1003            return Err(StatusCode::UNAUTHORIZED);
1004        }
1005        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1006    };
1007
1008    let additional_headers = api_token.get_additional_headers();
1009    Ok((api_token.service_id(), additional_headers))
1010}
1011
1012/// Handle API key validation
1013async fn handle_api_key(
1014    api_key: &str,
1015    db: &crate::db::RocksDb,
1016) -> Result<
1017    (
1018        crate::types::ServiceId,
1019        std::collections::BTreeMap<String, String>,
1020    ),
1021    StatusCode,
1022> {
1023    // Parse key_id from "ak_xxxxx.yyyyy"
1024    let key_id = api_key
1025        .split_once('.')
1026        .map(|(key_id_part, _)| key_id_part)
1027        .ok_or(StatusCode::BAD_REQUEST)?;
1028
1029    // Find and validate API key
1030    let mut api_key_model = match crate::api_keys::ApiKeyModel::find_by_key_id(key_id, db) {
1031        Ok(Some(model)) => model,
1032        Ok(None) => {
1033            warn!("API key not found: {}", key_id);
1034            return Err(StatusCode::UNAUTHORIZED);
1035        }
1036        Err(e) => {
1037            error!("Database error looking up API key: {}", e);
1038            return Err(StatusCode::INTERNAL_SERVER_ERROR);
1039        }
1040    };
1041
1042    // Validate the full key
1043    if !api_key_model.validates_key(api_key) {
1044        warn!("API key validation failed: {}", key_id);
1045        return Err(StatusCode::UNAUTHORIZED);
1046    }
1047
1048    // Check expiration and enabled status
1049    if api_key_model.is_expired() {
1050        warn!("API key expired: {}", key_id);
1051        return Err(StatusCode::UNAUTHORIZED);
1052    }
1053
1054    if !api_key_model.is_enabled {
1055        warn!("API key disabled: {}", key_id);
1056        return Err(StatusCode::UNAUTHORIZED);
1057    }
1058
1059    // Update last used timestamp
1060    if let Err(e) = api_key_model.update_last_used(db) {
1061        warn!("Failed to update API key last_used timestamp: {}", e);
1062    }
1063
1064    let additional_headers = api_key_model.get_default_headers();
1065    Ok((api_key_model.service_id(), additional_headers))
1066}
1067
1068/// Handle Paseto access token validation
1069async fn handle_paseto_token(
1070    token: &str,
1071    paseto_manager: &crate::paseto_tokens::PasetoTokenManager,
1072) -> Result<
1073    (
1074        crate::types::ServiceId,
1075        std::collections::BTreeMap<String, String>,
1076    ),
1077    StatusCode,
1078> {
1079    let claims = match paseto_manager.validate_token(token) {
1080        Ok(claims) => claims,
1081        Err(e) => {
1082            warn!("Paseto token validation failed: {}", e);
1083            return Err(StatusCode::UNAUTHORIZED);
1084        }
1085    };
1086
1087    // Build upstream headers starting from token's additional headers
1088    let mut headers = claims.additional_headers.clone();
1089
1090    // Inject canonical X-Scopes from claims.scopes if present
1091    if let Some(scopes_vec) = claims.scopes.clone() {
1092        if !scopes_vec.is_empty() {
1093            let mut set = std::collections::BTreeSet::new();
1094            for s in scopes_vec {
1095                set.insert(s.to_lowercase());
1096            }
1097            let scopes_str = set.into_iter().collect::<Vec<_>>().join(" ");
1098            headers.insert("x-scopes".to_string(), scopes_str);
1099        }
1100    }
1101
1102    // Token expiration is already checked in validate_token
1103    Ok((claims.service_id, headers))
1104}
1105
1106/// Check if a header name is forbidden for security reasons
1107/// Note: header_name may already be lowercase from validation, but we normalize again for safety
1108fn is_forbidden_header(header_name: &str) -> bool {
1109    // Normalized comparison (case-insensitive)
1110    let lower = header_name.to_lowercase();
1111    matches!(
1112        lower.as_str(),
1113        "host"
1114            | "content-length"
1115            | "transfer-encoding"
1116            | "connection"
1117            | "upgrade"
1118            | "proxy-authorization"
1119            | "proxy-authenticate"
1120            | "x-forwarded-host"
1121            | "x-real-ip"
1122            | "x-forwarded-for"
1123            | "x-forwarded-proto"
1124            | "forwarded"
1125    )
1126}
1127
1128/// Check if a header name is auth-specific and should be sanitized from client requests
1129fn is_auth_header(header_name: &str) -> bool {
1130    let lower = header_name.to_lowercase();
1131    lower == "authorization"
1132        || lower.starts_with("x-tenant-")
1133        || lower == "x-scope"
1134        || lower == "x-scopes"
1135}
1136
1137/// Check if a header is required for gRPC functionality
1138fn is_grpc_required_header(header_name: &str) -> bool {
1139    let lower = header_name.to_lowercase();
1140    matches!(
1141        lower.as_str(),
1142        "content-type" | "te" | "grpc-encoding" | "grpc-accept-encoding"
1143    )
1144}
1145
1146/// Check if a header is allowed to be injected as upstream metadata
1147/// Dedicated allowlist to prevent sensitive internal headers from leaking
1148fn is_proxy_injected_header_allowed(header_name: &str) -> bool {
1149    let lower = header_name.to_lowercase();
1150    // Allow tenant and scope headers as they are meant to be forwarded
1151    lower.starts_with("x-tenant-") 
1152        || lower == "x-scope" 
1153        || lower == "x-scopes"
1154        // Allow specific safe headers that are commonly forwarded
1155        || lower == "x-request-id"
1156        || lower == "x-trace-id"
1157        || lower == "user-agent"
1158        // Allow gRPC-specific headers that are safe to forward
1159        || lower.starts_with("grpc-")
1160}
1161
1162/// Validate binary metadata header according to gRPC specification
1163/// Binary metadata keys must end with -bin, be base64 encoded, and have size limits
1164fn validate_binary_metadata(header_name: &str, header_value: &str) -> bool {
1165    let lower = header_name.to_lowercase();
1166
1167    // Check if header ends with -bin suffix
1168    if !lower.ends_with("-bin") {
1169        return false;
1170    }
1171
1172    // Check size limit (configurable max size for binary metadata to prevent large payloads)
1173    if header_value.len() > get_max_binary_metadata_size() {
1174        warn!(
1175            "Binary metadata header {} exceeds size limit: {} bytes (max: {})",
1176            header_name,
1177            header_value.len(),
1178            get_max_binary_metadata_size()
1179        );
1180        return false;
1181    }
1182
1183    // Validate base64 encoding
1184    if header_value.contains('\r') || header_value.contains('\n') {
1185        warn!(
1186            "Binary metadata header {} contains CRLF characters",
1187            header_name
1188        );
1189        return false;
1190    }
1191
1192    // Check if it's valid base64 (allowing padding and standard/base64url variants)
1193    if !header_value.chars().all(|c| {
1194        c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=' || c == '-' || c == '_'
1195    }) {
1196        warn!(
1197            "Binary metadata header {} contains invalid base64 characters",
1198            header_name
1199        );
1200        return false;
1201    }
1202
1203    true
1204}
1205
1206/// Extract and validate authentication token from headers
1207/// Returns (service_id, additional_headers) on success
1208async fn extract_and_validate_auth(
1209    headers: &axum::http::HeaderMap,
1210    db: &crate::db::RocksDb,
1211    paseto_manager: &crate::paseto_tokens::PasetoTokenManager,
1212) -> Result<
1213    (
1214        crate::types::ServiceId,
1215        std::collections::BTreeMap<String, String>,
1216    ),
1217    StatusCode,
1218> {
1219    let auth_header = headers
1220        .get(crate::types::headers::AUTHORIZATION)
1221        .and_then(|h| h.to_str().ok())
1222        .and_then(|h| h.strip_prefix("Bearer "))
1223        .ok_or_else(|| {
1224            warn!("Missing or invalid Authorization header");
1225            StatusCode::UNAUTHORIZED
1226        })?;
1227
1228    debug!("Processing auth header: {}", auth_header);
1229
1230    // Use unified token parsing
1231    match crate::auth_token::AuthToken::parse(auth_header) {
1232        Ok(crate::auth_token::AuthToken::Legacy(legacy_token)) => {
1233            handle_legacy_token(legacy_token, db).await
1234        }
1235        Ok(crate::auth_token::AuthToken::ApiKey(api_key)) => handle_api_key(&api_key, db).await,
1236        Ok(crate::auth_token::AuthToken::AccessToken(_)) => {
1237            // This shouldn't happen as parse() returns an error for Paseto tokens
1238            Err(StatusCode::INTERNAL_SERVER_ERROR)
1239        }
1240        Err(_) if auth_header.starts_with("v4.local.") => {
1241            // Special case for Paseto tokens which need the manager to validate
1242            handle_paseto_token(auth_header, paseto_manager).await
1243        }
1244        Err(e) => {
1245            warn!("Token parsing error for '{}': {:?}", auth_header, e);
1246            Err(StatusCode::UNAUTHORIZED)
1247        }
1248    }
1249}
1250
1251/// Apply additional headers to request with security validation
1252fn apply_additional_headers(
1253    req: &mut Request,
1254    additional_headers: std::collections::BTreeMap<String, String>,
1255    is_grpc: bool,
1256) {
1257    for (header_name, header_value) in additional_headers {
1258        // Re-validate header names against security-sensitive headers
1259        if is_forbidden_header(&header_name) {
1260            warn!("Attempted to inject forbidden header: {}", header_name);
1261            continue;
1262        }
1263
1264        // For gRPC, allow gRPC-required headers even if they might be auth-related
1265        // For gRPC, also enforce dedicated allowlist for proxy-injected headers to prevent sensitive internal headers from leaking
1266        let is_tenant_header = header_name.to_lowercase().starts_with("x-tenant-");
1267        let is_scope_header =
1268            header_name.to_lowercase() == "x-scope" || header_name.to_lowercase() == "x-scopes";
1269        let is_allowed_proxy_header = is_proxy_injected_header_allowed(&header_name);
1270
1271        // Skip (continue) if header is not allowed
1272        let is_allowed = !is_auth_header(&header_name)
1273            || (is_grpc && is_grpc_required_header(&header_name))
1274            || is_tenant_header
1275            || is_scope_header
1276            || (is_grpc && is_allowed_proxy_header);
1277
1278        if !is_allowed {
1279            continue;
1280        }
1281
1282        if let Ok(name) = header::HeaderName::from_bytes(header_name.as_bytes()) {
1283            // Additional validation for header values
1284            if let Ok(value) = header::HeaderValue::from_str(&header_value) {
1285                // For gRPC binary metadata, apply additional validation
1286                if is_grpc
1287                    && header_name.to_lowercase().ends_with("-bin")
1288                    && !validate_binary_metadata(&header_name, &header_value)
1289                {
1290                    warn!(
1291                        "Invalid binary metadata header {}: {}",
1292                        header_name, header_value
1293                    );
1294                    continue;
1295                }
1296
1297                // Prevent header value injection attacks
1298                if header_value.contains('\r') || header_value.contains('\n') {
1299                    warn!("Header value contains CRLF: {}", header_name);
1300                    continue;
1301                }
1302                req.headers_mut().insert(name, value);
1303            } else {
1304                warn!("Invalid header value for {}: {}", header_name, header_value);
1305            }
1306        } else {
1307            warn!("Invalid header name: {}", header_name);
1308        }
1309    }
1310}
1311
1312/// Sanitize request headers by removing auth-specific and forbidden headers
1313fn sanitize_request_headers(req: &mut Request, is_grpc: bool) {
1314    let mut to_remove: Vec<header::HeaderName> = Vec::new();
1315    for (name, _value) in req.headers().iter() {
1316        let name_str = name.as_str();
1317        if is_auth_header(name_str) || is_forbidden_header(name_str) {
1318            // For gRPC, don't remove gRPC-required headers
1319            if !(is_grpc && is_grpc_required_header(name_str)) {
1320                to_remove.push(name.clone());
1321            }
1322        }
1323    }
1324    for name in to_remove {
1325        req.headers_mut().remove(name);
1326    }
1327}
1328
1329/// Detect if a request is a gRPC request based on headers and HTTP version
1330fn is_grpc_request(headers: &axum::http::HeaderMap, req: &Request) -> bool {
1331    // Check for content-type: application/grpc (case-insensitive)
1332    let content_type = headers.get("content-type");
1333    let is_grpc_content_type = content_type
1334        .and_then(|ct| ct.to_str().ok())
1335        .map(|ct| {
1336            debug!("Content-Type header: {}", ct);
1337            ct.to_lowercase().starts_with("application/grpc")
1338        })
1339        .unwrap_or(false);
1340
1341    // Check for te: trailers header (required for gRPC)
1342    let te_header = headers.get("te");
1343    let has_te_trailers = te_header
1344        .and_then(|te| te.to_str().ok())
1345        .map(|te| {
1346            debug!("TE header: {}", te);
1347            te.to_lowercase() == "trailers"
1348        })
1349        .unwrap_or(false);
1350
1351    debug!(
1352        "gRPC detection - content-type: {:?}, is_grpc: {}, te: {:?}, has_trailers: {}",
1353        content_type, is_grpc_content_type, te_header, has_te_trailers
1354    );
1355
1356    // Check HTTP version - gRPC requires HTTP/2 to prevent HTTP/1.1 downgrade attempts
1357    let is_http2 = req.version() == axum::http::Version::HTTP_2;
1358    debug!("HTTP version: {:?}, is_http2: {}", req.version(), is_http2);
1359
1360    let result = is_grpc_content_type && has_te_trailers && is_http2;
1361    debug!("gRPC request detected: {}", result);
1362    result
1363}
1364
1365/// gRPC-aware proxy handler that can handle both HTTP and gRPC requests
1366#[instrument(skip_all)]
1367async fn unified_proxy(
1368    headers: axum::http::HeaderMap,
1369    State(s): State<AuthenticatedProxyState>,
1370    req: Request,
1371) -> Result<Response, StatusCode> {
1372    info!("Unified proxy called with headers: {:?}", headers);
1373    info!("Request method: {}, URI: {}", req.method(), req.uri());
1374
1375    // Extract client certificate information from request extensions
1376    let client_cert = extract_client_cert_from_request(&req);
1377
1378    // Determine authentication method based on available information
1379    let auth_method = if client_cert.is_some() {
1380        AuthMethod::Mtls
1381    } else if headers.get("authorization").is_some() {
1382        AuthMethod::AccessToken
1383    } else if headers.get("x-api-key").is_some() {
1384        AuthMethod::ApiKey
1385    } else {
1386        AuthMethod::OAuth
1387    };
1388
1389    let is_grpc = is_grpc_request(&headers, &req);
1390
1391    if is_grpc {
1392        info!("Detected gRPC request, using gRPC proxy path");
1393        grpc_proxy_with_mtls(headers, State(s), req, client_cert, auth_method).await
1394    } else {
1395        info!("Detected HTTP request, using HTTP proxy path");
1396        reverse_proxy_with_mtls(headers, State(s), req, client_cert, auth_method).await
1397    }
1398}
1399
1400/// gRPC proxy handler with mTLS support
1401#[instrument(skip_all)]
1402async fn grpc_proxy_with_mtls(
1403    headers: axum::http::HeaderMap,
1404    State(s): State<AuthenticatedProxyState>,
1405    mut req: Request,
1406    client_cert: Option<crate::tls_listener::ClientCertInfo>,
1407    _auth_method: AuthMethod,
1408) -> Result<Response, StatusCode> {
1409    debug!("gRPC proxy with mTLS called with headers: {:?}", headers);
1410
1411    // Extract and validate authentication
1412    let (service_id, mut additional_headers) =
1413        extract_and_validate_auth(&headers, &s.db, &s.paseto_manager).await?;
1414
1415    // Check if service requires mTLS and enforce it
1416    let service = match ServiceModel::find_by_id(service_id, &s.db) {
1417        Ok(Some(service)) => service,
1418        Ok(None) => return Err(StatusCode::NOT_FOUND),
1419        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1420    };
1421
1422    // Enforce mTLS requirement if configured
1423    if let Some(tls_profile) = &service.tls_profile {
1424        if tls_profile.require_client_mtls && client_cert.is_none() {
1425            warn!(
1426                "mTLS required but no client certificate provided for service {}",
1427                service_id
1428            );
1429            return Err(StatusCode::UNAUTHORIZED);
1430        }
1431    }
1432
1433    // Add client certificate information to headers if available
1434    if let Some(cert) = &client_cert {
1435        additional_headers.insert("x-client-cert-subject".to_string(), cert.subject.clone());
1436        additional_headers.insert("x-client-cert-issuer".to_string(), cert.issuer.clone());
1437        additional_headers.insert("x-client-cert-serial".to_string(), cert.serial.clone());
1438        additional_headers.insert("x-auth-method".to_string(), "mtls".to_string());
1439    }
1440
1441    let target_host = service
1442        .upstream_url()
1443        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1444
1445    debug!("Target host: {:?}", target_host);
1446
1447    let path = req.uri().path();
1448    let path_query = req
1449        .uri()
1450        .path_and_query()
1451        .map(|v| v.as_str())
1452        .unwrap_or(path);
1453    let target_uri = Uri::builder()
1454        .scheme(target_host.scheme().cloned().unwrap_or(uri::Scheme::HTTP))
1455        .authority(
1456            target_host
1457                .authority()
1458                .cloned()
1459                .unwrap_or(uri::Authority::from_static("localhost")),
1460        )
1461        .path_and_query(path_query)
1462        .build()
1463        .map_err(|_| StatusCode::BAD_REQUEST)?;
1464
1465    debug!("Target URI: {:?}", target_uri);
1466
1467    // Set the target URI in the request
1468    *req.uri_mut() = target_uri;
1469
1470    // Sanitize inbound headers and apply additional headers (gRPC-aware)
1471    sanitize_request_headers(&mut req, true);
1472    apply_additional_headers(&mut req, additional_headers, true);
1473
1474    debug!("Forwarding gRPC request with headers: {:?}", req.headers());
1475
1476    // Determine if we need TLS for the upstream connection
1477    let use_tls = target_host.scheme() == Some(&uri::Scheme::HTTPS);
1478
1479    // Forward the request using appropriate client
1480    let response = if use_tls {
1481        // Use TLS client for HTTPS upstreams
1482        let tls_client = s
1483            .tls_client_manager
1484            .get_client_for_service(service_id)
1485            .await
1486            .map_err(|e| {
1487                error!(
1488                    "Failed to get TLS client for service {}: {:?}",
1489                    service_id, e
1490                );
1491                StatusCode::INTERNAL_SERVER_ERROR
1492            })?;
1493
1494        // Convert request body to Incoming type
1495        let (parts, body) = req.into_parts();
1496        let req_with_incoming = Request::from_parts(parts, body);
1497        tls_client
1498            .http2_client
1499            .request(req_with_incoming)
1500            .await
1501            .map_err(|e| {
1502                error!("Failed to forward gRPC request with TLS: {:?}", e);
1503                StatusCode::BAD_GATEWAY
1504            })?
1505    } else {
1506        // Use fallback HTTP/2 client for HTTP upstreams
1507        s.http2_client.request(req).await.map_err(|e| {
1508            error!("Failed to forward gRPC request: {:?}", e);
1509            StatusCode::BAD_GATEWAY
1510        })?
1511    };
1512
1513    debug!("gRPC response received: {:?}", response);
1514
1515    Ok(response.into_response())
1516}
1517
1518/// Reverse proxy handler with mTLS support
1519#[instrument(skip_all)]
1520async fn reverse_proxy_with_mtls(
1521    headers: axum::http::HeaderMap,
1522    State(s): State<AuthenticatedProxyState>,
1523    mut req: Request,
1524    client_cert: Option<crate::tls_listener::ClientCertInfo>,
1525    _auth_method: AuthMethod,
1526) -> Result<Response, StatusCode> {
1527    // Extract and validate authentication
1528    let (service_id, mut additional_headers) =
1529        extract_and_validate_auth(&headers, &s.db, &s.paseto_manager).await?;
1530
1531    // Check if service requires mTLS and enforce it
1532    let service = match ServiceModel::find_by_id(service_id, &s.db) {
1533        Ok(Some(service)) => service,
1534        Ok(None) => return Err(StatusCode::NOT_FOUND),
1535        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
1536    };
1537
1538    if let Some(tls_profile) = &service.tls_profile {
1539        if tls_profile.require_client_mtls && client_cert.is_none() {
1540            warn!(
1541                "mTLS required but no client certificate provided for service {}",
1542                service_id
1543            );
1544            return Err(StatusCode::UNAUTHORIZED);
1545        }
1546    }
1547
1548    // Add client certificate information to headers if available
1549    if let Some(cert) = &client_cert {
1550        additional_headers.insert("x-client-cert-subject".to_string(), cert.subject.clone());
1551        additional_headers.insert("x-client-cert-issuer".to_string(), cert.issuer.clone());
1552        additional_headers.insert("x-client-cert-serial".to_string(), cert.serial.clone());
1553        additional_headers.insert("x-auth-method".to_string(), "mtls".to_string());
1554    }
1555
1556    let target_host = service
1557        .upstream_url()
1558        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1559
1560    let path = req.uri().path();
1561    let path_query = req
1562        .uri()
1563        .path_and_query()
1564        .map(|v| v.as_str())
1565        .unwrap_or(path);
1566    let target_uri = Uri::builder()
1567        .scheme(target_host.scheme().cloned().unwrap_or(uri::Scheme::HTTP))
1568        .authority(
1569            target_host
1570                .authority()
1571                .cloned()
1572                .unwrap_or(uri::Authority::from_static("localhost")),
1573        )
1574        .path_and_query(path_query)
1575        .build()
1576        .map_err(|_| StatusCode::BAD_REQUEST)?;
1577
1578    // Set the target URI in the request
1579    *req.uri_mut() = target_uri;
1580
1581    // Sanitize inbound headers and apply additional headers
1582    sanitize_request_headers(&mut req, false);
1583    apply_additional_headers(&mut req, additional_headers, false);
1584
1585    // Determine if we need TLS for the upstream connection
1586    let use_tls = target_host.scheme() == Some(&uri::Scheme::HTTPS);
1587
1588    // Forward the request using appropriate client
1589    let response = if use_tls {
1590        // Use TLS client for HTTPS upstreams
1591        let tls_client = s
1592            .tls_client_manager
1593            .get_client_for_service(service_id)
1594            .await
1595            .map_err(|e| {
1596                error!(
1597                    "Failed to get TLS client for service {}: {:?}",
1598                    service_id, e
1599                );
1600                StatusCode::INTERNAL_SERVER_ERROR
1601            })?;
1602
1603        // Convert request body to Incoming type
1604        let (parts, body) = req.into_parts();
1605        let req_with_incoming = Request::from_parts(parts, body);
1606        tls_client
1607            .http_client
1608            .request(req_with_incoming)
1609            .await
1610            .map_err(|e| {
1611                error!("Failed to forward HTTP request with TLS: {:?}", e);
1612                StatusCode::BAD_GATEWAY
1613            })?
1614    } else {
1615        // Use fallback HTTP/1.1 client for HTTP upstreams
1616        s.http_client
1617            .request(req)
1618            .await
1619            .map_err(|_| StatusCode::BAD_GATEWAY)?
1620    };
1621
1622    Ok(response.into_response())
1623}
1624
1625#[cfg(test)]
1626mod tests {
1627    use std::collections::BTreeMap;
1628    use std::net::Ipv4Addr;
1629
1630    use tempfile::tempdir;
1631
1632    use super::*;
1633    use crate::{
1634        test_client::TestClient,
1635        types::{ChallengeRequest, ChallengeResponse, KeyType, VerifyChallengeResponse, headers},
1636    };
1637
1638    #[tokio::test]
1639    async fn auth_flow_works() {
1640        let _guard = tracing::subscriber::set_default(
1641            tracing_subscriber::fmt()
1642                .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1643                .with_line_number(true)
1644                .with_file(true)
1645                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
1646                .with_test_writer()
1647                .finish(),
1648        );
1649        let mut rng = blueprint_std::BlueprintRng::new();
1650        let tmp = tempdir().unwrap();
1651        let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1652
1653        // Create a simple hello world http server using axum
1654        let hello_world_router =
1655            Router::new().route("/hello", axum::routing::get(|| async { "Hello, World!" }));
1656
1657        // Start the simple hello world server in a separate task
1658        let (hello_world_server, local_addr) = {
1659            let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
1660                .await
1661                .expect("Failed to bind to address");
1662            let server = axum::serve(listener, hello_world_router);
1663            let local_address = server.local_addr().unwrap();
1664            let handle = tokio::spawn(async move {
1665                if let Err(e) = server.await {
1666                    eprintln!("Hello world server error: {e}");
1667                }
1668            });
1669            (handle, local_address)
1670        };
1671
1672        // Create a service in the database first
1673        let service_id = ServiceId::new(0);
1674        let mut service = crate::models::ServiceModel {
1675            api_key_prefix: "test_".to_string(),
1676            owners: Vec::new(),
1677            upstream_url: format!("http://localhost:{}", local_addr.port()),
1678            tls_profile: None,
1679        };
1680
1681        let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1682        let public_key = signing_key.verifying_key().to_sec1_bytes();
1683
1684        // Add the owner to the service
1685        service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1686        service.save(service_id, &proxy.db).unwrap();
1687
1688        let router = proxy.router();
1689        let client = TestClient::new(router);
1690
1691        // Step 1
1692        let req = ChallengeRequest {
1693            pub_key: public_key.clone().into(),
1694            key_type: KeyType::Ecdsa,
1695        };
1696
1697        let res = client
1698            .post("/v1/auth/challenge")
1699            .header(headers::X_SERVICE_ID, service_id.to_string())
1700            .json(&req)
1701            .await;
1702
1703        let res: ChallengeResponse = res.json().await;
1704
1705        // Sign the challenge and send it back
1706        let (signature, _) = signing_key
1707            .sign_prehash_recoverable(&res.challenge)
1708            .unwrap();
1709        // sanity check
1710        assert!(
1711            crate::verify_challenge(
1712                &res.challenge,
1713                &signature.to_vec(),
1714                &public_key,
1715                KeyType::Ecdsa
1716            )
1717            .unwrap()
1718        );
1719
1720        // Step 2
1721        let req = crate::types::VerifyChallengeRequest {
1722            challenge: res.challenge,
1723            signature: signature.to_bytes().into(),
1724            challenge_request: req,
1725            expires_at: 0,
1726            additional_headers: BTreeMap::new(),
1727        };
1728
1729        let res = client
1730            .post("/v1/auth/verify")
1731            .header(headers::X_SERVICE_ID, ServiceId::new(0).to_string())
1732            .json(&req)
1733            .await;
1734        let res: VerifyChallengeResponse = res.json().await;
1735
1736        assert!(matches!(res, VerifyChallengeResponse::Verified { .. }));
1737        let api_key = match res {
1738            VerifyChallengeResponse::Verified { api_key, .. } => api_key,
1739            _ => panic!("Expected a verified response"),
1740        };
1741
1742        // Try to send a request to the reverse proxy with the token in the header
1743        let res = client
1744            .get("/hello")
1745            .header(headers::AUTHORIZATION, format!("Bearer {api_key}"))
1746            .await;
1747        assert!(
1748            res.status().is_success(),
1749            "Request to reverse proxy failed: {res:?}",
1750        );
1751
1752        hello_world_server.abort(); // Stop the hello world server
1753    }
1754
1755    #[tokio::test]
1756    async fn auth_flow_with_additional_headers() {
1757        use std::collections::BTreeMap;
1758
1759        let _guard = tracing::subscriber::set_default(
1760            tracing_subscriber::fmt()
1761                .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1762                .with_line_number(true)
1763                .with_file(true)
1764                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
1765                .with_test_writer()
1766                .finish(),
1767        );
1768        let mut rng = blueprint_std::BlueprintRng::new();
1769        let tmp = tempdir().unwrap();
1770        let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1771
1772        // Create a test server that echoes back headers
1773        let echo_router = Router::new().route(
1774            "/echo",
1775            axum::routing::get(|headers: axum::http::HeaderMap| async move {
1776                let mut response_headers = BTreeMap::new();
1777                for (name, value) in headers.iter() {
1778                    if name.as_str().starts_with("x-tenant-")
1779                        || name.as_str().starts_with("X-Tenant-")
1780                    {
1781                        response_headers
1782                            .insert(name.to_string(), value.to_str().unwrap_or("").to_string());
1783                    }
1784                }
1785                axum::Json(response_headers)
1786            }),
1787        );
1788
1789        // Start the echo server
1790        let (echo_server, local_addr) = {
1791            let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
1792                .await
1793                .expect("Failed to bind to address");
1794            let server = axum::serve(listener, echo_router);
1795            let local_address = server.local_addr().unwrap();
1796            let handle = tokio::spawn(async move {
1797                if let Err(e) = server.await {
1798                    eprintln!("Echo server error: {e}");
1799                }
1800            });
1801            (handle, local_address)
1802        };
1803
1804        // Create a service in the database
1805        let service_id = ServiceId::new(1);
1806        let mut service = crate::models::ServiceModel {
1807            api_key_prefix: "test_".to_string(),
1808            owners: Vec::new(),
1809            upstream_url: format!("http://localhost:{}", local_addr.port()),
1810            tls_profile: None,
1811        };
1812
1813        let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1814        let public_key = signing_key.verifying_key().to_sec1_bytes();
1815
1816        service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1817        service.save(service_id, &proxy.db).unwrap();
1818
1819        let router = proxy.router();
1820        let client = TestClient::new(router);
1821
1822        // Step 1: Get challenge
1823        let req = ChallengeRequest {
1824            pub_key: public_key.clone().into(),
1825            key_type: KeyType::Ecdsa,
1826        };
1827
1828        let res = client
1829            .post("/v1/auth/challenge")
1830            .header(headers::X_SERVICE_ID, service_id.to_string())
1831            .json(&req)
1832            .await;
1833
1834        let res: ChallengeResponse = res.json().await;
1835
1836        // Sign the challenge
1837        let (signature, _) = signing_key
1838            .sign_prehash_recoverable(&res.challenge)
1839            .unwrap();
1840
1841        // Step 2: Verify with additional headers
1842        let mut additional_headers = BTreeMap::new();
1843        let user_id = "user123@example.com";
1844        let tenant_id = crate::validation::hash_user_id(user_id);
1845        additional_headers.insert("X-Tenant-Id".to_string(), tenant_id.clone());
1846        additional_headers.insert("X-Tenant-Name".to_string(), "Acme Corp".to_string());
1847
1848        let req = crate::types::VerifyChallengeRequest {
1849            challenge: res.challenge,
1850            signature: signature.to_bytes().into(),
1851            challenge_request: req,
1852            expires_at: 0,
1853            additional_headers,
1854        };
1855
1856        let res = client
1857            .post("/v1/auth/verify")
1858            .header(headers::X_SERVICE_ID, service_id.to_string())
1859            .json(&req)
1860            .await;
1861        let res: VerifyChallengeResponse = res.json().await;
1862
1863        assert!(matches!(res, VerifyChallengeResponse::Verified { .. }));
1864        let api_key = match res {
1865            VerifyChallengeResponse::Verified { api_key, .. } => api_key,
1866            _ => panic!("Expected a verified response"),
1867        };
1868
1869        // Step 3: Make request with token and verify headers are forwarded
1870        let res = client
1871            .get("/echo")
1872            .header(headers::AUTHORIZATION, format!("Bearer {api_key}"))
1873            .await;
1874
1875        assert!(res.status().is_success());
1876
1877        let response_headers: BTreeMap<String, String> = res.json().await;
1878        assert_eq!(response_headers.get("x-tenant-id"), Some(&tenant_id));
1879        assert_eq!(
1880            response_headers.get("x-tenant-name"),
1881            Some(&"Acme Corp".to_string())
1882        );
1883
1884        echo_server.abort();
1885    }
1886
1887    #[tokio::test]
1888    async fn auth_flow_rejects_invalid_headers() {
1889        use std::collections::BTreeMap;
1890
1891        let mut rng = blueprint_std::BlueprintRng::new();
1892        let tmp = tempdir().unwrap();
1893        let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
1894
1895        let service_id = ServiceId::new(2);
1896        let mut service = crate::models::ServiceModel {
1897            api_key_prefix: "test_".to_string(),
1898            owners: Vec::new(),
1899            upstream_url: "http://localhost:9999".to_string(),
1900            tls_profile: None,
1901        };
1902
1903        let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
1904        let public_key = signing_key.verifying_key().to_sec1_bytes();
1905
1906        service.add_owner(KeyType::Ecdsa, public_key.clone().into());
1907        service.save(service_id, &proxy.db).unwrap();
1908
1909        let router = proxy.router();
1910        let client = TestClient::new(router);
1911
1912        // Get challenge
1913        let req = ChallengeRequest {
1914            pub_key: public_key.clone().into(),
1915            key_type: KeyType::Ecdsa,
1916        };
1917
1918        let res = client
1919            .post("/v1/auth/challenge")
1920            .header(headers::X_SERVICE_ID, service_id.to_string())
1921            .json(&req)
1922            .await;
1923
1924        let res: ChallengeResponse = res.json().await;
1925
1926        let (signature, _) = signing_key
1927            .sign_prehash_recoverable(&res.challenge)
1928            .unwrap();
1929
1930        // Try to verify with forbidden headers
1931        let mut additional_headers = BTreeMap::new();
1932        additional_headers.insert("Connection".to_string(), "close".to_string());
1933
1934        let req = crate::types::VerifyChallengeRequest {
1935            challenge: res.challenge,
1936            signature: signature.to_bytes().into(),
1937            challenge_request: req,
1938            expires_at: 0,
1939            additional_headers,
1940        };
1941
1942        let res = client
1943            .post("/v1/auth/verify")
1944            .header(headers::X_SERVICE_ID, service_id.to_string())
1945            .json(&req)
1946            .await;
1947
1948        let res: VerifyChallengeResponse = res.json().await;
1949
1950        // Should fail with an error about invalid headers
1951        assert!(
1952            matches!(res, VerifyChallengeResponse::UnexpectedError { message } if message.contains("Invalid headers"))
1953        );
1954    }
1955}