hyperapi/auth/
jwt.rs

1use std::{collections::HashMap, sync::Mutex, time::SystemTime};
2use super::{AuthProvider, AuthResult, authenticator::GatewayAuthError};
3use hyper::http::request::Parts;
4use crate::config::{ClientInfo, ConfigUpdate};
5use jsonwebtoken as jwt;
6use serde::{Serialize, Deserialize};
7use tracing::{event, Level};
8use lru::LruCache;
9
10#[derive(Debug)]
11pub struct JWTAuthProvider {
12    apps: HashMap<String, ClientInfo>,
13    token_cache: Mutex<LruCache<String, String>>,
14}
15
16impl AuthProvider for JWTAuthProvider {
17    fn update_config(&mut self, update: ConfigUpdate) {
18        match update {
19            ConfigUpdate::ClientUpdate(client) => {
20                let client_id = client.client_id.clone();
21                self.apps.insert(client_id, client);
22            },
23            ConfigUpdate::ClientRemove(cid) => {
24                self.apps.remove(&cid);
25            },
26            _ => {},
27        }
28    }
29
30    fn identify_client(&self, head: Parts, service_id: &str) -> Result<(Parts, AuthResult), GatewayAuthError> {
31        let token =  Self::extract_token(&head)?;
32        let client_id = Self::extract_client_id(&token)?;
33        let client = self.apps.get(&client_id).ok_or(GatewayAuthError::UnknownClient)?;
34        let sla = client.services.get(service_id).ok_or(GatewayAuthError::InvalidSLA)?;
35
36        // check cache
37        let mut cache = self.token_cache.lock().unwrap();
38        if let Some(cached_key) = cache.get(&token) {
39            event!(Level::DEBUG, "cached data {} {}", cached_key, client.app_key);
40            if cached_key.eq(&client.app_key) {
41                return Ok((head, AuthResult {client_id: client.client_id.clone(), sla: sla.clone()}))
42            } else {
43                return Err(GatewayAuthError::InvalidToken);
44            }
45        } else {
46            Self::verify_token(token.clone(), &client.pub_key)?;
47            cache.put(token, client.app_key.clone());
48            return Ok((head, AuthResult {client_id: client.client_id.clone(), sla: sla.clone()}))
49        }
50    }
51}
52
53
54impl JWTAuthProvider {
55
56    pub fn new() -> Self {
57        JWTAuthProvider {
58            apps: HashMap::new(),
59            token_cache: Mutex::new(LruCache::new(1024)),
60        }
61    }
62
63    fn extract_token(head: &Parts) -> Result<String, GatewayAuthError> {
64        if let Some(token) = head.headers.get(hyper::header::AUTHORIZATION) {  // find in authorization header
65            let segs: Vec<&str> = token.to_str().unwrap().split(' ').collect();
66            let token = *(segs.get(1).unwrap_or(&""));
67            Ok(String::from(token))
68        } else {
69            Err(GatewayAuthError::TokenNotFound)
70        }
71    }
72
73    fn extract_client_id(token: &str) -> Result<String, GatewayAuthError> {
74        let ts = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap();
75        if let Ok(t) = jwt::dangerous_insecure_decode::<JwtClaims>(token) {
76            if t.claims.exp > ts.as_secs() {
77                return Ok(t.claims.sub);
78            }
79        }
80        Err(GatewayAuthError::InvalidToken)
81    }
82
83    fn verify_token(token: String, pubkey: &str) -> Result<(), GatewayAuthError> {
84        let verifier = jwt::Validation::new(jwt::Algorithm::ES256);
85        let verify_key = jwt::DecodingKey::from_ec_pem(pubkey.as_bytes()).unwrap();
86        if let Ok(_td) = jwt::decode::<JwtClaims>(&token, &verify_key, &verifier) {
87            Ok(())
88        } else {
89            Err(GatewayAuthError::InvalidToken)
90        }
91    }
92
93}
94
95
96#[derive(Debug, Serialize, Deserialize)]
97pub struct JwtClaims {
98    pub exp: u64,                    // Required (validate_exp defaults to true in validation). Expiration time (as UTC timestamp)
99    pub iat: Option<u64>,            // Optional. Issued at (as UTC timestamp)
100    pub iss: Option<String>,         // Optional. Issuer
101    pub sub: String,                 // Optional. Subject (whom token refers to)
102}