raphtory-graphql 0.11.1

Raphtory GraphQL server
Documentation
use crate::azure_auth::common::{decode_base64_urlsafe, AppState};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use poem::{
    http::StatusCode, web::Redirect, Endpoint, Error, IntoResponse, Middleware, Request, Response,
    Result,
};
use std::{collections::HashMap, sync::Arc};

#[derive(Clone)]
pub struct TokenMiddleware {
    app_state: Arc<AppState>,
}

impl TokenMiddleware {
    #[allow(dead_code)]
    pub fn new(app_state: Arc<AppState>) -> Self {
        TokenMiddleware { app_state }
    }
}

impl<E: Endpoint> Middleware<E> for TokenMiddleware {
    type Output = TokenMiddlewareImpl<E>;

    fn transform(&self, ep: E) -> Self::Output {
        TokenMiddlewareImpl {
            ep,
            app_state: self.app_state.clone(),
        }
    }
}

pub struct TokenMiddlewareImpl<E> {
    ep: E,
    app_state: Arc<AppState>,
}

#[allow(dead_code)]
#[derive(Clone)]
struct Token(String);

impl<E: Endpoint> Endpoint for TokenMiddlewareImpl<E> {
    type Output = Response;

    async fn call(&self, mut req: Request) -> Result<Self::Output> {
        let jar = req.cookie().clone();
        if let Some(_session_cookie) = jar.get("session_id") {
            if let Some(auth_cookie) = jar.get("auth_token") {
                let token_data: serde_json::Value = serde_json::from_str(
                    &auth_cookie
                        .value::<String>()
                        .expect("Unable to find cookie"),
                )
                .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;
                let access_token = token_data["access_token_secret"]
                    .as_str()
                    .ok_or_else(|| Error::from_status(StatusCode::UNAUTHORIZED))?;
                let expires_at_str = token_data["expires_at"]
                    .as_str()
                    .ok_or_else(|| Error::from_status(StatusCode::UNAUTHORIZED))?;
                let expires_at = chrono::DateTime::parse_from_rfc3339(expires_at_str)
                    .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;
                if chrono::Utc::now() > expires_at {
                    return Err(Error::from_status(StatusCode::UNAUTHORIZED));
                }

                let header = decode_header(access_token)
                    .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;
                let kid = header
                    .kid
                    .ok_or_else(|| Error::from_status(StatusCode::UNAUTHORIZED))?;

                let jwk = self
                    .app_state
                    .jwks
                    .keys
                    .iter()
                    .find(|&jwk| jwk.kid == kid)
                    .ok_or_else(|| Error::from_status(StatusCode::UNAUTHORIZED))?;

                let n = decode_base64_urlsafe(&jwk.n)
                    .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;
                let e = decode_base64_urlsafe(&jwk.e)
                    .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;

                let decoding_key = DecodingKey::from_rsa_raw_components(&n, &e);

                let validation = Validation::new(Algorithm::RS256);
                decode::<HashMap<String, serde_json::Value>>(
                    access_token,
                    &decoding_key,
                    &validation,
                )
                .map_err(|_| Error::from_status(StatusCode::UNAUTHORIZED))?;

                req.extensions_mut().insert(Token(access_token.to_string()));

                return self.ep.call(req).await.map(IntoResponse::into_response);
            }
            Ok(Redirect::temporary("/login").into_response())
        } else {
            Ok(Redirect::temporary("/login").into_response())
        }
    }
}