http_tunnel_handler/
auth.rs1use anyhow::{Context, Result, anyhow};
7use aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest;
8use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
9use once_cell::sync::Lazy;
10use serde::{Deserialize, Serialize};
11use std::sync::RwLock;
12use tracing::{debug, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Claims {
17 pub sub: String, pub exp: usize, #[serde(skip_serializing_if = "Option::is_none")]
20 pub iat: Option<usize>, }
22
23#[derive(Debug, Clone, Deserialize)]
25struct Jwks {
26 keys: Vec<JwkKey>,
27}
28
29#[derive(Debug, Clone, Deserialize)]
31struct JwkKey {
32 kty: String, kid: String, alg: String, #[serde(skip_serializing_if = "Option::is_none")]
36 n: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
38 e: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
40 k: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
42 #[allow(dead_code)]
43 r#use: Option<String>, }
45
46static JWKS_CACHE: Lazy<RwLock<Option<Jwks>>> = Lazy::new(|| RwLock::new(None));
48
49fn load_jwks() -> Result<Jwks> {
51 {
53 let cache = JWKS_CACHE.read().unwrap();
54 if let Some(jwks) = cache.as_ref() {
55 return Ok(jwks.clone());
56 }
57 }
58
59 let jwks_content = if let Ok(jwks_json) = std::env::var("JWKS") {
61 debug!("Loading JWKS from JWKS environment variable");
62 jwks_json
63 } else {
64 let jwks_path =
66 std::env::var("JWKS_PATH").unwrap_or_else(|_| "/var/task/jwks.json".to_string());
67
68 debug!("Loading JWKS from file: {}", jwks_path);
69
70 std::fs::read_to_string(&jwks_path)
71 .with_context(|| format!("Failed to read JWKS file at {}", jwks_path))?
72 };
73
74 let jwks: Jwks = serde_json::from_str(&jwks_content).context("Failed to parse JWKS JSON")?;
75
76 {
78 let mut cache = JWKS_CACHE.write().unwrap();
79 *cache = Some(jwks.clone());
80 }
81
82 info!("JWKS loaded successfully with {} keys", jwks.keys.len());
83 Ok(jwks)
84}
85
86pub fn is_auth_required() -> bool {
88 std::env::var("REQUIRE_AUTH")
89 .unwrap_or_else(|_| "false".to_string())
90 .to_lowercase()
91 == "true"
92}
93
94fn extract_token(request: &ApiGatewayWebsocketProxyRequest) -> Option<String> {
97 if let Some(auth_header) = request
99 .headers
100 .get("authorization")
101 .or_else(|| request.headers.get("Authorization"))
102 && let Some(token) = auth_header
103 .to_str()
104 .ok()
105 .and_then(|s| s.strip_prefix("Bearer "))
106 {
107 debug!("Token extracted from Authorization header");
108 return Some(token.to_string());
109 }
110
111 if let Some(token) = request.query_string_parameters.first("token") {
113 warn!("Token extracted from query parameter (consider using Authorization header)");
114 return Some(token.to_string());
115 }
116
117 None
118}
119
120pub fn validate_token(token: &str) -> Result<Claims> {
122 if let Ok(jwks) = load_jwks() {
124 for key in &jwks.keys {
126 debug!(
127 "Trying key: {} (type: {}, alg: {})",
128 key.kid, key.kty, key.alg
129 );
130
131 let result = match key.kty.as_str() {
132 "RSA" => validate_with_rsa_key(token, key),
133 "oct" => validate_with_symmetric_key(token, key),
134 _ => {
135 warn!("Unsupported key type: {} (kid: {})", key.kty, key.kid);
136 continue;
137 }
138 };
139
140 match result {
141 Ok(claims) => {
142 info!("✅ Token validated with key: {} ({})", key.kid, key.alg);
143 return Ok(claims);
144 }
145 Err(e) => {
146 debug!("Key {} validation failed: {}", key.kid, e);
147 }
148 }
149 }
150
151 warn!(
152 "Token validation failed with all {} JWKS keys",
153 jwks.keys.len()
154 );
155 return Err(anyhow!("Token validation failed with all JWKS keys"));
156 }
157
158 let secret = std::env::var("JWT_SECRET")
160 .unwrap_or_else(|_| "default-secret-change-in-production".to_string());
161
162 debug!("Using JWT_SECRET for validation (JWKS not available)");
163
164 let validation = Validation::new(Algorithm::HS256);
165 let token_data = decode::<Claims>(
166 token,
167 &DecodingKey::from_secret(secret.as_bytes()),
168 &validation,
169 )?;
170
171 Ok(token_data.claims)
172}
173
174fn validate_with_rsa_key(token: &str, key: &JwkKey) -> Result<Claims> {
176 let n = key
177 .n
178 .as_ref()
179 .ok_or_else(|| anyhow!("Missing 'n' in RSA key"))?;
180 let e = key
181 .e
182 .as_ref()
183 .ok_or_else(|| anyhow!("Missing 'e' in RSA key"))?;
184
185 let algorithm = match key.alg.as_str() {
186 "RS256" => Algorithm::RS256,
187 "RS384" => Algorithm::RS384,
188 "RS512" => Algorithm::RS512,
189 _ => return Err(anyhow!("Unsupported RSA algorithm: {}", key.alg)),
190 };
191
192 let decoding_key = DecodingKey::from_rsa_components(n, e)?;
194
195 let mut validation = Validation::new(algorithm);
197 validation.validate_aud = false;
198 validation.validate_exp = true;
199
200 let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
201 Ok(token_data.claims)
202}
203
204fn validate_with_symmetric_key(token: &str, key: &JwkKey) -> Result<Claims> {
206 let k = key
207 .k
208 .as_ref()
209 .ok_or_else(|| anyhow!("Missing 'k' in symmetric key"))?;
210
211 let key_bytes = base64_url_decode(k)?;
212
213 let algorithm = match key.alg.as_str() {
214 "HS256" => Algorithm::HS256,
215 "HS384" => Algorithm::HS384,
216 "HS512" => Algorithm::HS512,
217 _ => return Err(anyhow!("Unsupported HMAC algorithm: {}", key.alg)),
218 };
219
220 let decoding_key = DecodingKey::from_secret(&key_bytes);
221 let validation = Validation::new(algorithm);
222
223 let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
224 Ok(token_data.claims)
225}
226
227fn base64_url_decode(s: &str) -> Result<Vec<u8>> {
229 use base64::Engine;
230 base64::engine::general_purpose::URL_SAFE_NO_PAD
231 .decode(s)
232 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(s))
233 .context("Failed to decode base64url")
234}
235
236pub fn authenticate_request(request: &ApiGatewayWebsocketProxyRequest) -> Result<Option<Claims>> {
242 if !is_auth_required() {
243 debug!("Authentication not required");
244 return Ok(None);
245 }
246
247 info!("Authentication required, validating token");
248
249 let token =
250 extract_token(request).ok_or_else(|| anyhow!("No authentication token provided"))?;
251
252 match validate_token(&token) {
253 Ok(claims) => {
254 info!("Token validated successfully for user: {}", claims.sub);
255 Ok(Some(claims))
256 }
257 Err(e) => {
258 warn!("Token validation failed: {}", e);
259 Err(anyhow!("Invalid or expired token"))
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use jsonwebtoken::{EncodingKey, Header, encode};
268
269 #[test]
270 fn test_create_and_validate_token() {
271 let claims = Claims {
272 sub: "user123".to_string(),
273 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
274 iat: Some(chrono::Utc::now().timestamp() as usize),
275 };
276
277 let secret = "test-secret";
278 unsafe { std::env::set_var("JWT_SECRET", secret) };
279
280 let token = encode(
281 &Header::default(),
282 &claims,
283 &EncodingKey::from_secret(secret.as_bytes()),
284 )
285 .unwrap();
286
287 let validated = validate_token(&token).unwrap();
288 assert_eq!(validated.sub, "user123");
289 }
290
291 #[test]
292 fn test_expired_token() {
293 let claims = Claims {
294 sub: "user123".to_string(),
295 exp: (chrono::Utc::now() - chrono::Duration::hours(1)).timestamp() as usize,
296 iat: Some(chrono::Utc::now().timestamp() as usize),
297 };
298
299 let secret = "test-secret";
300 unsafe { std::env::set_var("JWT_SECRET", secret) };
301
302 let token = encode(
303 &Header::default(),
304 &claims,
305 &EncodingKey::from_secret(secret.as_bytes()),
306 )
307 .unwrap();
308
309 assert!(validate_token(&token).is_err());
310 }
311}