Skip to main content

camel_auth/
jwks.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::net::IpAddr;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, RwLock};
6
7use crate::types::AuthError;
8
9#[derive(Debug, Clone, Deserialize, Serialize)]
10pub struct Jwk {
11    pub kid: String,
12    pub kty: String,
13    pub alg: Option<String>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub r#use: Option<String>,
16    pub n: String,
17    pub e: String,
18}
19
20#[async_trait]
21pub trait JwksProvider: Send + Sync {
22    async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError>;
23    async fn refresh(&self) -> Result<(), AuthError>;
24}
25
26struct CachedKeys {
27    keys: Vec<Jwk>,
28    fetched_at: Instant,
29    ttl: Duration,
30}
31
32pub struct RemoteJwksProvider {
33    jwks_uri: String,
34    http: reqwest::Client,
35    cache: RwLock<Option<CachedKeys>>,
36    in_flight: Mutex<()>,
37    default_ttl: Duration,
38}
39
40impl RemoteJwksProvider {
41    /// Creates a production provider with HTTPS enforcement, SSRF guard, and hardened timeouts.
42    pub fn new(jwks_uri: String) -> Result<Self, AuthError> {
43        validate_https_public_uri(&jwks_uri, "JWKS URI")?;
44        let http = reqwest::Client::builder()
45            .connect_timeout(Duration::from_secs(5))
46            .timeout(Duration::from_secs(10))
47            .build()
48            .map_err(|e| AuthError::ConfigError(format!("failed to build HTTP client: {e}")))?;
49        Ok(Self::with_client(jwks_uri, http))
50    }
51
52    /// Creates a provider with a custom HTTP client, bypassing URL validation.
53    /// **For testing only.**
54    #[cfg(test)]
55    pub fn new_for_test(jwks_uri: String) -> Self {
56        Self::with_client(jwks_uri, reqwest::Client::new())
57    }
58
59    fn with_client(jwks_uri: String, http: reqwest::Client) -> Self {
60        Self {
61            jwks_uri,
62            http,
63            cache: RwLock::new(None),
64            in_flight: Mutex::new(()),
65            default_ttl: Duration::from_secs(300),
66        }
67    }
68
69    async fn fetch_and_store(&self) -> Result<Vec<Jwk>, AuthError> {
70        let resp = self
71            .http
72            .get(&self.jwks_uri)
73            .send()
74            .await
75            .map_err(|e| AuthError::ProviderUnavailable(format!("JWKS fetch failed: {e}")))?;
76
77        if !resp.status().is_success() {
78            return Err(AuthError::ProviderUnavailable(format!(
79                "JWKS endpoint returned {}",
80                resp.status()
81            )));
82        }
83
84        let ttl = resp
85            .headers()
86            .get("cache-control")
87            .and_then(|v| v.to_str().ok())
88            .and_then(|v| {
89                v.split(',').find_map(|part| {
90                    let part = part.trim();
91                    part.strip_prefix("max-age=")
92                        .and_then(|s| s.parse::<u64>().ok())
93                        .map(Duration::from_secs)
94                })
95            })
96            .unwrap_or(self.default_ttl);
97
98        #[derive(Deserialize)]
99        struct JwksResponse {
100            keys: Vec<Jwk>,
101        }
102
103        let body: JwksResponse = resp
104            .json()
105            .await
106            .map_err(|e| AuthError::ProviderUnavailable(format!("JWKS parse failed: {e}")))?;
107
108        let keys = body.keys;
109        *self.cache.write().await = Some(CachedKeys {
110            keys: keys.clone(),
111            fetched_at: Instant::now(),
112            ttl,
113        });
114        Ok(keys)
115    }
116}
117
118/// Validates that `uri` is a public HTTPS endpoint safe for outbound requests.
119///
120/// Rules:
121/// - Scheme must be `https`
122/// - Host must not be a loopback or RFC-1918 private address
123pub fn validate_https_public_uri(uri: &str, label: &str) -> Result<(), AuthError> {
124    let parsed = uri
125        .parse::<reqwest::Url>()
126        .map_err(|e| AuthError::ConfigError(format!("invalid {label} '{uri}': {e}")))?;
127
128    if parsed.scheme() != "https" {
129        return Err(AuthError::ConfigError(format!(
130            "{label} must use HTTPS (got scheme '{}')",
131            parsed.scheme()
132        )));
133    }
134
135    if parsed.host_str().is_some_and(is_private_or_loopback_host) {
136        return Err(AuthError::ConfigError(format!(
137            "{label} host '{}' is a private or loopback address (SSRF guard)",
138            parsed.host_str().unwrap_or("")
139        )));
140    }
141
142    Ok(())
143}
144
145/// Returns `true` if the host string resolves to a loopback or private IP.
146fn is_private_or_loopback_host(host: &str) -> bool {
147    // Named loopback / unspecified
148    if matches!(host, "localhost" | "localhost.localdomain" | "0.0.0.0") {
149        return true;
150    }
151    // url::Url::host_str() wraps IPv6 addresses in brackets: "[::1]".
152    // std::net::IpAddr::from_str rejects the bracket form, so strip them first.
153    let ip_str = host
154        .strip_prefix('[')
155        .and_then(|s| s.strip_suffix(']'))
156        .unwrap_or(host);
157    if let Ok(ip) = ip_str.parse::<IpAddr>() {
158        return ip.is_loopback() || is_private_ip(ip);
159    }
160    false
161}
162
163fn is_private_ip(ip: IpAddr) -> bool {
164    match ip {
165        IpAddr::V4(v4) => {
166            // RFC 1918 private ranges, link-local (169.254/16 — cloud metadata),
167            // loopback (127/8), and unspecified (0.0.0.0).
168            v4.is_private() || v4.is_link_local() || v4.is_loopback() || v4.is_unspecified()
169        }
170        IpAddr::V6(v6) => {
171            // Loopback (::1), unique-local (fc00::/7), and unspecified (::).
172            v6.is_loopback() || v6.is_unique_local() || v6.is_unspecified()
173        }
174    }
175}
176
177#[async_trait]
178impl JwksProvider for RemoteJwksProvider {
179    async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError> {
180        // Fast path: fresh cache
181        {
182            let cache = self.cache.read().await;
183            if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
184                return Ok(c.keys.clone());
185            }
186        }
187
188        // Slow path: single-flight fetch
189        let _guard = self.in_flight.lock().await;
190
191        // Re-check after acquiring lock (another task may have refreshed)
192        {
193            let cache = self.cache.read().await;
194            if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
195                return Ok(c.keys.clone());
196            }
197        }
198
199        self.fetch_and_store().await
200    }
201
202    async fn refresh(&self) -> Result<(), AuthError> {
203        self.fetch_and_store().await.map(|_| ())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn jwk_from_json_fields() {
213        let jwk = Jwk {
214            kid: "key-1".into(),
215            kty: "RSA".into(),
216            alg: Some("RS256".into()),
217            r#use: None,
218            n: "modulus-base64url".into(),
219            e: "AQAB".into(),
220        };
221        assert_eq!(jwk.kid, "key-1");
222        assert_eq!(jwk.e, "AQAB");
223    }
224
225    #[test]
226    fn https_enforcement_rejects_http() {
227        let result = RemoteJwksProvider::new(
228            "http://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
229        );
230        assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("HTTPS")));
231    }
232
233    #[test]
234    fn ssrf_guard_rejects_localhost() {
235        let result = RemoteJwksProvider::new(
236            "https://localhost/realms/test/protocol/openid-connect/certs".into(),
237        );
238        assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("loopback")));
239    }
240
241    #[test]
242    fn ssrf_guard_rejects_private_ip() {
243        let result = RemoteJwksProvider::new(
244            "https://192.168.1.1/realms/test/protocol/openid-connect/certs".into(),
245        );
246        assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private")));
247    }
248
249    #[test]
250    fn production_url_accepted() {
251        // Should not fail URL validation (will fail at network level, not construction)
252        let result = RemoteJwksProvider::new(
253            "https://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
254        );
255        assert!(result.is_ok());
256    }
257
258    #[test]
259    fn ssrf_guard_rejects_link_local_metadata_endpoint() {
260        // 169.254.169.254 is the cloud instance-metadata address (AWS/GCP/Azure)
261        let result = RemoteJwksProvider::new("https://169.254.169.254/latest/meta-data".into());
262        assert!(
263            matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
264        );
265    }
266
267    #[test]
268    fn ssrf_guard_rejects_ipv6_unique_local() {
269        let result = RemoteJwksProvider::new(
270            "https://[fc00::1]/realms/test/protocol/openid-connect/certs".into(),
271        );
272        assert!(
273            matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
274        );
275    }
276
277    #[test]
278    fn ssrf_guard_rejects_ipv6_loopback() {
279        let result = RemoteJwksProvider::new(
280            "https://[::1]/realms/test/protocol/openid-connect/certs".into(),
281        );
282        assert!(
283            matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
284        );
285    }
286
287    #[tokio::test]
288    async fn cache_returns_fresh_keys_without_http() {
289        let provider = RemoteJwksProvider::new_for_test("http://unreachable:9999/certs".into());
290        // Seed cache manually
291        {
292            let mut cache = provider.cache.write().await;
293            *cache = Some(CachedKeys {
294                keys: vec![Jwk {
295                    kid: "cached-key".into(),
296                    kty: "RSA".into(),
297                    alg: Some("RS256".into()),
298                    r#use: None,
299                    n: "n".into(),
300                    e: "AQAB".into(),
301                }],
302                fetched_at: Instant::now(),
303                ttl: Duration::from_secs(300),
304            });
305        }
306        let keys = provider.get_signing_keys().await.unwrap();
307        assert_eq!(keys.len(), 1);
308        assert_eq!(keys[0].kid, "cached-key");
309    }
310
311    #[tokio::test]
312    async fn refresh_via_wiremock() {
313        use wiremock::matchers::method;
314        use wiremock::{Mock, MockServer, ResponseTemplate};
315
316        let mock_server = MockServer::start().await;
317        let body =
318            r#"{"keys":[{"kid":"key-1","kty":"RSA","alg":"RS256","n":"modulus","e":"AQAB"}]}"#;
319
320        Mock::given(method("GET"))
321            .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
322            .mount(&mock_server)
323            .await;
324
325        let jwks_uri = format!(
326            "{}/realms/test/protocol/openid-connect/certs",
327            mock_server.uri()
328        );
329        let provider = RemoteJwksProvider::new_for_test(jwks_uri);
330        provider.refresh().await.unwrap();
331
332        let keys = provider.get_signing_keys().await.unwrap();
333        assert_eq!(keys.len(), 1);
334        assert_eq!(keys[0].kid, "key-1");
335    }
336}