1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::net::IpAddr;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, RwLock};
6
7use crate::types::AuthError;
8
9#[derive(Debug, Clone, Deserialize, Serialize)]
10pub struct Jwk {
11 pub kid: String,
12 pub kty: String,
13 pub alg: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub r#use: Option<String>,
16 pub n: String,
17 pub e: String,
18}
19
20#[async_trait]
21pub trait JwksProvider: Send + Sync {
22 async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError>;
23 async fn refresh(&self) -> Result<(), AuthError>;
24}
25
26struct CachedKeys {
27 keys: Vec<Jwk>,
28 fetched_at: Instant,
29 ttl: Duration,
30}
31
32pub struct RemoteJwksProvider {
33 jwks_uri: String,
34 http: reqwest::Client,
35 cache: RwLock<Option<CachedKeys>>,
36 in_flight: Mutex<()>,
37 default_ttl: Duration,
38}
39
40impl RemoteJwksProvider {
41 pub fn new(jwks_uri: String) -> Result<Self, AuthError> {
43 validate_https_public_uri(&jwks_uri, "JWKS URI")?;
44 let http = reqwest::Client::builder()
45 .connect_timeout(Duration::from_secs(5))
46 .timeout(Duration::from_secs(10))
47 .build()
48 .map_err(|e| AuthError::ConfigError(format!("failed to build HTTP client: {e}")))?;
49 Ok(Self::with_client(jwks_uri, http))
50 }
51
52 #[cfg(test)]
55 pub fn new_for_test(jwks_uri: String) -> Self {
56 Self::with_client(jwks_uri, reqwest::Client::new())
57 }
58
59 fn with_client(jwks_uri: String, http: reqwest::Client) -> Self {
60 Self {
61 jwks_uri,
62 http,
63 cache: RwLock::new(None),
64 in_flight: Mutex::new(()),
65 default_ttl: Duration::from_secs(300),
66 }
67 }
68
69 async fn fetch_and_store(&self) -> Result<Vec<Jwk>, AuthError> {
70 let resp = self
71 .http
72 .get(&self.jwks_uri)
73 .send()
74 .await
75 .map_err(|e| AuthError::ProviderUnavailable(format!("JWKS fetch failed: {e}")))?;
76
77 if !resp.status().is_success() {
78 return Err(AuthError::ProviderUnavailable(format!(
79 "JWKS endpoint returned {}",
80 resp.status()
81 )));
82 }
83
84 let ttl = resp
85 .headers()
86 .get("cache-control")
87 .and_then(|v| v.to_str().ok())
88 .and_then(|v| {
89 v.split(',').find_map(|part| {
90 let part = part.trim();
91 part.strip_prefix("max-age=")
92 .and_then(|s| s.parse::<u64>().ok())
93 .map(Duration::from_secs)
94 })
95 })
96 .unwrap_or(self.default_ttl);
97
98 #[derive(Deserialize)]
99 struct JwksResponse {
100 keys: Vec<Jwk>,
101 }
102
103 let body: JwksResponse = resp
104 .json()
105 .await
106 .map_err(|e| AuthError::ProviderUnavailable(format!("JWKS parse failed: {e}")))?;
107
108 let keys = body.keys;
109 *self.cache.write().await = Some(CachedKeys {
110 keys: keys.clone(),
111 fetched_at: Instant::now(),
112 ttl,
113 });
114 Ok(keys)
115 }
116}
117
118pub fn validate_https_public_uri(uri: &str, label: &str) -> Result<(), AuthError> {
124 let parsed = uri
125 .parse::<reqwest::Url>()
126 .map_err(|e| AuthError::ConfigError(format!("invalid {label} '{uri}': {e}")))?;
127
128 if parsed.scheme() != "https" {
129 return Err(AuthError::ConfigError(format!(
130 "{label} must use HTTPS (got scheme '{}')",
131 parsed.scheme()
132 )));
133 }
134
135 if parsed.host_str().is_some_and(is_private_or_loopback_host) {
136 return Err(AuthError::ConfigError(format!(
137 "{label} host '{}' is a private or loopback address (SSRF guard)",
138 parsed.host_str().unwrap_or("")
139 )));
140 }
141
142 Ok(())
143}
144
145fn is_private_or_loopback_host(host: &str) -> bool {
147 if matches!(host, "localhost" | "localhost.localdomain" | "0.0.0.0") {
149 return true;
150 }
151 let ip_str = host
154 .strip_prefix('[')
155 .and_then(|s| s.strip_suffix(']'))
156 .unwrap_or(host);
157 if let Ok(ip) = ip_str.parse::<IpAddr>() {
158 return ip.is_loopback() || is_private_ip(ip);
159 }
160 false
161}
162
163fn is_private_ip(ip: IpAddr) -> bool {
164 match ip {
165 IpAddr::V4(v4) => {
166 v4.is_private() || v4.is_link_local() || v4.is_loopback() || v4.is_unspecified()
169 }
170 IpAddr::V6(v6) => {
171 v6.is_loopback() || v6.is_unique_local() || v6.is_unspecified()
173 }
174 }
175}
176
177#[async_trait]
178impl JwksProvider for RemoteJwksProvider {
179 async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError> {
180 {
182 let cache = self.cache.read().await;
183 if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
184 return Ok(c.keys.clone());
185 }
186 }
187
188 let _guard = self.in_flight.lock().await;
190
191 {
193 let cache = self.cache.read().await;
194 if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
195 return Ok(c.keys.clone());
196 }
197 }
198
199 self.fetch_and_store().await
200 }
201
202 async fn refresh(&self) -> Result<(), AuthError> {
203 self.fetch_and_store().await.map(|_| ())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn jwk_from_json_fields() {
213 let jwk = Jwk {
214 kid: "key-1".into(),
215 kty: "RSA".into(),
216 alg: Some("RS256".into()),
217 r#use: None,
218 n: "modulus-base64url".into(),
219 e: "AQAB".into(),
220 };
221 assert_eq!(jwk.kid, "key-1");
222 assert_eq!(jwk.e, "AQAB");
223 }
224
225 #[test]
226 fn https_enforcement_rejects_http() {
227 let result = RemoteJwksProvider::new(
228 "http://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
229 );
230 assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("HTTPS")));
231 }
232
233 #[test]
234 fn ssrf_guard_rejects_localhost() {
235 let result = RemoteJwksProvider::new(
236 "https://localhost/realms/test/protocol/openid-connect/certs".into(),
237 );
238 assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("loopback")));
239 }
240
241 #[test]
242 fn ssrf_guard_rejects_private_ip() {
243 let result = RemoteJwksProvider::new(
244 "https://192.168.1.1/realms/test/protocol/openid-connect/certs".into(),
245 );
246 assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private")));
247 }
248
249 #[test]
250 fn production_url_accepted() {
251 let result = RemoteJwksProvider::new(
253 "https://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
254 );
255 assert!(result.is_ok());
256 }
257
258 #[test]
259 fn ssrf_guard_rejects_link_local_metadata_endpoint() {
260 let result = RemoteJwksProvider::new("https://169.254.169.254/latest/meta-data".into());
262 assert!(
263 matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
264 );
265 }
266
267 #[test]
268 fn ssrf_guard_rejects_ipv6_unique_local() {
269 let result = RemoteJwksProvider::new(
270 "https://[fc00::1]/realms/test/protocol/openid-connect/certs".into(),
271 );
272 assert!(
273 matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
274 );
275 }
276
277 #[test]
278 fn ssrf_guard_rejects_ipv6_loopback() {
279 let result = RemoteJwksProvider::new(
280 "https://[::1]/realms/test/protocol/openid-connect/certs".into(),
281 );
282 assert!(
283 matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
284 );
285 }
286
287 #[tokio::test]
288 async fn cache_returns_fresh_keys_without_http() {
289 let provider = RemoteJwksProvider::new_for_test("http://unreachable:9999/certs".into());
290 {
292 let mut cache = provider.cache.write().await;
293 *cache = Some(CachedKeys {
294 keys: vec![Jwk {
295 kid: "cached-key".into(),
296 kty: "RSA".into(),
297 alg: Some("RS256".into()),
298 r#use: None,
299 n: "n".into(),
300 e: "AQAB".into(),
301 }],
302 fetched_at: Instant::now(),
303 ttl: Duration::from_secs(300),
304 });
305 }
306 let keys = provider.get_signing_keys().await.unwrap();
307 assert_eq!(keys.len(), 1);
308 assert_eq!(keys[0].kid, "cached-key");
309 }
310
311 #[tokio::test]
312 async fn refresh_via_wiremock() {
313 use wiremock::matchers::method;
314 use wiremock::{Mock, MockServer, ResponseTemplate};
315
316 let mock_server = MockServer::start().await;
317 let body =
318 r#"{"keys":[{"kid":"key-1","kty":"RSA","alg":"RS256","n":"modulus","e":"AQAB"}]}"#;
319
320 Mock::given(method("GET"))
321 .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
322 .mount(&mock_server)
323 .await;
324
325 let jwks_uri = format!(
326 "{}/realms/test/protocol/openid-connect/certs",
327 mock_server.uri()
328 );
329 let provider = RemoteJwksProvider::new_for_test(jwks_uri);
330 provider.refresh().await.unwrap();
331
332 let keys = provider.get_signing_keys().await.unwrap();
333 assert_eq!(keys.len(), 1);
334 assert_eq!(keys[0].kid, "key-1");
335 }
336}