netdb_auth 0.2.0

Netdb Auth validation for rocket
Documentation
use std::collections::HashMap;
use std::env;
use netdb_auth_shared::UserBase;
use rocket::http::Status;
use serde::{Deserialize, Serialize};
use tokio::task::block_in_place;
use once_cell::sync::Lazy;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use jsonwebtoken::errors::ErrorKind;
use rocket::request::{FromRequest, Outcome};
use rocket::Request;
use std::ops::Deref;
use std::ops::DerefMut;

pub use netdb_auth_macro_derive::has_scope;
pub use netdb_auth_shared;

pub static JWT_KEYS: Lazy<Keys> = Lazy::new(|| get_jwt_keys());
pub static ADMIN_KEY_IDENTIFIER: &str = "admin";

pub fn get_jwt_keys() -> Keys {
    let url = "https://api.login.".to_owned() + &env::var("DOMAIN").expect("DOMAIN env var was not set") + "/keys";
    let keys: Keys = block_in_place(move || {reqwest::blocking::get(url).unwrap().json().unwrap()});
    keys
}

#[derive(Serialize, Deserialize)]
pub struct Keys {
    #[serde(rename = "adminKey")]
    admin_key: String,

    #[serde(rename = "userKeys")]
    user_keys: HashMap<String, String>,
}

fn decode_jwt(token: String) -> Result<Claims, ErrorKind> {
    let token = token.trim_start_matches("Bearer").trim();
    let header = jsonwebtoken::decode_header(&token).unwrap();
    let kid = header.kid.unwrap();
    let mut secret: String;

    if kid == ADMIN_KEY_IDENTIFIER {
        secret = JWT_KEYS.admin_key.clone();
    } else {
        secret = JWT_KEYS.user_keys.get(&kid).unwrap().clone();
    }

    secret = add_rsa_headers(&secret);

    let mut validation = Validation::new(Algorithm::RS256);
    validation.set_audience(&[env::var("AUDIENCE").unwrap()]);

    match decode::<Claims>(
        &token,
        &DecodingKey::from_rsa_pem(secret.as_bytes()).unwrap(),
        &validation,
    ) {
        Ok(token) => Ok(token.claims),
        Err(err) => Err(err.kind().to_owned())
    }
}

fn add_rsa_headers(key: &str) -> String {
    let key = key.replace(" ", "\n");
    let key = format!("-----BEGIN PUBLIC KEY-----\n{key}\n-----END PUBLIC KEY-----");
    key
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
    sub: String,
    exp: usize,
    iat: usize,
    iss: String,
    nbf: usize,
    aud: String,
    email: Option<String>,
    username: String,
    lang: String,
    country: String,
    scope: String,
}

#[derive(Debug)]
pub struct JWT {
    pub claims: Claims
}

#[derive(Debug)]
pub enum LoginError {
    AuthHeaderMissing,
    TokenExpired,
    InvalidToken,
}

pub trait MyFromRequest<'r>: Sized {
    type Error;

    fn from_request(req: &'r Request<'_>) -> impl std::future::Future<Output = Outcome<Self, Self::Error>>;
}

impl<'r> MyFromRequest<'r> for UserBase {
    type Error = LoginError;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        if env::var("DISABLE_AUTH").is_ok() && env::var("DISABLE_AUTH").unwrap() == "true" {
            return Outcome::Success(UserBase {
                id: 1,
                username: String::from("testuser"),
                email: String::from("testuser@".to_owned() + &env::var("DOMAIN").unwrap()),
                lang: String::from("en-US"),
                country: String::from("US"),
                scopes: String::from(""),
            });
        }

        fn is_valid(key: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
            Ok(decode_jwt(String::from(key))?)
        }

        match req.headers().get_one("authorization") {
            None => {
                return Outcome::Error((Status::Unauthorized, LoginError::AuthHeaderMissing))
            },
            Some(key) => match is_valid(key) {
                Ok(claims) => {
                    Outcome::Success(UserBase {
                        id: str::parse::<i32>(&claims.sub).unwrap(),
                        username: claims.username,
                        email: claims.email.unwrap_or_default(),
                        lang: claims.lang,
                        country: claims.country,
                        scopes: claims.scope,
                    })
                },
                Err(err) => match &err.kind() {
                    ErrorKind::ExpiredSignature => {
                        Outcome::Error((Status::Unauthorized, LoginError::TokenExpired))
                    },
                    _ => {
                        Outcome::Error((Status::Unauthorized, LoginError::InvalidToken))
                    }
                }
            },
        }
    }
}

pub struct User(UserBase);

impl Deref for User {
    type Target = UserBase;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for User {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for User {
    type Error = LoginError;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        match UserBase::from_request(req).await {
            Outcome::Success(user) => Outcome::Success(User(user)),
            Outcome::Error((status, error)) => Outcome::Error((status, error)),
            Outcome::Forward(status) => Outcome::Forward(status),
        }
    }
}