raphtory-graphql 0.11.1

Raphtory GraphQL server
Documentation
use base64_compat::{decode_config, URL_SAFE_NO_PAD};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use oauth2::{
    basic::BasicClient, reqwest::async_http_client, AuthorizationCode, CsrfToken,
    PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
};
use poem::{
    handler,
    http::StatusCode,
    web::{
        cookie::{Cookie, CookieJar},
        Data, Json, Query, Redirect,
    },
    IntoResponse, Response,
};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{
    collections::HashMap,
    env,
    error::Error,
    sync::{Arc, Mutex},
};

#[derive(Deserialize, Serialize)]
struct AuthRequest {
    code: String,
    state: String,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Jwks {
    pub(crate) keys: Vec<Jwk>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Jwk {
    pub(crate) kid: String,
    kty: String,
    #[serde(rename = "use")]
    use_: String,
    pub(crate) n: String,
    pub(crate) e: String,
}

#[derive(Clone)]
pub struct AppState {
    pub(crate) oauth_client: Arc<BasicClient>,
    pub(crate) csrf_state: Arc<Mutex<HashMap<String, CsrfToken>>>,
    pub(crate) pkce_verifier: Arc<Mutex<HashMap<String, PkceCodeVerifier>>>,
    pub(crate) jwks: Arc<Jwks>,
}

#[handler]
pub async fn login(data: Data<&AppState>, jar: &CookieJar) -> Redirect {
    let session_id = uuid::Uuid::new_v4().to_string();
    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
    let client_id_str = env::var("CLIENT_ID").expect("CLIENT_ID not set");
    let (authorize_url, csrf_state) = data
        .oauth_client
        .authorize_url(CsrfToken::new_random)
        .set_pkce_challenge(pkce_challenge)
        .add_scope(Scope::new("openid".to_string()))
        .add_scope(Scope::new("email".to_string()))
        .add_scope(Scope::new("offline_access".to_string()))
        .add_scope(Scope::new(
            format!("{}/public-scope", client_id_str).to_string(),
        ))
        .url();

    data.csrf_state
        .lock()
        .unwrap()
        .insert(session_id.clone(), csrf_state);
    data.pkce_verifier
        .lock()
        .unwrap()
        .insert(session_id.clone(), pkce_verifier);

    let mut session_cookie = Cookie::new("session_id", session_id);
    session_cookie.set_path("/");
    session_cookie.set_http_only(true);
    jar.add(session_cookie);

    Redirect::temporary(authorize_url.to_string().as_str())
}

#[handler]
pub async fn auth_callback(
    data: Data<&AppState>,
    query: Query<AuthRequest>,
    jar: &CookieJar,
) -> impl IntoResponse {
    if let Some(session_cookie) = jar.get("session_id") {
        let session_id = session_cookie.value::<String>().unwrap();

        let code = AuthorizationCode::new(query.0.code.clone());
        let pkce_verifier = data
            .pkce_verifier
            .lock()
            .unwrap()
            .remove(&session_id)
            .unwrap();

        let token_result = data
            .oauth_client
            .exchange_code(code)
            .set_pkce_verifier(pkce_verifier)
            .request_async(async_http_client)
            .await;

        match token_result {
            Ok(token) => {
                let access_token = token.access_token();
                let expires_in = token
                    .expires_in()
                    .unwrap_or(core::time::Duration::from_secs(60 * 60 * 24));
                let expiration = Utc::now() + Duration::from_std(expires_in).unwrap();

                let token_data = json!({
                    "access_token_secret": access_token.secret(),
                    "expires_at": expiration.to_rfc3339()
                });

                let mut auth_cookie = Cookie::new("auth_token", token_data.to_string());
                auth_cookie.set_expires(expiration);
                auth_cookie.set_path("/");
                auth_cookie.set_http_only(true);
                jar.add(auth_cookie);
                return Redirect::temporary("/").into_response();
            }
            Err(_err) => {
                return Response::builder()
                    .status(StatusCode::UNAUTHORIZED)
                    .content_type("application/json")
                    .body(json!({"error": "Login failed"}).to_string());
            }
        }
    } else {
        return Response::builder()
            .status(StatusCode::UNAUTHORIZED)
            .content_type("application/json")
            .body(json!({"error": "Session not found. Please login again"}).to_string());
    }
}

pub fn decode_base64_urlsafe(base64_str: &str) -> Result<Vec<u8>, Box<dyn Error>> {
    let decoded = decode_config(base64_str, URL_SAFE_NO_PAD)?;
    Ok(decoded)
}

#[handler]
pub async fn secure_endpoint() -> Json<serde_json::Value> {
    Json(serde_json::json!({
        "message": "Secured"
    }))
}

#[handler]
pub async fn logout(jar: &CookieJar) -> String {
    if let Some(mut cookie) = jar.get("auth_token") {
        cookie.set_expires(Utc::now() - Duration::days(1));
        jar.remove("auth_token");
    }
    if let Some(mut cookie) = jar.get("session_id") {
        cookie.set_expires(Utc::now() - Duration::days(1));
        jar.remove("session_id");
    }
    "You have been logged out.".to_string()
}

#[handler]
pub async fn verify(data: Data<&AppState>, jar: &CookieJar) -> Json<serde_json::Value> {
    if let Some(_session_cookie) = jar.get("session_id") {
        if let Some(cookie) = jar.get("auth_token") {
            let cookie_value = cookie.value::<String>().expect("Unable to find cookie");
            let token_data: serde_json::Value =
                serde_json::from_str(&cookie_value).expect("Invalid cookie format");
            let token = token_data["access_token_secret"]
                .as_str()
                .expect("No access token found");
            let expires_at_str = token_data["expires_at"]
                .as_str()
                .expect("No expiration time found");

            let expires_at = chrono::DateTime::parse_from_rfc3339(expires_at_str)
                .expect("Invalid expiration format");

            if Utc::now() > expires_at {
                return Json(serde_json::json!({
                    "message": "Access token expired, please login again"
                }));
            }

            let header = decode_header(token).expect("Unable to decode header");
            let kid = header.kid.expect("Token header does not have a kid field");

            let jwk = data
                .jwks
                .keys
                .iter()
                .find(|&jwk| jwk.kid == kid)
                .expect("Key ID not found in JWKS");

            let n = decode_base64_urlsafe(&jwk.n).unwrap();
            let e = decode_base64_urlsafe(&jwk.e).unwrap();

            let decoding_key = DecodingKey::from_rsa_raw_components(&n, &e);

            let validation = Validation::new(Algorithm::RS256);

            let token_data =
                decode::<HashMap<String, serde_json::Value>>(token, &decoding_key, &validation);

            match token_data {
                Ok(_dc) => Json(serde_json::json!({
                    "message": "Valid access token",
                })),
                Err(_err) => Json(serde_json::json!({
                    "message": "No valid auth token found",
                })),
            }
        } else {
            Json(serde_json::json!({
                "message": "No cookie auth_token found, please login"
            }))
        }
    } else {
        Json(serde_json::json!({
            "message": "No session_id found, please login"
        }))
    }
}

#[allow(dead_code)]
pub async fn get_jwks() -> Result<Jwks, Box<dyn Error>> {
    let authority = env::var("AUTHORITY").expect("AUTHORITY not set");
    let jwks_url = format!("{}/discovery/v2.0/keys", authority);
    let client = Client::new();
    let response = client.get(&jwks_url).send().await?;
    let jwks = response.json::<Jwks>().await?;
    Ok(jwks)
}