1use crate::claims::{AuthContext, SessionClaims};
2use crate::error::VerifyError;
3use crate::keys::{SigningKey, VerifyingKey};
4use base64::Engine;
5use serde::{Deserialize, Serialize};
6use serde_json;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10struct JwtHeader {
11 alg: String,
12 typ: String,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 kid: Option<String>,
15}
16
17impl Default for JwtHeader {
18 fn default() -> Self {
19 Self {
20 alg: "EdDSA".to_string(),
21 typ: "JWT".to_string(),
22 kid: None,
23 }
24 }
25}
26
27pub struct TokenSigner {
29 signing_key: SigningKey,
30 issuer: String,
31}
32
33impl TokenSigner {
34 pub fn new(signing_key: SigningKey, issuer: impl Into<String>) -> Self {
39 Self {
40 signing_key,
41 issuer: issuer.into(),
42 }
43 }
44
45 pub fn sign(&self, claims: SessionClaims) -> Result<String, TokenError> {
47 let header = JwtHeader {
49 kid: Some(self.signing_key.key_id()),
50 ..Default::default()
51 };
52
53 let header_json = serde_json::to_string(&header)?;
55 let header_b64 =
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json.as_bytes());
57
58 let claims_json = serde_json::to_string(&claims)?;
60 let claims_b64 =
61 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
62
63 let message = format!("{}.{}", header_b64, claims_b64);
65
66 let signature = self.signing_key.sign(message.as_bytes());
68 let signature_b64 =
69 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes());
70
71 Ok(format!("{}.{}.{}", header_b64, claims_b64, signature_b64))
73 }
74
75 pub fn issuer(&self) -> &str {
77 &self.issuer
78 }
79}
80
81#[derive(Debug)]
83pub enum TokenError {
84 Serialization(serde_json::Error),
85 Base64(base64::DecodeError),
86 InvalidFormat(String),
87}
88
89impl std::fmt::Display for TokenError {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 TokenError::Serialization(e) => write!(f, "Serialization error: {}", e),
93 TokenError::Base64(e) => write!(f, "Base64 error: {}", e),
94 TokenError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
95 }
96 }
97}
98
99impl std::error::Error for TokenError {}
100
101impl From<serde_json::Error> for TokenError {
102 fn from(e: serde_json::Error) -> Self {
103 TokenError::Serialization(e)
104 }
105}
106
107impl From<base64::DecodeError> for TokenError {
108 fn from(e: base64::DecodeError) -> Self {
109 TokenError::Base64(e)
110 }
111}
112
113pub struct TokenVerifier {
115 verifying_key: VerifyingKey,
116 issuer: String,
117 audience: String,
118 require_origin: bool,
119 require_client_ip: bool,
120}
121
122impl TokenVerifier {
123 pub fn new(
128 verifying_key: VerifyingKey,
129 issuer: impl Into<String>,
130 audience: impl Into<String>,
131 ) -> Self {
132 Self {
133 verifying_key,
134 issuer: issuer.into(),
135 audience: audience.into(),
136 require_origin: false,
137 require_client_ip: false,
138 }
139 }
140
141 pub fn with_origin_validation(mut self) -> Self {
143 self.require_origin = true;
144 self
145 }
146
147 pub fn with_client_ip_validation(mut self) -> Self {
149 self.require_client_ip = true;
150 self
151 }
152
153 pub fn verify(
160 &self,
161 token: &str,
162 expected_origin: Option<&str>,
163 expected_client_ip: Option<&str>,
164 ) -> Result<AuthContext, VerifyError> {
165 let parts: Vec<&str> = token.split('.').collect();
167 if parts.len() != 3 {
168 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
169 }
170
171 let header_b64 = parts[0];
172 let claims_b64 = parts[1];
173 let signature_b64 = parts[2];
174
175 let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
177 .decode(header_b64)
178 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header base64: {}", e)))?;
179 let header: JwtHeader = serde_json::from_slice(&header_json)
180 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
181
182 if header.alg != "EdDSA" {
183 return Err(VerifyError::InvalidFormat(format!(
184 "Unsupported algorithm: {}",
185 header.alg
186 )));
187 }
188
189 let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
191 .decode(claims_b64)
192 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims base64: {}", e)))?;
193 let claims: SessionClaims = serde_json::from_slice(&claims_json)
194 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
195
196 let signature_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
198 .decode(signature_b64)
199 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid signature base64: {}", e)))?;
200 if signature_bytes.len() != 64 {
201 return Err(VerifyError::InvalidFormat(
202 "Invalid signature length".to_string(),
203 ));
204 }
205 let signature = ed25519_dalek::Signature::from_bytes(&signature_bytes.try_into().unwrap());
206
207 let message = format!("{}.{}", header_b64, claims_b64);
209 self.verifying_key
210 .verify(message.as_bytes(), &signature)
211 .map_err(|_| VerifyError::InvalidSignature)?;
212
213 if claims.iss != self.issuer {
215 return Err(VerifyError::InvalidIssuer);
216 }
217
218 if claims.aud != self.audience {
220 return Err(VerifyError::InvalidAudience);
221 }
222
223 use std::time::{SystemTime, UNIX_EPOCH};
225 let now = SystemTime::now()
226 .duration_since(UNIX_EPOCH)
227 .expect("time should not be before epoch")
228 .as_secs();
229
230 if claims.exp <= now {
231 return Err(VerifyError::Expired);
232 }
233
234 if claims.nbf > now {
235 return Err(VerifyError::NotYetValid);
236 }
237
238 let token_has_origin = claims.origin.is_some();
240 let origin_provided = expected_origin.is_some();
241
242 if token_has_origin {
243 if !origin_provided {
245 return Err(VerifyError::OriginRequired);
246 }
247
248 let expected = expected_origin.unwrap();
249 let actual = claims.origin.as_ref().unwrap();
250
251 if actual != expected {
252 return Err(VerifyError::OriginMismatch {
253 expected: expected.to_string(),
254 actual: actual.clone(),
255 });
256 }
257 } else if self.require_origin {
258 return Err(VerifyError::MissingClaim("origin".to_string()));
260 }
261
262 if self.require_client_ip {
264 if let Some(expected) = expected_client_ip {
265 match &claims.client_ip {
266 Some(actual) if actual == expected => {}
267 Some(actual) => {
268 return Err(VerifyError::OriginMismatch {
269 expected: expected.to_string(),
270 actual: actual.clone(),
271 });
272 }
273 None => {
274 return Err(VerifyError::MissingClaim("client_ip".to_string()));
275 }
276 }
277 } else if claims.client_ip.is_none() {
278 return Err(VerifyError::MissingClaim("client_ip".to_string()));
279 }
280 }
281
282 Ok(AuthContext::from_claims(claims))
283 }
284
285 pub fn issuer(&self) -> &str {
287 &self.issuer
288 }
289
290 pub fn audience(&self) -> &str {
292 &self.audience
293 }
294}
295
296#[derive(Debug, Clone, Deserialize)]
298pub struct Jwks {
299 pub keys: Vec<Jwk>,
300}
301
302#[derive(Debug, Clone, Deserialize)]
303pub struct Jwk {
304 pub kty: String,
305 #[serde(rename = "use")]
306 pub use_: Option<String>,
307 pub kid: String,
308 pub x: String, }
310
311#[derive(Clone)]
313pub struct JwksVerifier {
314 jwks: Jwks,
315 issuer: String,
316 audience: String,
317 require_origin: bool,
318}
319
320impl JwksVerifier {
321 pub fn new(jwks: Jwks, issuer: impl Into<String>, audience: impl Into<String>) -> Self {
323 Self {
324 jwks,
325 issuer: issuer.into(),
326 audience: audience.into(),
327 require_origin: false,
328 }
329 }
330
331 pub fn with_origin_validation(mut self) -> Self {
333 self.require_origin = true;
334 self
335 }
336
337 pub fn verify(
339 &self,
340 token: &str,
341 expected_origin: Option<&str>,
342 expected_client_ip: Option<&str>,
343 ) -> Result<AuthContext, VerifyError> {
344 let parts: Vec<&str> = token.split('.').collect();
346 if parts.len() != 3 {
347 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
348 }
349
350 let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
351 .decode(parts[0])
352 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header: {}", e)))?;
353 let header: JwtHeader = serde_json::from_slice(&header_json)
354 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
355
356 let kid = header
357 .kid
358 .ok_or_else(|| VerifyError::MissingClaim("kid".to_string()))?;
359
360 let jwk = self
362 .jwks
363 .keys
364 .iter()
365 .find(|k| k.kid == kid)
366 .ok_or(VerifyError::KeyNotFound(kid))?;
367
368 let public_key_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
372 .decode(&jwk.x)
373 .map_err(|_| VerifyError::InvalidFormat("Invalid public key base64".to_string()))?;
374
375 let public_key: [u8; 32] = public_key_bytes
376 .try_into()
377 .map_err(|_| VerifyError::InvalidFormat("Invalid key length".to_string()))?;
378
379 let verifying_key = VerifyingKey::from_bytes(&public_key)
381 .map_err(|e| VerifyError::InvalidFormat(e.to_string()))?;
382
383 let verifier = if self.require_origin {
384 TokenVerifier::new(verifying_key, &self.issuer, &self.audience).with_origin_validation()
385 } else {
386 TokenVerifier::new(verifying_key, &self.issuer, &self.audience)
387 };
388
389 verifier.verify(token, expected_origin, expected_client_ip)
390 }
391
392 #[cfg(feature = "jwks")]
394 pub async fn fetch_jwks(url: &str) -> Result<Jwks, reqwest::Error> {
395 let response = reqwest::get(url).await?;
396 let jwks: Jwks = response.json().await?;
397 Ok(jwks)
398 }
399}
400
401#[cfg(test)]
402pub struct HmacVerifier {
404 _secret: Vec<u8>,
405 _issuer: String,
406 _audience: String,
407}
408
409#[cfg(test)]
410impl HmacVerifier {
411 pub fn new(
413 secret: impl Into<Vec<u8>>,
414 issuer: impl Into<String>,
415 audience: impl Into<String>,
416 ) -> Self {
417 Self {
418 _secret: secret.into(),
419 _issuer: issuer.into(),
420 _audience: audience.into(),
421 }
422 }
423
424 pub fn verify(
426 &self,
427 token: &str,
428 _expected_origin: Option<&str>,
429 ) -> Result<AuthContext, VerifyError> {
430 let parts: Vec<&str> = token.split('.').collect();
432 if parts.len() != 3 {
433 return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
434 }
435
436 let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
440 .decode(parts[1])
441 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims: {}", e)))?;
442 let claims: SessionClaims = serde_json::from_slice(&claims_json)
443 .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
444
445 Ok(AuthContext::from_claims(claims))
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use crate::claims::{KeyClass, Limits};
453
454 fn create_test_claims() -> SessionClaims {
455 SessionClaims::builder("test-issuer", "test-subject", "test-audience")
456 .with_ttl(300)
457 .with_scope("read")
458 .with_metering_key("meter-123")
459 .with_key_class(KeyClass::Publishable)
460 .with_limits(Limits {
461 max_connections: Some(10),
462 max_subscriptions: Some(100),
463 max_snapshot_rows: Some(1000),
464 max_messages_per_minute: Some(1000),
465 max_bytes_per_minute: Some(10_000_000),
466 })
467 .build()
468 }
469
470 #[test]
471 fn test_sign_and_verify() {
472 let signing_key = crate::keys::SigningKey::generate();
474 let verifying_key = signing_key.verifying_key();
475
476 let signer = TokenSigner::new(signing_key, "test-issuer");
478 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
479
480 let claims = create_test_claims();
482 let token = signer.sign(claims.clone()).unwrap();
483
484 let context = verifier.verify(&token, None, None).unwrap();
486
487 assert_eq!(context.subject, "test-subject");
488 assert_eq!(context.issuer, "test-issuer");
489 assert_eq!(context.metering_key, "meter-123");
490 }
491
492 #[test]
493 fn test_expired_token() {
494 let signing_key = crate::keys::SigningKey::generate();
495 let verifying_key = signing_key.verifying_key();
496
497 let signer = TokenSigner::new(signing_key, "test-issuer");
498 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
499
500 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
502 .with_ttl(0) .with_scope("read")
504 .with_metering_key("meter-123")
505 .with_key_class(KeyClass::Publishable)
506 .build();
507
508 let token = signer.sign(claims).unwrap();
509
510 let result = verifier.verify(&token, None, None);
512 assert!(matches!(result, Err(VerifyError::Expired)));
513 }
514
515 #[test]
516 fn test_invalid_signature() {
517 let signing_key = crate::keys::SigningKey::generate();
518 let wrong_signing_key = crate::keys::SigningKey::generate();
519 let wrong_verifying_key = wrong_signing_key.verifying_key();
520
521 let signer = TokenSigner::new(signing_key, "test-issuer");
522 let verifier = TokenVerifier::new(wrong_verifying_key, "test-issuer", "test-audience");
523
524 let claims = create_test_claims();
525 let token = signer.sign(claims).unwrap();
526
527 let result = verifier.verify(&token, None, None);
529 assert!(matches!(result, Err(VerifyError::InvalidSignature)));
530 }
531
532 #[test]
533 fn test_wrong_issuer() {
534 let signing_key = crate::keys::SigningKey::generate();
535 let verifying_key = signing_key.verifying_key();
536
537 let signer = TokenSigner::new(signing_key, "wrong-issuer");
538 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
539
540 let claims = SessionClaims::builder("wrong-issuer", "test-subject", "test-audience")
542 .with_ttl(300)
543 .with_scope("read")
544 .with_metering_key("meter-123")
545 .with_key_class(KeyClass::Publishable)
546 .build();
547 let token = signer.sign(claims).unwrap();
548
549 let result = verifier.verify(&token, None, None);
551 assert!(matches!(result, Err(VerifyError::InvalidIssuer)));
552 }
553
554 #[test]
555 fn test_wrong_audience() {
556 let signing_key = crate::keys::SigningKey::generate();
557 let verifying_key = signing_key.verifying_key();
558
559 let signer = TokenSigner::new(signing_key, "test-issuer");
560 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "expected-audience");
561
562 let claims = SessionClaims::builder("test-issuer", "test-subject", "wrong-audience")
563 .with_ttl(300)
564 .with_scope("read")
565 .with_metering_key("meter-123")
566 .with_key_class(KeyClass::Publishable)
567 .build();
568 let token = signer.sign(claims).unwrap();
569
570 let result = verifier.verify(&token, None, None);
571 assert!(matches!(result, Err(VerifyError::InvalidAudience)));
572 }
573
574 #[test]
575 fn test_origin_mismatch() {
576 let signing_key = crate::keys::SigningKey::generate();
577 let verifying_key = signing_key.verifying_key();
578
579 let signer = TokenSigner::new(signing_key, "test-issuer");
580 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
581 .with_origin_validation();
582
583 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
584 .with_ttl(300)
585 .with_scope("read")
586 .with_metering_key("meter-123")
587 .with_origin("https://allowed.example")
588 .with_key_class(KeyClass::Publishable)
589 .build();
590 let token = signer.sign(claims).unwrap();
591
592 let result = verifier.verify(&token, Some("https://other.example"), None);
593 assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
594 }
595
596 #[test]
597 fn test_origin_validation_success() {
598 let signing_key = crate::keys::SigningKey::generate();
599 let verifying_key = signing_key.verifying_key();
600
601 let signer = TokenSigner::new(signing_key, "test-issuer");
602 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
603 .with_origin_validation();
604
605 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
606 .with_ttl(300)
607 .with_scope("read")
608 .with_metering_key("meter-123")
609 .with_origin("https://allowed.example")
610 .with_key_class(KeyClass::Publishable)
611 .build();
612 let token = signer.sign(claims).unwrap();
613
614 let context = verifier
615 .verify(&token, Some("https://allowed.example"), None)
616 .unwrap();
617 assert_eq!(context.origin.as_deref(), Some("https://allowed.example"));
618 }
619
620 #[test]
621 fn test_origin_validation_requires_origin_claim() {
622 let signing_key = crate::keys::SigningKey::generate();
623 let verifying_key = signing_key.verifying_key();
624
625 let signer = TokenSigner::new(signing_key, "test-issuer");
626 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
627 .with_origin_validation();
628
629 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
630 .with_ttl(300)
631 .with_scope("read")
632 .with_metering_key("meter-123")
633 .with_key_class(KeyClass::Publishable)
634 .build();
635 let token = signer.sign(claims).unwrap();
636
637 let result = verifier.verify(&token, None, None);
638 assert!(matches!(
639 result,
640 Err(VerifyError::MissingClaim(ref claim)) if claim == "origin"
641 ));
642 }
643
644 #[test]
645 fn test_client_ip_validation_requires_client_ip_claim() {
646 let signing_key = crate::keys::SigningKey::generate();
647 let verifying_key = signing_key.verifying_key();
648
649 let signer = TokenSigner::new(signing_key, "test-issuer");
650 let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
651 .with_client_ip_validation();
652
653 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
654 .with_ttl(300)
655 .with_scope("read")
656 .with_metering_key("meter-123")
657 .with_key_class(KeyClass::Publishable)
658 .build();
659 let token = signer.sign(claims).unwrap();
660
661 let result = verifier.verify(&token, None, None);
662 assert!(matches!(
663 result,
664 Err(VerifyError::MissingClaim(ref claim)) if claim == "client_ip"
665 ));
666 }
667}