1use crate::claims::AuthContext;
2use crate::error::VerifyError;
3use crate::keys::VerifyingKey;
4use crate::token::{JwksVerifier, TokenVerifier};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9#[derive(Clone)]
11struct CachedJwks {
12 verifier: JwksVerifier,
13 fetched_at: Instant,
14}
15
16pub struct AsyncVerifier {
18 inner: VerifierInner,
19 jwks_url: Option<String>,
20 cache_duration: Duration,
21 cached_jwks: Arc<RwLock<Option<CachedJwks>>>,
22 issuer: String,
24 audience: String,
26 require_origin: bool,
27}
28
29enum VerifierInner {
30 Static(TokenVerifier),
31 Jwks(JwksVerifier),
32}
33
34impl AsyncVerifier {
35 pub fn with_static_key(
37 key: VerifyingKey,
38 issuer: impl Into<String>,
39 audience: impl Into<String>,
40 ) -> Self {
41 let issuer_str = issuer.into();
42 let audience_str = audience.into();
43 Self {
44 inner: VerifierInner::Static(TokenVerifier::new(
45 key,
46 issuer_str.clone(),
47 audience_str.clone(),
48 )),
49 jwks_url: None,
50 cache_duration: Duration::from_secs(3600), cached_jwks: Arc::new(RwLock::new(None)),
52 issuer: issuer_str,
53 audience: audience_str,
54 require_origin: false,
55 }
56 }
57
58 pub fn with_jwks(
60 jwks: crate::token::Jwks,
61 issuer: impl Into<String>,
62 audience: impl Into<String>,
63 ) -> Self {
64 let issuer_str = issuer.into();
65 let audience_str = audience.into();
66 Self {
67 inner: VerifierInner::Jwks(JwksVerifier::new(
68 jwks,
69 issuer_str.clone(),
70 audience_str.clone(),
71 )),
72 jwks_url: None,
73 cache_duration: Duration::from_secs(3600),
74 cached_jwks: Arc::new(RwLock::new(None)),
75 issuer: issuer_str,
76 audience: audience_str,
77 require_origin: false,
78 }
79 }
80
81 #[cfg(feature = "jwks")]
83 pub fn with_jwks_url(
84 url: impl Into<String>,
85 issuer: impl Into<String>,
86 audience: impl Into<String>,
87 ) -> Self {
88 let issuer_str = issuer.into();
89 let audience_str = audience.into();
90 Self {
91 inner: VerifierInner::Static(TokenVerifier::new(
92 VerifyingKey::from_bytes(&[0u8; 32]).expect("zero key should be valid"),
93 issuer_str.clone(),
94 audience_str.clone(),
95 )),
96 jwks_url: Some(url.into()),
97 issuer: issuer_str,
98 audience: audience_str,
99 cache_duration: Duration::from_secs(3600),
100 cached_jwks: Arc::new(RwLock::new(None)),
101 require_origin: false,
102 }
103 }
104
105 pub fn with_origin_validation(mut self) -> Self {
107 self.require_origin = true;
108 self.inner = match self.inner {
109 VerifierInner::Static(verifier) => {
110 VerifierInner::Static(verifier.with_origin_validation())
111 }
112 VerifierInner::Jwks(verifier) => VerifierInner::Jwks(verifier.with_origin_validation()),
113 };
114 self
115 }
116
117 pub fn with_cache_duration(mut self, duration: Duration) -> Self {
119 self.cache_duration = duration;
120 self
121 }
122
123 #[cfg(feature = "jwks")]
125 pub async fn verify(
126 &self,
127 token: &str,
128 expected_origin: Option<&str>,
129 expected_client_ip: Option<&str>,
130 ) -> Result<AuthContext, VerifyError> {
131 match &self.inner {
133 VerifierInner::Static(verifier) => {
134 verifier.verify(token, expected_origin, expected_client_ip)
135 }
136 VerifierInner::Jwks(verifier) => {
137 verifier.verify(token, expected_origin, expected_client_ip)
138 }
139 }
140 }
141
142 #[cfg(not(feature = "jwks"))]
144 pub fn verify(
145 &self,
146 token: &str,
147 expected_origin: Option<&str>,
148 expected_client_ip: Option<&str>,
149 ) -> Result<AuthContext, VerifyError> {
150 match &self.inner {
151 VerifierInner::Static(verifier) => {
152 verifier.verify(token, expected_origin, expected_client_ip)
153 }
154 VerifierInner::Jwks(verifier) => {
155 verifier.verify(token, expected_origin, expected_client_ip)
156 }
157 }
158 }
159
160 #[cfg(feature = "jwks")]
162 pub async fn refresh_cache(&self) -> Result<(), VerifyError> {
163 if let Some(ref jwks_url) = self.jwks_url {
164 let jwks = crate::token::JwksVerifier::fetch_jwks(jwks_url)
166 .await
167 .map_err(|e| VerifyError::InvalidFormat(format!("Failed to fetch JWKS: {}", e)))?;
168
169 let verifier = if self.require_origin {
171 JwksVerifier::new(jwks, &self.issuer, &self.audience).with_origin_validation()
172 } else {
173 JwksVerifier::new(jwks, &self.issuer, &self.audience)
174 };
175
176 let mut cached = self.cached_jwks.write().await;
178 *cached = Some(CachedJwks {
179 verifier,
180 fetched_at: Instant::now(),
181 });
182 }
183 Ok(())
184 }
185
186 async fn get_cached_verifier(&self) -> Option<JwksVerifier> {
188 let cached = self.cached_jwks.read().await;
189 if let Some(ref cached_jwks) = *cached {
190 if cached_jwks.fetched_at.elapsed() < self.cache_duration {
191 return Some(cached_jwks.verifier.clone());
192 }
193 }
194 None
195 }
196
197 #[cfg(feature = "jwks")]
199 pub async fn verify_with_cache(
200 &self,
201 token: &str,
202 expected_origin: Option<&str>,
203 expected_client_ip: Option<&str>,
204 ) -> Result<AuthContext, VerifyError> {
205 if let Some(verifier) = self.get_cached_verifier().await {
207 match verifier.verify(token, expected_origin, expected_client_ip) {
208 Ok(ctx) => return Ok(ctx),
209 Err(VerifyError::KeyNotFound(_)) => {
210 }
212 Err(e) => return Err(e),
213 }
214 }
215
216 self.refresh_cache().await?;
218
219 if let Some(verifier) = self.get_cached_verifier().await {
220 verifier.verify(token, expected_origin, expected_client_ip)
221 } else if self.jwks_url.is_some() {
222 Err(VerifyError::InvalidFormat(
223 "JWKS cache unavailable after refresh".to_string(),
224 ))
225 } else {
226 match &self.inner {
228 VerifierInner::Static(verifier) => {
229 verifier.verify(token, expected_origin, expected_client_ip)
230 }
231 VerifierInner::Jwks(verifier) => {
232 verifier.verify(token, expected_origin, expected_client_ip)
233 }
234 }
235 }
236 }
237}
238
239pub struct SimpleVerifier {
241 inner: TokenVerifier,
242}
243
244impl SimpleVerifier {
245 pub fn new(key: VerifyingKey, issuer: impl Into<String>, audience: impl Into<String>) -> Self {
247 Self {
248 inner: TokenVerifier::new(key, issuer, audience),
249 }
250 }
251
252 pub fn verify(
254 &self,
255 token: &str,
256 expected_origin: Option<&str>,
257 expected_client_ip: Option<&str>,
258 ) -> Result<AuthContext, VerifyError> {
259 self.inner
260 .verify(token, expected_origin, expected_client_ip)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::claims::{KeyClass, SessionClaims};
268 use crate::keys::SigningKey;
269 use crate::token::TokenSigner;
270 use base64::Engine;
271
272 #[cfg(feature = "jwks")]
273 use tokio::io::{AsyncReadExt, AsyncWriteExt};
274
275 #[tokio::test]
276 async fn test_async_verifier_with_static_key() {
277 let signing_key = SigningKey::generate();
278 let verifying_key = signing_key.verifying_key();
279
280 let signer = TokenSigner::new(signing_key, "test-issuer");
281 let verifier =
282 AsyncVerifier::with_static_key(verifying_key, "test-issuer", "test-audience");
283
284 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
285 .with_scope("read")
286 .with_metering_key("meter-123")
287 .with_key_class(KeyClass::Publishable)
288 .build();
289
290 let token = signer.sign(claims).unwrap();
291 let context = verifier.verify(&token, None, None).await.unwrap();
292
293 assert_eq!(context.subject, "test-subject");
294 }
295
296 #[test]
297 fn test_simple_verifier() {
298 let signing_key = SigningKey::generate();
299 let verifying_key = signing_key.verifying_key();
300
301 let signer = TokenSigner::new(signing_key, "test-issuer");
302 let verifier = SimpleVerifier::new(verifying_key, "test-issuer", "test-audience");
303
304 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
305 .with_scope("read")
306 .with_metering_key("meter-123")
307 .with_key_class(KeyClass::Publishable)
308 .build();
309
310 let token = signer.sign(claims).unwrap();
311 let context = verifier.verify(&token, None, None).unwrap();
312
313 assert_eq!(context.subject, "test-subject");
314 assert_eq!(context.metering_key, "meter-123");
315 }
316
317 #[cfg(feature = "jwks")]
318 #[test]
319 fn test_verify_with_cache_returns_explicit_error_when_cache_stays_empty() {
320 tokio::runtime::Runtime::new().unwrap().block_on(async {
321 let signing_key = SigningKey::generate();
322 let verifying_key = signing_key.verifying_key();
323 let signer = TokenSigner::new(signing_key, "test-issuer");
324
325 let jwks = serde_json::json!({
326 "keys": [{
327 "kty": "OKP",
328 "use": "sig",
329 "kid": verifying_key.key_id(),
330 "x": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(verifying_key.to_bytes()),
331 }]
332 })
333 .to_string();
334
335 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
336 let addr = listener.local_addr().unwrap();
337 let response_body = jwks.clone();
338 tokio::spawn(async move {
339 let (mut socket, _) = listener.accept().await.unwrap();
340 let mut buffer = [0u8; 1024];
341 let _ = socket.read(&mut buffer).await;
342
343 let response = format!(
344 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
345 response_body.len(),
346 response_body
347 );
348 socket.write_all(response.as_bytes()).await.unwrap();
349 });
350
351 let verifier = AsyncVerifier::with_jwks_url(
352 format!("http://{addr}/jwks"),
353 "test-issuer",
354 "test-audience",
355 )
356 .with_cache_duration(Duration::ZERO);
357
358 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
359 .with_scope("read")
360 .with_metering_key("meter-123")
361 .with_key_class(KeyClass::Publishable)
362 .build();
363 let token = signer.sign(claims).unwrap();
364
365 let result = verifier.verify_with_cache(&token, None, None).await;
366 assert!(matches!(
367 result,
368 Err(VerifyError::InvalidFormat(ref msg)) if msg == "JWKS cache unavailable after refresh"
369 ));
370 });
371 }
372}