hive-router 0.0.53

GraphQL router/gateway for Federation
use hive_router_config::jwt_auth::{JwksProviderSourceConfig, JwtAuthConfig};
use hive_router_internal::background_tasks::{BackgroundTask, BackgroundTasksManager};
use sonic_rs::from_str;
use std::sync::{Arc, RwLock};
use tokio::fs::read_to_string;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};

use jsonwebtoken::jwk::JwkSet;

pub struct JwksManager {
    sources: Vec<Arc<JwksSource>>,
}

impl JwksManager {
    pub fn from_config(config: &JwtAuthConfig) -> Self {
        let sources = config
            .jwks_providers
            .iter()
            .map(|config| Arc::new(JwksSource::new(config.clone())))
            .collect();

        JwksManager { sources }
    }

    pub fn all(&self) -> Vec<Arc<JwkSet>> {
        self.sources
            .iter()
            .filter_map(|v| match v.get_jwk_set() {
                Ok(set) => Some(set),
                Err(err) => {
                    error!("Failed to use JWK set: {}, ignoring", err);

                    None
                }
            })
            .collect()
    }

    pub async fn prefetch_sources(&self) -> Result<(), JwksSourceError> {
        for source in &self.sources {
            if source.should_prefetch() {
                match source.load_and_store_jwks().await {
                    Ok(_) => {}
                    Err(err) => return Err(err),
                }
            }
        }

        Ok(())
    }

    pub fn register_background_tasks(&self, background_tasks_mgr: &mut BackgroundTasksManager) {
        for source in &self.sources {
            if source.should_poll_in_background() {
                background_tasks_mgr.register_task(JwksSourceTask(source.clone()));
            }
        }
    }
}

#[derive(Debug)]
pub struct JwksSource {
    config: JwksProviderSourceConfig,
    jwk: RwLock<Option<Arc<JwkSet>>>,
}

struct JwksSourceTask(Arc<JwksSource>);

#[async_trait::async_trait]
impl BackgroundTask for JwksSourceTask {
    fn id(&self) -> &str {
        "jwt_auth_jwks"
    }

    async fn run(&self, token: CancellationToken) {
        if let JwksProviderSourceConfig::Remote {
            polling_interval: Some(interval),
            ..
        } = &self.0.config
        {
            debug!(
                "Starting remote jwks polling for source: {:?}",
                self.0.config
            );
            let mut tokio_interval = tokio::time::interval(*interval);

            loop {
                tokio::select! {
                    _ = tokio_interval.tick() => { match self.0.load_and_store_jwks().await {
                        Ok(_) => {}
                        Err(err) => {
                            error!("Failed to load remote jwks: {}", err);
                        }
                    } }
                    _ = token.cancelled() => { info!("Jwks source shutting down."); return; }
                }
            }
        }
    }
}

#[derive(thiserror::Error, Debug)]
pub enum JwksSourceError {
    #[error("failed to load remote jwks: {0}")]
    RemoteJwksNetworkError(reqwest::Error),
    #[error("failed to load file jwks: {0}")]
    FileJwksNetworkError(std::io::Error),
    #[error("failed to parse jwks json file: {0}")]
    JwksContentInvalidStructure(sonic_rs::Error),
    #[error("failed to acquire jwks handle")]
    FailedToAcquireJwk,
}

impl JwksSource {
    async fn load_and_store_jwks(&self) -> Result<&Self, JwksSourceError> {
        let jwks_str = match &self.config {
            JwksProviderSourceConfig::Remote { url, .. } => {
                let client = reqwest::Client::new();
                debug!("loading jwks from a remote source: {}", url);

                let response_text = client
                    .get(url)
                    .send()
                    .await
                    .map_err(JwksSourceError::RemoteJwksNetworkError)?
                    .text()
                    .await
                    .map_err(JwksSourceError::RemoteJwksNetworkError)?;

                response_text
            }
            JwksProviderSourceConfig::File { file, .. } => {
                debug!("loading jwks from a file source: {}", file.absolute);

                let file_contents = read_to_string(&file.absolute)
                    .await
                    .map_err(JwksSourceError::FileJwksNetworkError)?;

                file_contents
            }
        };

        let new_jwk = Arc::new(
            from_str::<JwkSet>(&jwks_str).map_err(JwksSourceError::JwksContentInvalidStructure)?,
        );

        if let Ok(mut w_jwk) = self.jwk.write() {
            *w_jwk = Some(new_jwk);
        }

        Ok(self)
    }

    pub fn new(config: JwksProviderSourceConfig) -> Self {
        Self {
            config,
            jwk: RwLock::new(None),
        }
    }

    pub fn should_poll_in_background(&self) -> bool {
        match &self.config {
            JwksProviderSourceConfig::Remote { .. } => true,
            JwksProviderSourceConfig::File { .. } => false,
        }
    }

    pub fn should_prefetch(&self) -> bool {
        match &self.config {
            JwksProviderSourceConfig::Remote { prefetch, .. } => match prefetch {
                Some(prefetch) => *prefetch,
                None => false,
            },
            JwksProviderSourceConfig::File { .. } => true,
        }
    }

    pub fn get_jwk_set(&self) -> Result<Arc<JwkSet>, JwksSourceError> {
        if let Ok(jwk) = self.jwk.try_read() {
            if let Some(jwk) = jwk.as_ref() {
                return Ok(jwk.clone());
            }
        }

        Err(JwksSourceError::FailedToAcquireJwk)
    }
}