hive_router/jwt/
jwks_manager.rs1use hive_router_config::jwt_auth::{JwksProviderSourceConfig, JwtAuthConfig};
2use sonic_rs::from_str;
3use std::sync::{Arc, RwLock};
4use tokio::fs::read_to_string;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error, info};
7
8use jsonwebtoken::jwk::JwkSet;
9
10use crate::background_tasks::{BackgroundTask, BackgroundTasksManager};
11
12pub struct JwksManager {
13 sources: Vec<Arc<JwksSource>>,
14}
15
16impl JwksManager {
17 pub fn from_config(config: &JwtAuthConfig) -> Self {
18 let sources = config
19 .jwks_providers
20 .iter()
21 .map(|config| Arc::new(JwksSource::new(config.clone())))
22 .collect();
23
24 JwksManager { sources }
25 }
26
27 pub fn all(&self) -> Vec<Arc<JwkSet>> {
28 self.sources
29 .iter()
30 .filter_map(|v| match v.get_jwk_set() {
31 Ok(set) => Some(set),
32 Err(err) => {
33 error!("Failed to use JWK set: {}, ignoring", err);
34
35 None
36 }
37 })
38 .collect()
39 }
40
41 pub async fn prefetch_sources(&self) -> Result<(), JwksSourceError> {
42 for source in &self.sources {
43 if source.should_prefetch() {
44 match source.load_and_store_jwks().await {
45 Ok(_) => {}
46 Err(err) => return Err(err),
47 }
48 }
49 }
50
51 Ok(())
52 }
53
54 pub fn register_background_tasks(&self, background_tasks_mgr: &mut BackgroundTasksManager) {
55 for source in &self.sources {
56 if source.should_poll_in_background() {
57 background_tasks_mgr.register_task(source.clone());
58 }
59 }
60 }
61}
62
63#[derive(Debug)]
64pub struct JwksSource {
65 config: JwksProviderSourceConfig,
66 jwk: RwLock<Option<Arc<JwkSet>>>,
67}
68
69#[async_trait::async_trait]
70impl BackgroundTask for Arc<JwksSource> {
71 fn id(&self) -> &str {
72 "jwt_auth_jwks"
73 }
74
75 async fn run(&self, token: CancellationToken) {
76 if let JwksProviderSourceConfig::Remote {
77 polling_interval: Some(interval),
78 ..
79 } = &self.config
80 {
81 debug!("Starting remote jwks polling for source: {:?}", self.config);
82 let mut tokio_interval = tokio::time::interval(*interval);
83
84 loop {
85 tokio::select! {
86 _ = tokio_interval.tick() => { match self.load_and_store_jwks().await {
87 Ok(_) => {}
88 Err(err) => {
89 error!("Failed to load remote jwks: {}", err);
90 }
91 } }
92 _ = token.cancelled() => { info!("Jwks source shutting down."); return; }
93 }
94 }
95 }
96 }
97}
98
99#[derive(thiserror::Error, Debug)]
100pub enum JwksSourceError {
101 #[error("failed to load remote jwks: {0}")]
102 RemoteJwksNetworkError(reqwest::Error),
103 #[error("failed to load file jwks: {0}")]
104 FileJwksNetworkError(std::io::Error),
105 #[error("failed to parse jwks json file: {0}")]
106 JwksContentInvalidStructure(sonic_rs::Error),
107 #[error("failed to acquire jwks handle")]
108 FailedToAcquireJwk,
109}
110
111impl JwksSource {
112 async fn load_and_store_jwks(&self) -> Result<&Self, JwksSourceError> {
113 let jwks_str = match &self.config {
114 JwksProviderSourceConfig::Remote { url, .. } => {
115 let client = reqwest::Client::new();
116 debug!("loading jwks from a remote source: {}", url);
117
118 let response_text = client
119 .get(url)
120 .send()
121 .await
122 .map_err(JwksSourceError::RemoteJwksNetworkError)?
123 .text()
124 .await
125 .map_err(JwksSourceError::RemoteJwksNetworkError)?;
126
127 response_text
128 }
129 JwksProviderSourceConfig::File { file, .. } => {
130 debug!("loading jwks from a file source: {}", file.absolute);
131
132 let file_contents = read_to_string(&file.absolute)
133 .await
134 .map_err(JwksSourceError::FileJwksNetworkError)?;
135
136 file_contents
137 }
138 };
139
140 let new_jwk = Arc::new(
141 from_str::<JwkSet>(&jwks_str).map_err(JwksSourceError::JwksContentInvalidStructure)?,
142 );
143
144 if let Ok(mut w_jwk) = self.jwk.write() {
145 *w_jwk = Some(new_jwk);
146 }
147
148 Ok(self)
149 }
150
151 pub fn new(config: JwksProviderSourceConfig) -> Self {
152 Self {
153 config,
154 jwk: RwLock::new(None),
155 }
156 }
157
158 pub fn should_poll_in_background(&self) -> bool {
159 match &self.config {
160 JwksProviderSourceConfig::Remote { .. } => true,
161 JwksProviderSourceConfig::File { .. } => false,
162 }
163 }
164
165 pub fn should_prefetch(&self) -> bool {
166 match &self.config {
167 JwksProviderSourceConfig::Remote { prefetch, .. } => match prefetch {
168 Some(prefetch) => *prefetch,
169 None => false,
170 },
171 JwksProviderSourceConfig::File { .. } => true,
172 }
173 }
174
175 pub fn get_jwk_set(&self) -> Result<Arc<JwkSet>, JwksSourceError> {
176 if let Ok(jwk) = self.jwk.try_read() {
177 if let Some(jwk) = jwk.as_ref() {
178 return Ok(jwk.clone());
179 }
180 }
181
182 Err(JwksSourceError::FailedToAcquireJwk)
183 }
184}