Skip to main content

hive_router/jwt/
jwks_manager.rs

1use hive_router_config::jwt_auth::{JwksProviderSourceConfig, JwtAuthConfig};
2use sonic_rs::from_str;
3use std::sync::{Arc, RwLock};
4use tokio::fs::read_to_string;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error, info};
7
8use jsonwebtoken::jwk::JwkSet;
9
10use crate::background_tasks::{BackgroundTask, BackgroundTasksManager};
11
12pub struct JwksManager {
13    sources: Vec<Arc<JwksSource>>,
14}
15
16impl JwksManager {
17    pub fn from_config(config: &JwtAuthConfig) -> Self {
18        let sources = config
19            .jwks_providers
20            .iter()
21            .map(|config| Arc::new(JwksSource::new(config.clone())))
22            .collect();
23
24        JwksManager { sources }
25    }
26
27    pub fn all(&self) -> Vec<Arc<JwkSet>> {
28        self.sources
29            .iter()
30            .filter_map(|v| match v.get_jwk_set() {
31                Ok(set) => Some(set),
32                Err(err) => {
33                    error!("Failed to use JWK set: {}, ignoring", err);
34
35                    None
36                }
37            })
38            .collect()
39    }
40
41    pub async fn prefetch_sources(&self) -> Result<(), JwksSourceError> {
42        for source in &self.sources {
43            if source.should_prefetch() {
44                match source.load_and_store_jwks().await {
45                    Ok(_) => {}
46                    Err(err) => return Err(err),
47                }
48            }
49        }
50
51        Ok(())
52    }
53
54    pub fn register_background_tasks(&self, background_tasks_mgr: &mut BackgroundTasksManager) {
55        for source in &self.sources {
56            if source.should_poll_in_background() {
57                background_tasks_mgr.register_task(source.clone());
58            }
59        }
60    }
61}
62
63#[derive(Debug)]
64pub struct JwksSource {
65    config: JwksProviderSourceConfig,
66    jwk: RwLock<Option<Arc<JwkSet>>>,
67}
68
69#[async_trait::async_trait]
70impl BackgroundTask for Arc<JwksSource> {
71    fn id(&self) -> &str {
72        "jwt_auth_jwks"
73    }
74
75    async fn run(&self, token: CancellationToken) {
76        if let JwksProviderSourceConfig::Remote {
77            polling_interval: Some(interval),
78            ..
79        } = &self.config
80        {
81            debug!("Starting remote jwks polling for source: {:?}", self.config);
82            let mut tokio_interval = tokio::time::interval(*interval);
83
84            loop {
85                tokio::select! {
86                    _ = tokio_interval.tick() => { match self.load_and_store_jwks().await {
87                        Ok(_) => {}
88                        Err(err) => {
89                            error!("Failed to load remote jwks: {}", err);
90                        }
91                    } }
92                    _ = token.cancelled() => { info!("Jwks source shutting down."); return; }
93                }
94            }
95        }
96    }
97}
98
99#[derive(thiserror::Error, Debug)]
100pub enum JwksSourceError {
101    #[error("failed to load remote jwks: {0}")]
102    RemoteJwksNetworkError(reqwest::Error),
103    #[error("failed to load file jwks: {0}")]
104    FileJwksNetworkError(std::io::Error),
105    #[error("failed to parse jwks json file: {0}")]
106    JwksContentInvalidStructure(sonic_rs::Error),
107    #[error("failed to acquire jwks handle")]
108    FailedToAcquireJwk,
109}
110
111impl JwksSource {
112    async fn load_and_store_jwks(&self) -> Result<&Self, JwksSourceError> {
113        let jwks_str = match &self.config {
114            JwksProviderSourceConfig::Remote { url, .. } => {
115                let client = reqwest::Client::new();
116                debug!("loading jwks from a remote source: {}", url);
117
118                let response_text = client
119                    .get(url)
120                    .send()
121                    .await
122                    .map_err(JwksSourceError::RemoteJwksNetworkError)?
123                    .text()
124                    .await
125                    .map_err(JwksSourceError::RemoteJwksNetworkError)?;
126
127                response_text
128            }
129            JwksProviderSourceConfig::File { file, .. } => {
130                debug!("loading jwks from a file source: {}", file.absolute);
131
132                let file_contents = read_to_string(&file.absolute)
133                    .await
134                    .map_err(JwksSourceError::FileJwksNetworkError)?;
135
136                file_contents
137            }
138        };
139
140        let new_jwk = Arc::new(
141            from_str::<JwkSet>(&jwks_str).map_err(JwksSourceError::JwksContentInvalidStructure)?,
142        );
143
144        if let Ok(mut w_jwk) = self.jwk.write() {
145            *w_jwk = Some(new_jwk);
146        }
147
148        Ok(self)
149    }
150
151    pub fn new(config: JwksProviderSourceConfig) -> Self {
152        Self {
153            config,
154            jwk: RwLock::new(None),
155        }
156    }
157
158    pub fn should_poll_in_background(&self) -> bool {
159        match &self.config {
160            JwksProviderSourceConfig::Remote { .. } => true,
161            JwksProviderSourceConfig::File { .. } => false,
162        }
163    }
164
165    pub fn should_prefetch(&self) -> bool {
166        match &self.config {
167            JwksProviderSourceConfig::Remote { prefetch, .. } => match prefetch {
168                Some(prefetch) => *prefetch,
169                None => false,
170            },
171            JwksProviderSourceConfig::File { .. } => true,
172        }
173    }
174
175    pub fn get_jwk_set(&self) -> Result<Arc<JwkSet>, JwksSourceError> {
176        if let Ok(jwk) = self.jwk.try_read() {
177            if let Some(jwk) = jwk.as_ref() {
178                return Ok(jwk.clone());
179            }
180        }
181
182        Err(JwksSourceError::FailedToAcquireJwk)
183    }
184}