http_tunnel_handler/
auth.rs

1//! Authentication module for WebSocket connections
2//!
3//! This module provides JWT-based authentication for WebSocket connections.
4//! Authentication can be enabled/disabled via the REQUIRE_AUTH environment variable.
5
6use anyhow::{Context, Result, anyhow};
7use aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest;
8use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
9use once_cell::sync::Lazy;
10use serde::{Deserialize, Serialize};
11use std::sync::RwLock;
12use tracing::{debug, info, warn};
13
14/// JWT Claims structure
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Claims {
17    pub sub: String, // Subject (user ID)
18    pub exp: usize,  // Expiration time
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub iat: Option<usize>, // Issued at
21}
22
23/// JWKS (JSON Web Key Set) structure
24#[derive(Debug, Clone, Deserialize)]
25struct Jwks {
26    keys: Vec<JwkKey>,
27}
28
29/// Individual JWK (JSON Web Key)
30#[derive(Debug, Clone, Deserialize)]
31struct JwkKey {
32    kty: String, // Key type (RSA or oct)
33    kid: String, // Key ID
34    alg: String, // Algorithm (RS256, HS256, etc.)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    n: Option<String>, // RSA modulus (base64url)
37    #[serde(skip_serializing_if = "Option::is_none")]
38    e: Option<String>, // RSA exponent (base64url)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    k: Option<String>, // Symmetric key (base64url)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    #[allow(dead_code)]
43    r#use: Option<String>, // Key use (sig, enc)
44}
45
46/// Cached JWKS loaded from file
47static JWKS_CACHE: Lazy<RwLock<Option<Jwks>>> = Lazy::new(|| RwLock::new(None));
48
49/// Load JWKS from environment variable or file (cached)
50fn load_jwks() -> Result<Jwks> {
51    // Check cache first
52    {
53        let cache = JWKS_CACHE.read().unwrap();
54        if let Some(jwks) = cache.as_ref() {
55            return Ok(jwks.clone());
56        }
57    }
58
59    // Try loading from JWKS environment variable first
60    let jwks_content = if let Ok(jwks_json) = std::env::var("JWKS") {
61        debug!("Loading JWKS from JWKS environment variable");
62        jwks_json
63    } else {
64        // Fallback to file
65        let jwks_path =
66            std::env::var("JWKS_PATH").unwrap_or_else(|_| "/var/task/jwks.json".to_string());
67
68        debug!("Loading JWKS from file: {}", jwks_path);
69
70        std::fs::read_to_string(&jwks_path)
71            .with_context(|| format!("Failed to read JWKS file at {}", jwks_path))?
72    };
73
74    let jwks: Jwks = serde_json::from_str(&jwks_content).context("Failed to parse JWKS JSON")?;
75
76    // Cache it
77    {
78        let mut cache = JWKS_CACHE.write().unwrap();
79        *cache = Some(jwks.clone());
80    }
81
82    info!("JWKS loaded successfully with {} keys", jwks.keys.len());
83    Ok(jwks)
84}
85
86/// Check if authentication is required based on environment variable
87pub fn is_auth_required() -> bool {
88    std::env::var("REQUIRE_AUTH")
89        .unwrap_or_else(|_| "false".to_string())
90        .to_lowercase()
91        == "true"
92}
93
94/// Extract token from WebSocket request
95/// Checks (in order): Authorization header, query parameters
96fn extract_token(request: &ApiGatewayWebsocketProxyRequest) -> Option<String> {
97    // First try Authorization header (preferred - works with custom domains and not logged)
98    if let Some(auth_header) = request
99        .headers
100        .get("authorization")
101        .or_else(|| request.headers.get("Authorization"))
102        && let Some(token) = auth_header
103            .to_str()
104            .ok()
105            .and_then(|s| s.strip_prefix("Bearer "))
106    {
107        debug!("Token extracted from Authorization header");
108        return Some(token.to_string());
109    }
110
111    // Fallback to query parameter (less secure - gets logged)
112    if let Some(token) = request.query_string_parameters.first("token") {
113        warn!("Token extracted from query parameter (consider using Authorization header)");
114        return Some(token.to_string());
115    }
116
117    None
118}
119
120/// Validate JWT token using JWKS file or JWT_SECRET
121pub fn validate_token(token: &str) -> Result<Claims> {
122    // Try JWKS first if available
123    if let Ok(jwks) = load_jwks() {
124        // Try each key in JWKS
125        for key in &jwks.keys {
126            debug!(
127                "Trying key: {} (type: {}, alg: {})",
128                key.kid, key.kty, key.alg
129            );
130
131            let result = match key.kty.as_str() {
132                "RSA" => validate_with_rsa_key(token, key),
133                "oct" => validate_with_symmetric_key(token, key),
134                _ => {
135                    warn!("Unsupported key type: {} (kid: {})", key.kty, key.kid);
136                    continue;
137                }
138            };
139
140            match result {
141                Ok(claims) => {
142                    info!("✅ Token validated with key: {} ({})", key.kid, key.alg);
143                    return Ok(claims);
144                }
145                Err(e) => {
146                    debug!("Key {} validation failed: {}", key.kid, e);
147                }
148            }
149        }
150
151        warn!(
152            "Token validation failed with all {} JWKS keys",
153            jwks.keys.len()
154        );
155        return Err(anyhow!("Token validation failed with all JWKS keys"));
156    }
157
158    // Fallback to JWT_SECRET environment variable
159    let secret = std::env::var("JWT_SECRET")
160        .unwrap_or_else(|_| "default-secret-change-in-production".to_string());
161
162    debug!("Using JWT_SECRET for validation (JWKS not available)");
163
164    let validation = Validation::new(Algorithm::HS256);
165    let token_data = decode::<Claims>(
166        token,
167        &DecodingKey::from_secret(secret.as_bytes()),
168        &validation,
169    )?;
170
171    Ok(token_data.claims)
172}
173
174/// Validate token with RSA public key
175fn validate_with_rsa_key(token: &str, key: &JwkKey) -> Result<Claims> {
176    let n = key
177        .n
178        .as_ref()
179        .ok_or_else(|| anyhow!("Missing 'n' in RSA key"))?;
180    let e = key
181        .e
182        .as_ref()
183        .ok_or_else(|| anyhow!("Missing 'e' in RSA key"))?;
184
185    let algorithm = match key.alg.as_str() {
186        "RS256" => Algorithm::RS256,
187        "RS384" => Algorithm::RS384,
188        "RS512" => Algorithm::RS512,
189        _ => return Err(anyhow!("Unsupported RSA algorithm: {}", key.alg)),
190    };
191
192    // DecodingKey::from_rsa_components expects base64url strings directly
193    let decoding_key = DecodingKey::from_rsa_components(n, e)?;
194
195    // Create validation without audience/issuer checks (accept any)
196    let mut validation = Validation::new(algorithm);
197    validation.validate_aud = false;
198    validation.validate_exp = true;
199
200    let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
201    Ok(token_data.claims)
202}
203
204/// Validate token with symmetric (HMAC) key
205fn validate_with_symmetric_key(token: &str, key: &JwkKey) -> Result<Claims> {
206    let k = key
207        .k
208        .as_ref()
209        .ok_or_else(|| anyhow!("Missing 'k' in symmetric key"))?;
210
211    let key_bytes = base64_url_decode(k)?;
212
213    let algorithm = match key.alg.as_str() {
214        "HS256" => Algorithm::HS256,
215        "HS384" => Algorithm::HS384,
216        "HS512" => Algorithm::HS512,
217        _ => return Err(anyhow!("Unsupported HMAC algorithm: {}", key.alg)),
218    };
219
220    let decoding_key = DecodingKey::from_secret(&key_bytes);
221    let validation = Validation::new(algorithm);
222
223    let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
224    Ok(token_data.claims)
225}
226
227/// Decode base64url string (with or without padding)
228fn base64_url_decode(s: &str) -> Result<Vec<u8>> {
229    use base64::Engine;
230    base64::engine::general_purpose::URL_SAFE_NO_PAD
231        .decode(s)
232        .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(s))
233        .context("Failed to decode base64url")
234}
235
236/// Authenticate WebSocket connection request
237///
238/// Returns Ok(Some(claims)) if authentication is required and successful
239/// Returns Ok(None) if authentication is not required
240/// Returns Err if authentication is required but failed
241pub fn authenticate_request(request: &ApiGatewayWebsocketProxyRequest) -> Result<Option<Claims>> {
242    if !is_auth_required() {
243        debug!("Authentication not required");
244        return Ok(None);
245    }
246
247    info!("Authentication required, validating token");
248
249    let token =
250        extract_token(request).ok_or_else(|| anyhow!("No authentication token provided"))?;
251
252    match validate_token(&token) {
253        Ok(claims) => {
254            info!("Token validated successfully for user: {}", claims.sub);
255            Ok(Some(claims))
256        }
257        Err(e) => {
258            warn!("Token validation failed: {}", e);
259            Err(anyhow!("Invalid or expired token"))
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use jsonwebtoken::{EncodingKey, Header, encode};
268
269    #[test]
270    fn test_create_and_validate_token() {
271        let claims = Claims {
272            sub: "user123".to_string(),
273            exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
274            iat: Some(chrono::Utc::now().timestamp() as usize),
275        };
276
277        let secret = "test-secret";
278        unsafe { std::env::set_var("JWT_SECRET", secret) };
279
280        let token = encode(
281            &Header::default(),
282            &claims,
283            &EncodingKey::from_secret(secret.as_bytes()),
284        )
285        .unwrap();
286
287        let validated = validate_token(&token).unwrap();
288        assert_eq!(validated.sub, "user123");
289    }
290
291    #[test]
292    fn test_expired_token() {
293        let claims = Claims {
294            sub: "user123".to_string(),
295            exp: (chrono::Utc::now() - chrono::Duration::hours(1)).timestamp() as usize,
296            iat: Some(chrono::Utc::now().timestamp() as usize),
297        };
298
299        let secret = "test-secret";
300        unsafe { std::env::set_var("JWT_SECRET", secret) };
301
302        let token = encode(
303            &Header::default(),
304            &claims,
305            &EncodingKey::from_secret(secret.as_bytes()),
306        )
307        .unwrap();
308
309        assert!(validate_token(&token).is_err());
310    }
311}