Skip to main content

hyperstack_auth/
verifier.rs

1use crate::claims::AuthContext;
2use crate::error::VerifyError;
3use crate::keys::VerifyingKey;
4use crate::token::{JwksVerifier, TokenVerifier};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9/// Cached JWKS with expiration
10#[derive(Clone)]
11struct CachedJwks {
12    verifier: JwksVerifier,
13    fetched_at: Instant,
14}
15
16/// Async verifier with JWKS caching support
17pub struct AsyncVerifier {
18    inner: VerifierInner,
19    jwks_url: Option<String>,
20    cache_duration: Duration,
21    cached_jwks: Arc<RwLock<Option<CachedJwks>>>,
22    /// Issuer for JWKS-based verification
23    issuer: String,
24    /// Audience for JWKS-based verification
25    audience: String,
26    require_origin: bool,
27}
28
29enum VerifierInner {
30    Static(TokenVerifier),
31    Jwks(JwksVerifier),
32}
33
34impl AsyncVerifier {
35    /// Create a verifier with a static key
36    pub fn with_static_key(
37        key: VerifyingKey,
38        issuer: impl Into<String>,
39        audience: impl Into<String>,
40    ) -> Self {
41        let issuer_str = issuer.into();
42        let audience_str = audience.into();
43        Self {
44            inner: VerifierInner::Static(TokenVerifier::new(
45                key,
46                issuer_str.clone(),
47                audience_str.clone(),
48            )),
49            jwks_url: None,
50            cache_duration: Duration::from_secs(3600), // 1 hour default
51            cached_jwks: Arc::new(RwLock::new(None)),
52            issuer: issuer_str,
53            audience: audience_str,
54            require_origin: false,
55        }
56    }
57
58    /// Create a verifier with JWKS
59    pub fn with_jwks(
60        jwks: crate::token::Jwks,
61        issuer: impl Into<String>,
62        audience: impl Into<String>,
63    ) -> Self {
64        let issuer_str = issuer.into();
65        let audience_str = audience.into();
66        Self {
67            inner: VerifierInner::Jwks(JwksVerifier::new(
68                jwks,
69                issuer_str.clone(),
70                audience_str.clone(),
71            )),
72            jwks_url: None,
73            cache_duration: Duration::from_secs(3600),
74            cached_jwks: Arc::new(RwLock::new(None)),
75            issuer: issuer_str,
76            audience: audience_str,
77            require_origin: false,
78        }
79    }
80
81    /// Create a verifier that fetches JWKS from a URL
82    #[cfg(feature = "jwks")]
83    pub fn with_jwks_url(
84        url: impl Into<String>,
85        issuer: impl Into<String>,
86        audience: impl Into<String>,
87    ) -> Self {
88        let issuer_str = issuer.into();
89        let audience_str = audience.into();
90        Self {
91            inner: VerifierInner::Static(TokenVerifier::new(
92                VerifyingKey::from_bytes(&[0u8; 32]).expect("zero key should be valid"),
93                issuer_str.clone(),
94                audience_str.clone(),
95            )),
96            jwks_url: Some(url.into()),
97            issuer: issuer_str,
98            audience: audience_str,
99            cache_duration: Duration::from_secs(3600),
100            cached_jwks: Arc::new(RwLock::new(None)),
101            require_origin: false,
102        }
103    }
104
105    /// Require origin validation on verified tokens.
106    pub fn with_origin_validation(mut self) -> Self {
107        self.require_origin = true;
108        self.inner = match self.inner {
109            VerifierInner::Static(verifier) => {
110                VerifierInner::Static(verifier.with_origin_validation())
111            }
112            VerifierInner::Jwks(verifier) => VerifierInner::Jwks(verifier.with_origin_validation()),
113        };
114        self
115    }
116
117    /// Set cache duration for JWKS
118    pub fn with_cache_duration(mut self, duration: Duration) -> Self {
119        self.cache_duration = duration;
120        self
121    }
122
123    /// Verify a token with automatic JWKS fetching and caching
124    #[cfg(feature = "jwks")]
125    pub async fn verify(
126        &self,
127        token: &str,
128        expected_origin: Option<&str>,
129        expected_client_ip: Option<&str>,
130    ) -> Result<AuthContext, VerifyError> {
131        // If using static JWKS or static key, use directly
132        match &self.inner {
133            VerifierInner::Static(verifier) => {
134                verifier.verify(token, expected_origin, expected_client_ip)
135            }
136            VerifierInner::Jwks(verifier) => {
137                verifier.verify(token, expected_origin, expected_client_ip)
138            }
139        }
140    }
141
142    /// Verify a token (non-JWKS version)
143    #[cfg(not(feature = "jwks"))]
144    pub fn verify(
145        &self,
146        token: &str,
147        expected_origin: Option<&str>,
148        expected_client_ip: Option<&str>,
149    ) -> Result<AuthContext, VerifyError> {
150        match &self.inner {
151            VerifierInner::Static(verifier) => {
152                verifier.verify(token, expected_origin, expected_client_ip)
153            }
154            VerifierInner::Jwks(verifier) => {
155                verifier.verify(token, expected_origin, expected_client_ip)
156            }
157        }
158    }
159
160    /// Refresh JWKS cache from the configured URL
161    #[cfg(feature = "jwks")]
162    pub async fn refresh_cache(&self) -> Result<(), VerifyError> {
163        if let Some(ref jwks_url) = self.jwks_url {
164            // Fetch JWKS from URL
165            let jwks = crate::token::JwksVerifier::fetch_jwks(jwks_url)
166                .await
167                .map_err(|e| VerifyError::InvalidFormat(format!("Failed to fetch JWKS: {}", e)))?;
168
169            // Create new verifier with fetched JWKS
170            let verifier = if self.require_origin {
171                JwksVerifier::new(jwks, &self.issuer, &self.audience).with_origin_validation()
172            } else {
173                JwksVerifier::new(jwks, &self.issuer, &self.audience)
174            };
175
176            // Update cache
177            let mut cached = self.cached_jwks.write().await;
178            *cached = Some(CachedJwks {
179                verifier,
180                fetched_at: Instant::now(),
181            });
182        }
183        Ok(())
184    }
185
186    /// Get cached verifier if available and not expired
187    async fn get_cached_verifier(&self) -> Option<JwksVerifier> {
188        let cached = self.cached_jwks.read().await;
189        if let Some(ref cached_jwks) = *cached {
190            if cached_jwks.fetched_at.elapsed() < self.cache_duration {
191                return Some(cached_jwks.verifier.clone());
192            }
193        }
194        None
195    }
196
197    /// Verify a token with automatic JWKS caching
198    #[cfg(feature = "jwks")]
199    pub async fn verify_with_cache(
200        &self,
201        token: &str,
202        expected_origin: Option<&str>,
203        expected_client_ip: Option<&str>,
204    ) -> Result<AuthContext, VerifyError> {
205        // Try cached verifier first
206        if let Some(verifier) = self.get_cached_verifier().await {
207            match verifier.verify(token, expected_origin, expected_client_ip) {
208                Ok(ctx) => return Ok(ctx),
209                Err(VerifyError::KeyNotFound(_)) => {
210                    // Key not found in cache, refresh and retry
211                }
212                Err(e) => return Err(e),
213            }
214        }
215
216        // Refresh cache and try again
217        self.refresh_cache().await?;
218
219        if let Some(verifier) = self.get_cached_verifier().await {
220            verifier.verify(token, expected_origin, expected_client_ip)
221        } else if self.jwks_url.is_some() {
222            Err(VerifyError::InvalidFormat(
223                "JWKS cache unavailable after refresh".to_string(),
224            ))
225        } else {
226            // Fallback to inner verifier if no cache available
227            match &self.inner {
228                VerifierInner::Static(verifier) => {
229                    verifier.verify(token, expected_origin, expected_client_ip)
230                }
231                VerifierInner::Jwks(verifier) => {
232                    verifier.verify(token, expected_origin, expected_client_ip)
233                }
234            }
235        }
236    }
237}
238
239/// Simple synchronous verifier for use in non-async contexts
240pub struct SimpleVerifier {
241    inner: TokenVerifier,
242}
243
244impl SimpleVerifier {
245    /// Create a new simple verifier
246    pub fn new(key: VerifyingKey, issuer: impl Into<String>, audience: impl Into<String>) -> Self {
247        Self {
248            inner: TokenVerifier::new(key, issuer, audience),
249        }
250    }
251
252    /// Verify a token synchronously
253    pub fn verify(
254        &self,
255        token: &str,
256        expected_origin: Option<&str>,
257        expected_client_ip: Option<&str>,
258    ) -> Result<AuthContext, VerifyError> {
259        self.inner
260            .verify(token, expected_origin, expected_client_ip)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::claims::{KeyClass, SessionClaims};
268    use crate::keys::SigningKey;
269    use crate::token::TokenSigner;
270    use base64::Engine;
271
272    #[cfg(feature = "jwks")]
273    use tokio::io::{AsyncReadExt, AsyncWriteExt};
274
275    #[tokio::test]
276    async fn test_async_verifier_with_static_key() {
277        let signing_key = SigningKey::generate();
278        let verifying_key = signing_key.verifying_key();
279
280        let signer = TokenSigner::new(signing_key, "test-issuer");
281        let verifier =
282            AsyncVerifier::with_static_key(verifying_key, "test-issuer", "test-audience");
283
284        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
285            .with_scope("read")
286            .with_metering_key("meter-123")
287            .with_key_class(KeyClass::Publishable)
288            .build();
289
290        let token = signer.sign(claims).unwrap();
291        let context = verifier.verify(&token, None, None).await.unwrap();
292
293        assert_eq!(context.subject, "test-subject");
294    }
295
296    #[test]
297    fn test_simple_verifier() {
298        let signing_key = SigningKey::generate();
299        let verifying_key = signing_key.verifying_key();
300
301        let signer = TokenSigner::new(signing_key, "test-issuer");
302        let verifier = SimpleVerifier::new(verifying_key, "test-issuer", "test-audience");
303
304        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
305            .with_scope("read")
306            .with_metering_key("meter-123")
307            .with_key_class(KeyClass::Publishable)
308            .build();
309
310        let token = signer.sign(claims).unwrap();
311        let context = verifier.verify(&token, None, None).unwrap();
312
313        assert_eq!(context.subject, "test-subject");
314        assert_eq!(context.metering_key, "meter-123");
315    }
316
317    #[cfg(feature = "jwks")]
318    #[test]
319    fn test_verify_with_cache_returns_explicit_error_when_cache_stays_empty() {
320        tokio::runtime::Runtime::new().unwrap().block_on(async {
321            let signing_key = SigningKey::generate();
322            let verifying_key = signing_key.verifying_key();
323            let signer = TokenSigner::new(signing_key, "test-issuer");
324
325            let jwks = serde_json::json!({
326                "keys": [{
327                    "kty": "OKP",
328                    "use": "sig",
329                    "kid": verifying_key.key_id(),
330                    "x": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(verifying_key.to_bytes()),
331                }]
332            })
333            .to_string();
334
335            let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
336            let addr = listener.local_addr().unwrap();
337            let response_body = jwks.clone();
338            tokio::spawn(async move {
339                let (mut socket, _) = listener.accept().await.unwrap();
340                let mut buffer = [0u8; 1024];
341                let _ = socket.read(&mut buffer).await;
342
343                let response = format!(
344                    "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
345                    response_body.len(),
346                    response_body
347                );
348                socket.write_all(response.as_bytes()).await.unwrap();
349            });
350
351            let verifier = AsyncVerifier::with_jwks_url(
352                format!("http://{addr}/jwks"),
353                "test-issuer",
354                "test-audience",
355            )
356            .with_cache_duration(Duration::ZERO);
357
358            let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
359                .with_scope("read")
360                .with_metering_key("meter-123")
361                .with_key_class(KeyClass::Publishable)
362                .build();
363            let token = signer.sign(claims).unwrap();
364
365            let result = verifier.verify_with_cache(&token, None, None).await;
366            assert!(matches!(
367                result,
368                Err(VerifyError::InvalidFormat(ref msg)) if msg == "JWKS cache unavailable after refresh"
369            ));
370        });
371    }
372}