tower-oauth2-resource-server 0.12.0

Tower middleware that provides JWT authorization against an OpenID Connect (OIDC) Provider
Documentation
use async_trait::async_trait;
use jsonwebtoken::jwk::JwkSet;
use log::warn;
use reqwest::{Client, Url};
use std::{sync::Arc, time::Duration};
use tokio::time;

use crate::error::JwkError;

pub trait JwksProducer {
    fn add_consumer(&mut self, receiver: Arc<dyn JwksConsumer>);
    fn start(&self);
}

#[async_trait]
pub trait JwksConsumer: Send + Sync {
    async fn receive_jwks(&self, jwks: JwkSet);
}

pub struct TimerJwksProducer {
    jwks_url: Url,
    refresh_interval: Duration,
    receivers: Vec<Arc<dyn JwksConsumer>>,
    http_client: Client,
}

impl TimerJwksProducer {
    pub fn new(jwks_url: Url, refresh_interval: Duration, http_client: Client) -> Self {
        Self {
            jwks_url,
            refresh_interval,
            receivers: Vec::new(),
            http_client,
        }
    }
}

impl JwksProducer for TimerJwksProducer {
    fn add_consumer(&mut self, consumer: Arc<dyn JwksConsumer>) {
        self.receivers.push(consumer);
    }

    fn start(&self) {
        tokio::spawn(fetch_jwks_job(
            self.jwks_url.clone(),
            self.refresh_interval,
            self.receivers.clone(),
            self.http_client.clone(),
        ));
    }
}

async fn fetch_jwks_job(
    jwks_url: Url,
    refresh_interval: Duration,
    consumers: Vec<Arc<dyn JwksConsumer>>,
    http_client: Client,
) {
    let mut interval = time::interval(refresh_interval);
    loop {
        interval.tick().await;
        match fetch_jwks(jwks_url.clone(), http_client.clone()).await {
            Ok(jwks) => {
                for consumer in &consumers {
                    consumer.receive_jwks(jwks.clone()).await;
                }
            }
            Err(e) => {
                warn!("Failed to fetch JWK set: {:?}", e);
            }
        }
    }
}

async fn fetch_jwks(jwks_url: Url, http_client: Client) -> Result<JwkSet, JwkError> {
    let response = http_client
        .get(jwks_url)
        .send()
        .await
        .map_err(|_| JwkError::FetchFailed)?;
    let parsed = response
        .json::<JwkSet>()
        .await
        .map_err(|_| JwkError::ParseFailed)?;
    Ok(parsed)
}

#[cfg(test)]
mod tests {
    use std::time::Instant;

    use jsonwebtoken::jwk::Jwk;
    use serde_json::json;
    use tokio::sync::RwLock;
    use wiremock::{
        Mock, MockServer, ResponseTemplate,
        matchers::{method, path},
    };

    use super::*;

    struct TestConsumer {
        jwks: Arc<RwLock<Option<JwkSet>>>,
    }

    impl TestConsumer {
        pub fn new() -> Self {
            Self {
                jwks: Arc::new(RwLock::new(None)),
            }
        }
        pub async fn has_jwks(&self) -> bool {
            self.jwks.read().await.is_some()
        }
    }

    #[async_trait]
    impl JwksConsumer for TestConsumer {
        async fn receive_jwks(&self, jwks: JwkSet) {
            self.jwks.write().await.replace(jwks);
        }
    }

    #[tokio::test]
    async fn test_should_notify_consumers() {
        let mock_server = MockServer::start().await;
        mock_jwks(&mock_server, "/jwks.json").await;

        let consumer = Arc::new(TestConsumer::new());
        let mut producer = TimerJwksProducer::new(
            format!("{}/jwks.json", &mock_server.uri())
                .parse::<Url>()
                .unwrap(),
            Duration::from_millis(5),
            Client::builder()
                .build()
                .expect("Could not create reqwest client"),
        );
        producer.add_consumer(consumer.clone());
        producer.start();

        let mut success = false;
        let start = Instant::now();
        while start.elapsed() < Duration::from_millis(500) {
            if consumer.has_jwks().await {
                success = true;
                break;
            }
            tokio::time::sleep(Duration::from_millis(10)).await;
        }
        assert!(success, "Consumer did not receive JWKS in time");
    }

    async fn mock_jwks(server: &MockServer, jwks_path: &str) {
        let jwk: Jwk = serde_json::from_value(json!({
            "kty": "RSA",
            "use_": "sig",
            "alg": "RS256",
            "kid": "test-kid",
            "n": "oEz_RrupHP9d9XiFbXLoJMwG-75Z18t4ziBy2PHTZHxkHOep7aFeNj-13NmIcL4ooj-2nxrLhWbgA2iBaWr95wKkf5peTsc-5Q6-B2uCcn9xPSQK08Y_jNVhtly3mAOdsT4Y9mQIO_oqaqEyzutypZBEu-18NkbGVwkNhG9sxvUjFXHvMoJs5iwILaDA2FhuEioIDzOy-ZjD8p928ye2v8CdPWl1xPxoBXd2KIe3RkocRDxLeeBg3wH8a9tQ5Z7fOmiXiAI8_lN57zYf078yazvLUlKzCo1pQoR25MU51d7zgI_I7H2Fb5PZGcCmfvN1Up41OfEQyMLL6JYyoP23XQ",
            "e": "AQAB"
        }))
        .unwrap();
        Mock::given(method("GET"))
            .and(path(jwks_path))
            .respond_with(ResponseTemplate::new(200).set_body_json(JwkSet { keys: vec![jwk] }))
            .mount(server)
            .await
    }
}