use std::{
fmt::{self, Display},
future::{ready, Ready},
sync::Arc,
};
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
error, Error, HttpResponse,
};
use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use log::{debug, error, info, warn};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
pub struct Jwt {
cert_invoker: Arc<CertInvoker>,
}
impl Jwt {
pub fn from(cert_invoker: Arc<CertInvoker>) -> Self {
Jwt { cert_invoker }
}
}
impl<S, B> Transform<S, ServiceRequest> for Jwt
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = JwtMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(JwtMiddleware {
service,
cert_invoker: Arc::clone(&self.cert_invoker),
}))
}
}
pub struct JwtMiddleware<S> {
service: S,
cert_invoker: Arc<CertInvoker>,
}
const BEARER: &str = "Bearer ";
impl<S, B> Service<ServiceRequest> for JwtMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::error::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let headers = req.headers();
let jwt_token = headers
.iter()
.filter(|(header, _)| header.as_str() == "authorization")
.map(|(_, value)| String::from(value.to_str().unwrap()))
.collect::<Vec<String>>();
let fut = self.service.call(req);
let cert = Arc::clone(&self.cert_invoker.cert);
Box::pin(async move {
if jwt_token.is_empty() {
warn!("Missing JWT token");
let x = actix_web::error::Error::from(JWTResponseError::missing_jwt());
return Err(x);
}
let jwt_token = jwt_token.join("");
if !jwt_token.starts_with(BEARER) {
warn!("JWT is not started with Bearer");
let x = actix_web::error::Error::from(JWTResponseError::invalid_jwt());
return Err(x);
}
let jwt_token = jwt_token.replace(BEARER, "");
let jwt_header = decode_header(&jwt_token);
if jwt_header.is_err() {
warn!("JWT header is invalid");
let x = actix_web::error::Error::from(JWTResponseError::invalid_jwt());
return Err(x);
}
let jwt_header = jwt_header.unwrap();
let kid = jwt_header.kid.unwrap();
let jwt_cert = cert.lock().await;
let cert = jwt_cert.clone().unwrap(); let key = cert.keys.iter().find(|key| key.kid == kid).unwrap();
let de_key = DecodingKey::from_rsa_components(key.n.as_str(), key.e.as_str()).unwrap();
let token = decode::<Claims>(&jwt_token, &de_key, &Validation::new(jwt_header.alg));
match token {
Ok(_) => Ok(fut.await?),
Err(err) => {
match err.kind() {
jsonwebtoken::errors::ErrorKind::InvalidSignature => return Err(invalid_invalid_signature()),
jsonwebtoken::errors::ErrorKind::ExpiredSignature => return Err(expired_jwt()),
jsonwebtoken::errors::ErrorKind::InvalidIssuer => return Err(invalid_invalid_iss()),
_ => {
warn!("JWT is invalid {:?}", err);
return Err(invalid_jwt())
},
}
}
}
})
}
}
fn invalid_jwt() -> actix_web::Error {
return actix_web::error::Error::from(
JWTResponseError::invalid_jwt(),
)
}
fn expired_jwt() -> actix_web::Error {
return actix_web::error::Error::from(
JWTResponseError::expired_jwt(),
)
}
fn invalid_invalid_signature() -> actix_web::Error {
return actix_web::error::Error::from(
JWTResponseError::invalid_invalid_signature(),
)
}
fn invalid_invalid_iss() -> actix_web::Error {
return actix_web::error::Error::from(
JWTResponseError::invalid_invalid_issr(),
)
}
async fn get_cert(cert_url: &String) -> Result<Cert, JWKSError> {
debug!("Getting cert");
let response = reqwest::get(cert_url).await;
if response.is_err() {
warn!("Error while getting cert");
return Err(JWKSError::InvokingCertUrl(
"Error while getting cert".to_string(),
));
}
let cert: Result<CertResponse, reqwest::Error> = response.unwrap().json().await;
if cert.is_err() {
warn!("Error while deserialize cert");
return Err(JWKSError::ErrorDeserializingCert(format!(
"Error while deserialize cert {:?}",
cert.err().unwrap()
)));
}
let keys = cert.unwrap().keys.iter().map(|key| {
let de_key = DecodingKey::from_rsa_components(key.n.as_str(), key.e.as_str()).unwrap();
Key::from(key.clone(), de_key)
}).collect();
Ok(Cert{keys})
}
pub struct CertInvoker {
cert: Arc<Mutex<Option<Cert>>>,
cert_url: String,
}
impl CertInvoker {
pub fn from(cert_url: String) -> Self {
CertInvoker {
cert: Arc::new(Mutex::new(None)),
cert_url,
}
}
pub async fn get_cert(&self) {
info!("Getting cert form {}", self.cert_url);
let cert = get_cert(&self.cert_url).await;
let mut jwt_cert = self.cert.lock().await;
match cert {
Ok(cert) => {
*jwt_cert = Option::Some(cert);
}
Err(er) => {
error!("Error while getting cert {:?}", er);
*jwt_cert = Option::None;
}
}
}
}
#[derive(Debug)]
pub enum JWKSError {
InvokingCertUrl(String),
ErrorDeserializingCert(String),
}
#[derive(Debug)]
struct JWTResponseError {
status_code: StatusCode,
message: String,
}
impl JWTResponseError {
pub fn invalid_jwt() -> Self {
JWTResponseError {
status_code: StatusCode::UNAUTHORIZED,
message: "Invalid JWT".to_string(),
}
}
pub fn expired_jwt() -> Self {
JWTResponseError {
status_code: StatusCode::UNAUTHORIZED,
message: "Expired JWT".to_string(),
}
}
pub fn invalid_invalid_signature() -> Self {
JWTResponseError {
status_code: StatusCode::UNAUTHORIZED,
message: "Invalid Invalid Signature".to_string(),
}
}
pub fn invalid_invalid_issr() -> Self {
JWTResponseError {
status_code: StatusCode::UNAUTHORIZED,
message: "Invalid Invalid Issure".to_string(),
}
}
pub fn missing_jwt() -> Self {
JWTResponseError {
status_code: StatusCode::UNAUTHORIZED,
message: "Missing JWT".to_string(),
}
}
}
impl error::ResponseError for JWTResponseError {
fn status_code(&self) -> StatusCode {
self.status_code
}
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code()).json(JWTResponse {
message: self.message.clone(),
})
}
}
impl Display for JWTResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{:?}", self)
}
}
#[derive(Serialize)]
pub struct JWTResponse {
pub message: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct CertResponse {
pub keys: Vec<KeyResponse>,
}
#[derive(Debug, Serialize, Deserialize)]
struct KeyResponse {
pub kid: String,
pub kty: String,
#[serde(rename = "use")]
pub use_key: String,
pub n: String,
pub e: String,
pub x5c: Option<Vec<String>>,
pub x5t: Option<String>,
#[serde(rename = "x5t#S256")]
pub x5t_s256: Option<String>,
pub alg: String,
}
impl Clone for CertResponse {
fn clone(&self) -> Self {
CertResponse {
keys: self.keys.clone(),
}
}
}
impl Clone for KeyResponse {
fn clone(&self) -> Self {
KeyResponse {
kid: self.kid.clone(),
kty: self.kty.clone(),
use_key: self.use_key.clone(),
n: self.n.clone(),
e: self.e.clone(),
x5c: self.x5c.clone(),
x5t: self.x5t.clone(),
x5t_s256: self.x5t_s256.clone(),
alg: self.alg.clone(),
}
}
}
pub struct Cert {
pub keys: Vec<Key>,
}
pub struct Key {
pub kid: String,
pub kty: String,
pub use_key: String,
pub n: String,
pub e: String,
pub x5c: Option<Vec<String>>,
pub x5t: Option<String>,
pub x5t_s256: Option<String>,
pub alg: String,
pub de_key: DecodingKey,
}
impl Clone for Cert {
fn clone(&self) -> Self {
Cert {
keys: self.keys.clone(),
}
}
}
impl Clone for Key {
fn clone(&self) -> Self {
Key {
kid: self.kid.clone(),
kty: self.kty.clone(),
use_key: self.use_key.clone(),
n: self.n.clone(),
e: self.e.clone(),
x5c: self.x5c.clone(),
x5t: self.x5t.clone(),
x5t_s256: self.x5t_s256.clone(),
alg: self.alg.clone(),
de_key: self.de_key.clone(),
}
}
}
impl Key {
fn from(key_response: KeyResponse, de_key: DecodingKey) -> Self {
Key { kid: key_response.kid, kty: key_response.kty, use_key: key_response.use_key,
n: key_response.n, e: key_response.e, x5c: key_response.x5c, x5t: key_response.x5t, x5t_s256: key_response.x5t_s256,
alg: key_response.alg, de_key }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub iss: String,
pub sub: String,
pub aud: Option<String>,
pub exp: usize,
pub nbf: Option<usize>,
pub iat: usize,
pub jti: Option<String>,
pub azp: Option<String>,
pub scope: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn get_cert_test() {
let cert_url = String::from("https://www.googleapis.com/oauth2/v3/certs");
let cert = get_cert(&cert_url).await;
assert!(cert.is_ok());
}
#[tokio::test]
async fn get_cert_wrong_url_test() {
let cert_url = String::from("https://www.googleapis.com/oauth2/v3/certsx");
let cert = get_cert(&cert_url).await;
assert!(cert.is_err());
}
}