use std::collections::HashMap;
use std::env;
use netdb_auth_shared::UserBase;
use rocket::http::Status;
use serde::{Deserialize, Serialize};
use tokio::task::block_in_place;
use once_cell::sync::Lazy;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use jsonwebtoken::errors::ErrorKind;
use rocket::request::{FromRequest, Outcome};
use rocket::Request;
use std::ops::Deref;
use std::ops::DerefMut;
pub use netdb_auth_macro_derive::has_scope;
pub use netdb_auth_shared;
pub static JWT_KEYS: Lazy<Keys> = Lazy::new(|| get_jwt_keys());
pub static ADMIN_KEY_IDENTIFIER: &str = "admin";
pub fn get_jwt_keys() -> Keys {
let url = "https://api.login.".to_owned() + &env::var("DOMAIN").expect("DOMAIN env var was not set") + "/keys";
let keys: Keys = block_in_place(move || {reqwest::blocking::get(url).unwrap().json().unwrap()});
keys
}
#[derive(Serialize, Deserialize)]
pub struct Keys {
#[serde(rename = "adminKey")]
admin_key: String,
#[serde(rename = "userKeys")]
user_keys: HashMap<String, String>,
}
fn decode_jwt(token: String) -> Result<Claims, ErrorKind> {
let token = token.trim_start_matches("Bearer").trim();
let header = jsonwebtoken::decode_header(&token).unwrap();
let kid = header.kid.unwrap();
let mut secret: String;
if kid == ADMIN_KEY_IDENTIFIER {
secret = JWT_KEYS.admin_key.clone();
} else {
secret = JWT_KEYS.user_keys.get(&kid).unwrap().clone();
}
secret = add_rsa_headers(&secret);
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&[env::var("AUDIENCE").unwrap()]);
match decode::<Claims>(
&token,
&DecodingKey::from_rsa_pem(secret.as_bytes()).unwrap(),
&validation,
) {
Ok(token) => Ok(token.claims),
Err(err) => Err(err.kind().to_owned())
}
}
fn add_rsa_headers(key: &str) -> String {
let key = key.replace(" ", "\n");
let key = format!("-----BEGIN PUBLIC KEY-----\n{key}\n-----END PUBLIC KEY-----");
key
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
sub: String,
exp: usize,
iat: usize,
iss: String,
nbf: usize,
aud: String,
email: Option<String>,
username: String,
lang: String,
country: String,
scope: String,
}
#[derive(Debug)]
pub struct JWT {
pub claims: Claims
}
#[derive(Debug)]
pub enum LoginError {
AuthHeaderMissing,
TokenExpired,
InvalidToken,
}
pub trait MyFromRequest<'r>: Sized {
type Error;
fn from_request(req: &'r Request<'_>) -> impl std::future::Future<Output = Outcome<Self, Self::Error>>;
}
impl<'r> MyFromRequest<'r> for UserBase {
type Error = LoginError;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
if env::var("DISABLE_AUTH").is_ok() && env::var("DISABLE_AUTH").unwrap() == "true" {
return Outcome::Success(UserBase {
id: 1,
username: String::from("testuser"),
email: String::from("testuser@".to_owned() + &env::var("DOMAIN").unwrap()),
lang: String::from("en-US"),
country: String::from("US"),
scopes: String::from(""),
});
}
fn is_valid(key: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
Ok(decode_jwt(String::from(key))?)
}
match req.headers().get_one("authorization") {
None => {
return Outcome::Error((Status::Unauthorized, LoginError::AuthHeaderMissing))
},
Some(key) => match is_valid(key) {
Ok(claims) => {
Outcome::Success(UserBase {
id: str::parse::<i32>(&claims.sub).unwrap(),
username: claims.username,
email: claims.email.unwrap_or_default(),
lang: claims.lang,
country: claims.country,
scopes: claims.scope,
})
},
Err(err) => match &err.kind() {
ErrorKind::ExpiredSignature => {
Outcome::Error((Status::Unauthorized, LoginError::TokenExpired))
},
_ => {
Outcome::Error((Status::Unauthorized, LoginError::InvalidToken))
}
}
},
}
}
}
pub struct User(UserBase);
impl Deref for User {
type Target = UserBase;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for User {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for User {
type Error = LoginError;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match UserBase::from_request(req).await {
Outcome::Success(user) => Outcome::Success(User(user)),
Outcome::Error((status, error)) => Outcome::Error((status, error)),
Outcome::Forward(status) => Outcome::Forward(status),
}
}
}