1use std::{collections::HashMap, sync::Mutex, time::SystemTime};
2use super::{AuthProvider, AuthResult, authenticator::GatewayAuthError};
3use hyper::http::request::Parts;
4use crate::config::{ClientInfo, ConfigUpdate};
5use jsonwebtoken as jwt;
6use serde::{Serialize, Deserialize};
7use tracing::{event, Level};
8use lru::LruCache;
9
10#[derive(Debug)]
11pub struct JWTAuthProvider {
12 apps: HashMap<String, ClientInfo>,
13 token_cache: Mutex<LruCache<String, String>>,
14}
15
16impl AuthProvider for JWTAuthProvider {
17 fn update_config(&mut self, update: ConfigUpdate) {
18 match update {
19 ConfigUpdate::ClientUpdate(client) => {
20 let client_id = client.client_id.clone();
21 self.apps.insert(client_id, client);
22 },
23 ConfigUpdate::ClientRemove(cid) => {
24 self.apps.remove(&cid);
25 },
26 _ => {},
27 }
28 }
29
30 fn identify_client(&self, head: Parts, service_id: &str) -> Result<(Parts, AuthResult), GatewayAuthError> {
31 let token = Self::extract_token(&head)?;
32 let client_id = Self::extract_client_id(&token)?;
33 let client = self.apps.get(&client_id).ok_or(GatewayAuthError::UnknownClient)?;
34 let sla = client.services.get(service_id).ok_or(GatewayAuthError::InvalidSLA)?;
35
36 let mut cache = self.token_cache.lock().unwrap();
38 if let Some(cached_key) = cache.get(&token) {
39 event!(Level::DEBUG, "cached data {} {}", cached_key, client.app_key);
40 if cached_key.eq(&client.app_key) {
41 return Ok((head, AuthResult {client_id: client.client_id.clone(), sla: sla.clone()}))
42 } else {
43 return Err(GatewayAuthError::InvalidToken);
44 }
45 } else {
46 Self::verify_token(token.clone(), &client.pub_key)?;
47 cache.put(token, client.app_key.clone());
48 return Ok((head, AuthResult {client_id: client.client_id.clone(), sla: sla.clone()}))
49 }
50 }
51}
52
53
54impl JWTAuthProvider {
55
56 pub fn new() -> Self {
57 JWTAuthProvider {
58 apps: HashMap::new(),
59 token_cache: Mutex::new(LruCache::new(1024)),
60 }
61 }
62
63 fn extract_token(head: &Parts) -> Result<String, GatewayAuthError> {
64 if let Some(token) = head.headers.get(hyper::header::AUTHORIZATION) { let segs: Vec<&str> = token.to_str().unwrap().split(' ').collect();
66 let token = *(segs.get(1).unwrap_or(&""));
67 Ok(String::from(token))
68 } else {
69 Err(GatewayAuthError::TokenNotFound)
70 }
71 }
72
73 fn extract_client_id(token: &str) -> Result<String, GatewayAuthError> {
74 let ts = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap();
75 if let Ok(t) = jwt::dangerous_insecure_decode::<JwtClaims>(token) {
76 if t.claims.exp > ts.as_secs() {
77 return Ok(t.claims.sub);
78 }
79 }
80 Err(GatewayAuthError::InvalidToken)
81 }
82
83 fn verify_token(token: String, pubkey: &str) -> Result<(), GatewayAuthError> {
84 let verifier = jwt::Validation::new(jwt::Algorithm::ES256);
85 let verify_key = jwt::DecodingKey::from_ec_pem(pubkey.as_bytes()).unwrap();
86 if let Ok(_td) = jwt::decode::<JwtClaims>(&token, &verify_key, &verifier) {
87 Ok(())
88 } else {
89 Err(GatewayAuthError::InvalidToken)
90 }
91 }
92
93}
94
95
96#[derive(Debug, Serialize, Deserialize)]
97pub struct JwtClaims {
98 pub exp: u64, pub iat: Option<u64>, pub iss: Option<String>, pub sub: String, }