1use crate::claims::{AuthContext, SessionClaims};
2use crate::error::VerifyError;
3use crate::keys::{SigningKey, VerifyingKey};
4use base64::Engine;
5use serde::{Deserialize, Serialize};
6use serde_json;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10struct JwtHeader {
11 alg: String,
12 typ: String,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 kid: Option<String>,
15}
16
17impl Default for JwtHeader {
18 fn default() -> Self {
19 Self {
20 alg: "EdDSA".to_string(),
21 typ: "JWT".to_string(),
22 kid: None,
23 }
24 }
25}
26
27pub struct TokenSigner {
29 signing_key: SigningKey,
30 issuer: String,
31}
32
33impl TokenSigner {
34 pub fn new(signing_key: SigningKey, issuer: impl Into<String>) -> Self {
39 Self {
40 signing_key,
41 issuer: issuer.into(),
42 }
43 }
44
45 pub fn sign(&self, claims: SessionClaims) -> Result<String, TokenError> {
47 let header = JwtHeader {
49 kid: Some(self.signing_key.key_id()),
50 ..Default::default()
51 };
52
53 let header_json = serde_json::to_string(&header)?;
55 let header_b64 =
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json.as_bytes());
57
58 let claims_json = serde_json::to_string(&claims)?;
60 let claims_b64 =
61 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
62
63 let message = format!("{}.{}", header_b64, claims_b64);
65
66 let signature = self.signing_key.sign(message.as_bytes());
68 let signature_b64 =
69 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes());
70
71 Ok(format!("{}.{}.{}", header_b64, claims_b64, signature_b64))
73 }
74
75 pub fn issuer(&self) -> &str {
77 &self.issuer
78 }
79}
80
81#[derive(Debug)]
83pub enum TokenError {
84 Serialization(serde_json::Error),
85 Base64(base64::DecodeError),
86 InvalidFormat(String),
87}
88
89impl std::fmt::Display for TokenError {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 TokenError::Serialization(e) => write!(f, "Serialization error: {}", e),
93 TokenError::Base64(e) => write!(f, "Base64 error: {}", e),
94 TokenError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
95 }
96 }
97}
98
99impl std::error::Error for TokenError {}
100
101impl From<serde_json::Error> for TokenError {
102 fn from(e: serde_json::Error) -> Self {
103 TokenError::Serialization(e)
104 }
105}
106
107impl From<base64::DecodeError> for TokenError {
108 fn from(e: base64::DecodeError) -> Self {
109 TokenError::Base64(e)
110 }
111}
112
113pub struct TokenVerifier {
115 verifying_key: VerifyingKey,
116 issuer: String,
117 audience: String,
118 require_origin: bool,
119 require_client_ip: bool,
120}
121
122impl TokenVerifier {
123 pub fn new(
128 verifying_key: VerifyingKey,
129 issuer: impl Into<String>,
130 audience: impl Into<String>,
131 ) -> Self {
132 Self {
133 verifying_key,
134 issuer: issuer.into(),
135 audience: audience.into(),
136 require_origin: false,
137 require_client_ip: false,
138 }
139 }
140
141 pub fn with_origin_validation(mut self) -> Self {
143 self.require_origin = true;
144 self
145 }
146
147 pub fn with_client_ip_validation(mut self) -> Self {
149 self.require_client_ip = true;
150 self
151 }
152
153 pub fn verify(
160 &self,
161 token: &str,
162 expected_origin: Option<&str>,
163 expected_client_ip: Option<&str>,
164 ) -> Result<AuthContext, VerifyError> {
165 let parts: Vec<&str> = token.split('.').collect();
167 if parts.len() != 3 {
168 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
169 }
170
171 let header_b64 = parts[0];
172 let claims_b64 = parts[1];
173 let signature_b64 = parts[2];
174
175 let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
177 .decode(header_b64)
178 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header base64: {}", e)))?;
179 let header: JwtHeader = serde_json::from_slice(&header_json)
180 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
181
182 if header.alg != "EdDSA" {
183 return Err(VerifyError::InvalidFormat(format!(
184 "Unsupported algorithm: {}",
185 header.alg
186 )));
187 }
188
189 let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
191 .decode(claims_b64)
192 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims base64: {}", e)))?;
193 let claims: SessionClaims = serde_json::from_slice(&claims_json)
194 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
195
196 let signature_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
198 .decode(signature_b64)
199 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid signature base64: {}", e)))?;
200 if signature_bytes.len() != 64 {
201 return Err(VerifyError::InvalidFormat(
202 "Invalid signature length".to_string(),
203 ));
204 }
205 let signature = ed25519_dalek::Signature::from_bytes(&signature_bytes.try_into().unwrap());
206
207 let message = format!("{}.{}", header_b64, claims_b64);
209 self.verifying_key
210 .verify(message.as_bytes(), &signature)
211 .map_err(|_| VerifyError::InvalidSignature)?;
212
213 if claims.iss != self.issuer {
215 return Err(VerifyError::InvalidIssuer);
216 }
217
218 if claims.aud != self.audience {
220 return Err(VerifyError::InvalidAudience);
221 }
222
223 use std::time::{SystemTime, UNIX_EPOCH};
225 let now = SystemTime::now()
226 .duration_since(UNIX_EPOCH)
227 .expect("time should not be before epoch")
228 .as_secs();
229
230 if claims.exp <= now {
231 return Err(VerifyError::Expired);
232 }
233
234 if claims.nbf > now {
235 return Err(VerifyError::NotYetValid);
236 }
237
238 let token_has_origin = claims.origin.is_some();
240 let origin_provided = expected_origin.is_some();
241
242 if token_has_origin && origin_provided {
243 let expected = expected_origin.unwrap();
245 let actual = claims.origin.as_ref().unwrap();
246
247 if actual != expected {
248 return Err(VerifyError::OriginMismatch {
249 expected: expected.to_string(),
250 actual: actual.clone(),
251 });
252 }
253 } else if token_has_origin && self.require_origin {
254 return Err(VerifyError::OriginRequired {
256 token_origin: claims.origin.as_ref().unwrap().clone(),
257 });
258 } else if !token_has_origin && self.require_origin {
259 return Err(VerifyError::MissingClaim("origin".to_string()));
261 }
262 if self.require_client_ip {
267 if let Some(expected) = expected_client_ip {
268 match &claims.client_ip {
269 Some(actual) if actual == expected => {}
270 Some(actual) => {
271 return Err(VerifyError::OriginMismatch {
272 expected: expected.to_string(),
273 actual: actual.clone(),
274 });
275 }
276 None => {
277 return Err(VerifyError::MissingClaim("client_ip".to_string()));
278 }
279 }
280 } else if claims.client_ip.is_none() {
281 return Err(VerifyError::MissingClaim("client_ip".to_string()));
282 }
283 }
284
285 Ok(AuthContext::from_claims(claims))
286 }
287
288 pub fn issuer(&self) -> &str {
290 &self.issuer
291 }
292
293 pub fn audience(&self) -> &str {
295 &self.audience
296 }
297}
298
299#[derive(Debug, Clone, Deserialize)]
301pub struct Jwks {
302 pub keys: Vec<Jwk>,
303}
304
305#[derive(Debug, Clone, Deserialize)]
306pub struct Jwk {
307 pub kty: String,
308 #[serde(rename = "use")]
309 pub use_: Option<String>,
310 pub kid: String,
311 pub x: String, }
313
314#[derive(Clone)]
316pub struct JwksVerifier {
317 jwks: Jwks,
318 issuer: String,
319 audience: String,
320 require_origin: bool,
321}
322
323impl JwksVerifier {
324 pub fn new(jwks: Jwks, issuer: impl Into<String>, audience: impl Into<String>) -> Self {
326 Self {
327 jwks,
328 issuer: issuer.into(),
329 audience: audience.into(),
330 require_origin: false,
331 }
332 }
333
334 pub fn with_origin_validation(mut self) -> Self {
336 self.require_origin = true;
337 self
338 }
339
340 pub fn verify(
342 &self,
343 token: &str,
344 expected_origin: Option<&str>,
345 expected_client_ip: Option<&str>,
346 ) -> Result<AuthContext, VerifyError> {
347 let parts: Vec<&str> = token.split('.').collect();
349 if parts.len() != 3 {
350 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
351 }
352
353 let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
354 .decode(parts[0])
355 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header: {}", e)))?;
356 let header: JwtHeader = serde_json::from_slice(&header_json)
357 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
358
359 let kid = header
360 .kid
361 .ok_or_else(|| VerifyError::MissingClaim("kid".to_string()))?;
362
363 let jwk = self
365 .jwks
366 .keys
367 .iter()
368 .find(|k| k.kid == kid)
369 .ok_or(VerifyError::KeyNotFound(kid))?;
370
371 let public_key_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
375 .decode(&jwk.x)
376 .map_err(|_| VerifyError::InvalidFormat("Invalid public key base64".to_string()))?;
377
378 let public_key: [u8; 32] = public_key_bytes
379 .try_into()
380 .map_err(|_| VerifyError::InvalidFormat("Invalid key length".to_string()))?;
381
382 let verifying_key = VerifyingKey::from_bytes(&public_key)
384 .map_err(|e| VerifyError::InvalidFormat(e.to_string()))?;
385
386 let verifier = if self.require_origin {
387 TokenVerifier::new(verifying_key, &self.issuer, &self.audience).with_origin_validation()
388 } else {
389 TokenVerifier::new(verifying_key, &self.issuer, &self.audience)
390 };
391
392 verifier.verify(token, expected_origin, expected_client_ip)
393 }
394
395 #[cfg(feature = "jwks")]
397 pub async fn fetch_jwks(url: &str) -> Result<Jwks, reqwest::Error> {
398 let response = reqwest::get(url).await?;
399 let jwks: Jwks = response.json().await?;
400 Ok(jwks)
401 }
402}
403
404#[cfg(test)]
405pub struct HmacVerifier {
407 _secret: Vec<u8>,
408 _issuer: String,
409 _audience: String,
410}
411
412#[cfg(test)]
413impl HmacVerifier {
414 pub fn new(
416 secret: impl Into<Vec<u8>>,
417 issuer: impl Into<String>,
418 audience: impl Into<String>,
419 ) -> Self {
420 Self {
421 _secret: secret.into(),
422 _issuer: issuer.into(),
423 _audience: audience.into(),
424 }
425 }
426
427 pub fn verify(
429 &self,
430 token: &str,
431 _expected_origin: Option<&str>,
432 ) -> Result<AuthContext, VerifyError> {
433 let parts: Vec<&str> = token.split('.').collect();
435 if parts.len() != 3 {
436 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
437 }
438
439 let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
443 .decode(parts[1])
444 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims: {}", e)))?;
445 let claims: SessionClaims = serde_json::from_slice(&claims_json)
446 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
447
448 Ok(AuthContext::from_claims(claims))
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::claims::{KeyClass, Limits};
456
457 fn create_test_claims() -> SessionClaims {
458 SessionClaims::builder("test-issuer", "test-subject", "test-audience")
459 .with_ttl(300)
460 .with_scope("read")
461 .with_metering_key("meter-123")
462 .with_key_class(KeyClass::Publishable)
463 .with_limits(Limits {
464 max_connections: Some(10),
465 max_subscriptions: Some(100),
466 max_snapshot_rows: Some(1000),
467 max_messages_per_minute: Some(1000),
468 max_bytes_per_minute: Some(10_000_000),
469 })
470 .build()
471 }
472
473 #[test]
474 fn test_sign_and_verify() {
475 let signing_key = crate::keys::SigningKey::generate();
477 let verifying_key = signing_key.verifying_key();
478
479 let signer = TokenSigner::new(signing_key, "test-issuer");
481 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
482
483 let claims = create_test_claims();
485 let token = signer.sign(claims.clone()).unwrap();
486
487 let context = verifier.verify(&token, None, None).unwrap();
489
490 assert_eq!(context.subject, "test-subject");
491 assert_eq!(context.issuer, "test-issuer");
492 assert_eq!(context.metering_key, "meter-123");
493 }
494
495 #[test]
496 fn test_expired_token() {
497 let signing_key = crate::keys::SigningKey::generate();
498 let verifying_key = signing_key.verifying_key();
499
500 let signer = TokenSigner::new(signing_key, "test-issuer");
501 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
502
503 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
505 .with_ttl(0) .with_scope("read")
507 .with_metering_key("meter-123")
508 .with_key_class(KeyClass::Publishable)
509 .build();
510
511 let token = signer.sign(claims).unwrap();
512
513 let result = verifier.verify(&token, None, None);
515 assert!(matches!(result, Err(VerifyError::Expired)));
516 }
517
518 #[test]
519 fn test_invalid_signature() {
520 let signing_key = crate::keys::SigningKey::generate();
521 let wrong_signing_key = crate::keys::SigningKey::generate();
522 let wrong_verifying_key = wrong_signing_key.verifying_key();
523
524 let signer = TokenSigner::new(signing_key, "test-issuer");
525 let verifier = TokenVerifier::new(wrong_verifying_key, "test-issuer", "test-audience");
526
527 let claims = create_test_claims();
528 let token = signer.sign(claims).unwrap();
529
530 let result = verifier.verify(&token, None, None);
532 assert!(matches!(result, Err(VerifyError::InvalidSignature)));
533 }
534
535 #[test]
536 fn test_wrong_issuer() {
537 let signing_key = crate::keys::SigningKey::generate();
538 let verifying_key = signing_key.verifying_key();
539
540 let signer = TokenSigner::new(signing_key, "wrong-issuer");
541 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
542
543 let claims = SessionClaims::builder("wrong-issuer", "test-subject", "test-audience")
545 .with_ttl(300)
546 .with_scope("read")
547 .with_metering_key("meter-123")
548 .with_key_class(KeyClass::Publishable)
549 .build();
550 let token = signer.sign(claims).unwrap();
551
552 let result = verifier.verify(&token, None, None);
554 assert!(matches!(result, Err(VerifyError::InvalidIssuer)));
555 }
556
557 #[test]
558 fn test_wrong_audience() {
559 let signing_key = crate::keys::SigningKey::generate();
560 let verifying_key = signing_key.verifying_key();
561
562 let signer = TokenSigner::new(signing_key, "test-issuer");
563 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "expected-audience");
564
565 let claims = SessionClaims::builder("test-issuer", "test-subject", "wrong-audience")
566 .with_ttl(300)
567 .with_scope("read")
568 .with_metering_key("meter-123")
569 .with_key_class(KeyClass::Publishable)
570 .build();
571 let token = signer.sign(claims).unwrap();
572
573 let result = verifier.verify(&token, None, None);
574 assert!(matches!(result, Err(VerifyError::InvalidAudience)));
575 }
576
577 #[test]
578 fn test_origin_mismatch() {
579 let signing_key = crate::keys::SigningKey::generate();
580 let verifying_key = signing_key.verifying_key();
581
582 let signer = TokenSigner::new(signing_key, "test-issuer");
583 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
584 .with_origin_validation();
585
586 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
587 .with_ttl(300)
588 .with_scope("read")
589 .with_metering_key("meter-123")
590 .with_origin("https://allowed.example")
591 .with_key_class(KeyClass::Publishable)
592 .build();
593 let token = signer.sign(claims).unwrap();
594
595 let result = verifier.verify(&token, Some("https://other.example"), None);
596 assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
597 }
598
599 #[test]
600 fn test_origin_validation_success() {
601 let signing_key = crate::keys::SigningKey::generate();
602 let verifying_key = signing_key.verifying_key();
603
604 let signer = TokenSigner::new(signing_key, "test-issuer");
605 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
606 .with_origin_validation();
607
608 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
609 .with_ttl(300)
610 .with_scope("read")
611 .with_metering_key("meter-123")
612 .with_origin("https://allowed.example")
613 .with_key_class(KeyClass::Publishable)
614 .build();
615 let token = signer.sign(claims).unwrap();
616
617 let context = verifier
618 .verify(&token, Some("https://allowed.example"), None)
619 .unwrap();
620 assert_eq!(context.origin.as_deref(), Some("https://allowed.example"));
621 }
622
623 #[test]
624 fn test_origin_validation_requires_origin_claim() {
625 let signing_key = crate::keys::SigningKey::generate();
626 let verifying_key = signing_key.verifying_key();
627
628 let signer = TokenSigner::new(signing_key, "test-issuer");
629 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
630 .with_origin_validation();
631
632 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
633 .with_ttl(300)
634 .with_scope("read")
635 .with_metering_key("meter-123")
636 .with_key_class(KeyClass::Publishable)
637 .build();
638 let token = signer.sign(claims).unwrap();
639
640 let result = verifier.verify(&token, None, None);
641 assert!(matches!(
642 result,
643 Err(VerifyError::MissingClaim(ref claim)) if claim == "origin"
644 ));
645 }
646
647 #[test]
648 fn test_client_ip_validation_requires_client_ip_claim() {
649 let signing_key = crate::keys::SigningKey::generate();
650 let verifying_key = signing_key.verifying_key();
651
652 let signer = TokenSigner::new(signing_key, "test-issuer");
653 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
654 .with_client_ip_validation();
655
656 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
657 .with_ttl(300)
658 .with_scope("read")
659 .with_metering_key("meter-123")
660 .with_key_class(KeyClass::Publishable)
661 .build();
662 let token = signer.sign(claims).unwrap();
663
664 let result = verifier.verify(&token, None, None);
665 assert!(matches!(
666 result,
667 Err(VerifyError::MissingClaim(ref claim)) if claim == "client_ip"
668 ));
669 }
670
671 #[test]
672 fn test_origin_bound_token_allows_no_origin_when_not_required() {
673 let signing_key = crate::keys::SigningKey::generate();
679 let verifying_key = signing_key.verifying_key();
680
681 let signer = TokenSigner::new(signing_key, "test-issuer");
682 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
684
685 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
686 .with_ttl(300)
687 .with_scope("read")
688 .with_metering_key("meter-123")
689 .with_origin("https://example.com") .with_key_class(KeyClass::Publishable)
691 .build();
692 let token = signer.sign(claims).unwrap();
693
694 let context = verifier.verify(&token, None, None).unwrap();
696 assert_eq!(context.origin.as_deref(), Some("https://example.com"));
697 }
698
699 #[test]
700 fn test_origin_bound_token_validates_when_origin_provided_even_when_not_required() {
701 let signing_key = crate::keys::SigningKey::generate();
704 let verifying_key = signing_key.verifying_key();
705
706 let signer = TokenSigner::new(signing_key, "test-issuer");
707 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
709
710 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
711 .with_ttl(300)
712 .with_scope("read")
713 .with_metering_key("meter-123")
714 .with_origin("https://allowed.example")
715 .with_key_class(KeyClass::Publishable)
716 .build();
717 let token = signer.sign(claims).unwrap();
718
719 let context = verifier
721 .verify(&token, Some("https://allowed.example"), None)
722 .unwrap();
723 assert_eq!(context.origin.as_deref(), Some("https://allowed.example"));
724
725 let result = verifier.verify(&token, Some("https://evil.example"), None);
727 assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
728 }
729}