Skip to main content

hyperstack_auth/
token.rs

1use crate::claims::{AuthContext, SessionClaims};
2use crate::error::VerifyError;
3use crate::keys::{SigningKey, VerifyingKey};
4use base64::Engine;
5use serde::{Deserialize, Serialize};
6use serde_json;
7
8/// JWT Header for EdDSA (Ed25519) tokens
9#[derive(Debug, Clone, Serialize, Deserialize)]
10struct JwtHeader {
11    alg: String,
12    typ: String,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    kid: Option<String>,
15}
16
17impl Default for JwtHeader {
18    fn default() -> Self {
19        Self {
20            alg: "EdDSA".to_string(),
21            typ: "JWT".to_string(),
22            kid: None,
23        }
24    }
25}
26
27/// Token signer for issuing session tokens using Ed25519 (EdDSA)
28pub struct TokenSigner {
29    signing_key: SigningKey,
30    issuer: String,
31}
32
33impl TokenSigner {
34    /// Create a new token signer with an Ed25519 signing key
35    ///
36    /// Uses EdDSA (Ed25519) for asymmetric signing. This is the recommended
37    /// algorithm for production use as it provides better security than HMAC.
38    pub fn new(signing_key: SigningKey, issuer: impl Into<String>) -> Self {
39        Self {
40            signing_key,
41            issuer: issuer.into(),
42        }
43    }
44
45    /// Sign a session token using Ed25519
46    pub fn sign(&self, claims: SessionClaims) -> Result<String, TokenError> {
47        // Create header with key ID
48        let header = JwtHeader {
49            kid: Some(self.signing_key.key_id()),
50            ..Default::default()
51        };
52
53        // Encode header
54        let header_json = serde_json::to_string(&header)?;
55        let header_b64 =
56            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json.as_bytes());
57
58        // Encode claims
59        let claims_json = serde_json::to_string(&claims)?;
60        let claims_b64 =
61            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
62
63        // Create message to sign
64        let message = format!("{}.{}", header_b64, claims_b64);
65
66        // Sign with Ed25519
67        let signature = self.signing_key.sign(message.as_bytes());
68        let signature_b64 =
69            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes());
70
71        // Combine into JWT
72        Ok(format!("{}.{}.{}", header_b64, claims_b64, signature_b64))
73    }
74
75    /// Get the issuer
76    pub fn issuer(&self) -> &str {
77        &self.issuer
78    }
79}
80
81/// Token error type
82#[derive(Debug)]
83pub enum TokenError {
84    Serialization(serde_json::Error),
85    Base64(base64::DecodeError),
86    InvalidFormat(String),
87}
88
89impl std::fmt::Display for TokenError {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        match self {
92            TokenError::Serialization(e) => write!(f, "Serialization error: {}", e),
93            TokenError::Base64(e) => write!(f, "Base64 error: {}", e),
94            TokenError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
95        }
96    }
97}
98
99impl std::error::Error for TokenError {}
100
101impl From<serde_json::Error> for TokenError {
102    fn from(e: serde_json::Error) -> Self {
103        TokenError::Serialization(e)
104    }
105}
106
107impl From<base64::DecodeError> for TokenError {
108    fn from(e: base64::DecodeError) -> Self {
109        TokenError::Base64(e)
110    }
111}
112
113/// Token verifier for validating session tokens using Ed25519 (EdDSA)
114pub struct TokenVerifier {
115    verifying_key: VerifyingKey,
116    issuer: String,
117    audience: String,
118    require_origin: bool,
119    require_client_ip: bool,
120}
121
122impl TokenVerifier {
123    /// Create a new token verifier with an Ed25519 verifying key
124    ///
125    /// Uses EdDSA (Ed25519) for asymmetric signature verification.
126    /// This is the recommended algorithm for production use.
127    pub fn new(
128        verifying_key: VerifyingKey,
129        issuer: impl Into<String>,
130        audience: impl Into<String>,
131    ) -> Self {
132        Self {
133            verifying_key,
134            issuer: issuer.into(),
135            audience: audience.into(),
136            require_origin: false,
137            require_client_ip: false,
138        }
139    }
140
141    /// Require origin validation
142    pub fn with_origin_validation(mut self) -> Self {
143        self.require_origin = true;
144        self
145    }
146
147    /// Require client IP validation
148    pub fn with_client_ip_validation(mut self) -> Self {
149        self.require_client_ip = true;
150        self
151    }
152
153    /// Verify a token and return the auth context
154    ///
155    /// # Arguments
156    /// * `token` - The JWT token to verify
157    /// * `expected_origin` - Optional expected origin for origin validation
158    /// * `expected_client_ip` - Optional expected client IP for IP binding validation
159    pub fn verify(
160        &self,
161        token: &str,
162        expected_origin: Option<&str>,
163        expected_client_ip: Option<&str>,
164    ) -> Result<AuthContext, VerifyError> {
165        // Split token into parts
166        let parts: Vec<&str> = token.split('.').collect();
167        if parts.len() != 3 {
168            return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
169        }
170
171        let header_b64 = parts[0];
172        let claims_b64 = parts[1];
173        let signature_b64 = parts[2];
174
175        // Decode and verify header
176        let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
177            .decode(header_b64)
178            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header base64: {}", e)))?;
179        let header: JwtHeader = serde_json::from_slice(&header_json)
180            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
181
182        if header.alg != "EdDSA" {
183            return Err(VerifyError::InvalidFormat(format!(
184                "Unsupported algorithm: {}",
185                header.alg
186            )));
187        }
188
189        // Decode claims
190        let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
191            .decode(claims_b64)
192            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims base64: {}", e)))?;
193        let claims: SessionClaims = serde_json::from_slice(&claims_json)
194            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
195
196        // Decode signature
197        let signature_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
198            .decode(signature_b64)
199            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid signature base64: {}", e)))?;
200        if signature_bytes.len() != 64 {
201            return Err(VerifyError::InvalidFormat(
202                "Invalid signature length".to_string(),
203            ));
204        }
205        let signature = ed25519_dalek::Signature::from_bytes(&signature_bytes.try_into().unwrap());
206
207        // Verify signature
208        let message = format!("{}.{}", header_b64, claims_b64);
209        self.verifying_key
210            .verify(message.as_bytes(), &signature)
211            .map_err(|_| VerifyError::InvalidSignature)?;
212
213        // Check issuer
214        if claims.iss != self.issuer {
215            return Err(VerifyError::InvalidIssuer);
216        }
217
218        // Check audience
219        if claims.aud != self.audience {
220            return Err(VerifyError::InvalidAudience);
221        }
222
223        // Check expiration
224        use std::time::{SystemTime, UNIX_EPOCH};
225        let now = SystemTime::now()
226            .duration_since(UNIX_EPOCH)
227            .expect("time should not be before epoch")
228            .as_secs();
229
230        if claims.exp <= now {
231            return Err(VerifyError::Expired);
232        }
233
234        if claims.nbf > now {
235            return Err(VerifyError::NotYetValid);
236        }
237
238        // Validate origin if required or if token has origin binding
239        let token_has_origin = claims.origin.is_some();
240        let origin_provided = expected_origin.is_some();
241
242        if token_has_origin {
243            // Token is origin-bound - must provide matching origin
244            if !origin_provided {
245                return Err(VerifyError::OriginRequired);
246            }
247
248            let expected = expected_origin.unwrap();
249            let actual = claims.origin.as_ref().unwrap();
250
251            if actual != expected {
252                return Err(VerifyError::OriginMismatch {
253                    expected: expected.to_string(),
254                    actual: actual.clone(),
255                });
256            }
257        } else if self.require_origin {
258            // Verifier requires origin but token doesn't have one bound
259            return Err(VerifyError::MissingClaim("origin".to_string()));
260        }
261
262        // Validate client IP if required
263        if self.require_client_ip {
264            if let Some(expected) = expected_client_ip {
265                match &claims.client_ip {
266                    Some(actual) if actual == expected => {}
267                    Some(actual) => {
268                        return Err(VerifyError::OriginMismatch {
269                            expected: expected.to_string(),
270                            actual: actual.clone(),
271                        });
272                    }
273                    None => {
274                        return Err(VerifyError::MissingClaim("client_ip".to_string()));
275                    }
276                }
277            } else if claims.client_ip.is_none() {
278                return Err(VerifyError::MissingClaim("client_ip".to_string()));
279            }
280        }
281
282        Ok(AuthContext::from_claims(claims))
283    }
284
285    /// Get the expected issuer
286    pub fn issuer(&self) -> &str {
287        &self.issuer
288    }
289
290    /// Get the expected audience
291    pub fn audience(&self) -> &str {
292        &self.audience
293    }
294}
295
296/// JWKS structure for key rotation
297#[derive(Debug, Clone, Deserialize)]
298pub struct Jwks {
299    pub keys: Vec<Jwk>,
300}
301
302#[derive(Debug, Clone, Deserialize)]
303pub struct Jwk {
304    pub kty: String,
305    #[serde(rename = "use")]
306    pub use_: Option<String>,
307    pub kid: String,
308    pub x: String, // Base64-encoded public key
309}
310
311/// Token verifier with JWKS support for key rotation
312#[derive(Clone)]
313pub struct JwksVerifier {
314    jwks: Jwks,
315    issuer: String,
316    audience: String,
317    require_origin: bool,
318}
319
320impl JwksVerifier {
321    /// Create a new JWKS verifier
322    pub fn new(jwks: Jwks, issuer: impl Into<String>, audience: impl Into<String>) -> Self {
323        Self {
324            jwks,
325            issuer: issuer.into(),
326            audience: audience.into(),
327            require_origin: false,
328        }
329    }
330
331    /// Require origin validation
332    pub fn with_origin_validation(mut self) -> Self {
333        self.require_origin = true;
334        self
335    }
336
337    /// Verify a token using the appropriate key from JWKS
338    pub fn verify(
339        &self,
340        token: &str,
341        expected_origin: Option<&str>,
342        expected_client_ip: Option<&str>,
343    ) -> Result<AuthContext, VerifyError> {
344        // Decode header to get kid
345        let parts: Vec<&str> = token.split('.').collect();
346        if parts.len() != 3 {
347            return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
348        }
349
350        let header_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
351            .decode(parts[0])
352            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header: {}", e)))?;
353        let header: JwtHeader = serde_json::from_slice(&header_json)
354            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid header JSON: {}", e)))?;
355
356        let kid = header
357            .kid
358            .ok_or_else(|| VerifyError::MissingClaim("kid".to_string()))?;
359
360        // Find the key
361        let jwk = self
362            .jwks
363            .keys
364            .iter()
365            .find(|k| k.kid == kid)
366            .ok_or(VerifyError::KeyNotFound(kid))?;
367
368        // Decode the public key from hex (first 16 chars of hex = 8 bytes of key id)
369        // Actually, we need to decode the full public key from the JWKS
370        // The JWKS should contain the full base64-encoded public key
371        let public_key_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
372            .decode(&jwk.x)
373            .map_err(|_| VerifyError::InvalidFormat("Invalid public key base64".to_string()))?;
374
375        let public_key: [u8; 32] = public_key_bytes
376            .try_into()
377            .map_err(|_| VerifyError::InvalidFormat("Invalid key length".to_string()))?;
378
379        // Create verifier for this key
380        let verifying_key = VerifyingKey::from_bytes(&public_key)
381            .map_err(|e| VerifyError::InvalidFormat(e.to_string()))?;
382
383        let verifier = if self.require_origin {
384            TokenVerifier::new(verifying_key, &self.issuer, &self.audience).with_origin_validation()
385        } else {
386            TokenVerifier::new(verifying_key, &self.issuer, &self.audience)
387        };
388
389        verifier.verify(token, expected_origin, expected_client_ip)
390    }
391
392    /// Fetch JWKS from a URL
393    #[cfg(feature = "jwks")]
394    pub async fn fetch_jwks(url: &str) -> Result<Jwks, reqwest::Error> {
395        let response = reqwest::get(url).await?;
396        let jwks: Jwks = response.json().await?;
397        Ok(jwks)
398    }
399}
400
401#[cfg(test)]
402/// HMAC-based verifier for tests only
403pub struct HmacVerifier {
404    _secret: Vec<u8>,
405    _issuer: String,
406    _audience: String,
407}
408
409#[cfg(test)]
410impl HmacVerifier {
411    /// Create a new HMAC verifier (dev only)
412    pub fn new(
413        secret: impl Into<Vec<u8>>,
414        issuer: impl Into<String>,
415        audience: impl Into<String>,
416    ) -> Self {
417        Self {
418            _secret: secret.into(),
419            _issuer: issuer.into(),
420            _audience: audience.into(),
421        }
422    }
423
424    /// Verify a token using HMAC
425    pub fn verify(
426        &self,
427        token: &str,
428        _expected_origin: Option<&str>,
429    ) -> Result<AuthContext, VerifyError> {
430        // Split token
431        let parts: Vec<&str> = token.split('.').collect();
432        if parts.len() != 3 {
433            return Err(VerifyError::InvalidFormat("Invalid JWT format".to_string()));
434        }
435
436        // For HMAC, we'd need to verify the HMAC signature
437        // This is a simplified implementation - in practice you'd use hmac-sha256
438        // For now, just decode the claims without verification (dev only!)
439        let claims_json = base64::engine::general_purpose::URL_SAFE_NO_PAD
440            .decode(parts[1])
441            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims: {}", e)))?;
442        let claims: SessionClaims = serde_json::from_slice(&claims_json)
443            .map_err(|e| VerifyError::InvalidFormat(format!("Invalid claims JSON: {}", e)))?;
444
445        Ok(AuthContext::from_claims(claims))
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use crate::claims::{KeyClass, Limits};
453
454    fn create_test_claims() -> SessionClaims {
455        SessionClaims::builder("test-issuer", "test-subject", "test-audience")
456            .with_ttl(300)
457            .with_scope("read")
458            .with_metering_key("meter-123")
459            .with_key_class(KeyClass::Publishable)
460            .with_limits(Limits {
461                max_connections: Some(10),
462                max_subscriptions: Some(100),
463                max_snapshot_rows: Some(1000),
464                max_messages_per_minute: Some(1000),
465                max_bytes_per_minute: Some(10_000_000),
466            })
467            .build()
468    }
469
470    #[test]
471    fn test_sign_and_verify() {
472        // Generate keys
473        let signing_key = crate::keys::SigningKey::generate();
474        let verifying_key = signing_key.verifying_key();
475
476        // Create signer and verifier
477        let signer = TokenSigner::new(signing_key, "test-issuer");
478        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
479
480        // Sign token
481        let claims = create_test_claims();
482        let token = signer.sign(claims.clone()).unwrap();
483
484        // Verify token
485        let context = verifier.verify(&token, None, None).unwrap();
486
487        assert_eq!(context.subject, "test-subject");
488        assert_eq!(context.issuer, "test-issuer");
489        assert_eq!(context.metering_key, "meter-123");
490    }
491
492    #[test]
493    fn test_expired_token() {
494        let signing_key = crate::keys::SigningKey::generate();
495        let verifying_key = signing_key.verifying_key();
496
497        let signer = TokenSigner::new(signing_key, "test-issuer");
498        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
499
500        // Create expired claims
501        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
502            .with_ttl(0) // Already expired
503            .with_scope("read")
504            .with_metering_key("meter-123")
505            .with_key_class(KeyClass::Publishable)
506            .build();
507
508        let token = signer.sign(claims).unwrap();
509
510        // Should fail with expired error
511        let result = verifier.verify(&token, None, None);
512        assert!(matches!(result, Err(VerifyError::Expired)));
513    }
514
515    #[test]
516    fn test_invalid_signature() {
517        let signing_key = crate::keys::SigningKey::generate();
518        let wrong_signing_key = crate::keys::SigningKey::generate();
519        let wrong_verifying_key = wrong_signing_key.verifying_key();
520
521        let signer = TokenSigner::new(signing_key, "test-issuer");
522        let verifier = TokenVerifier::new(wrong_verifying_key, "test-issuer", "test-audience");
523
524        let claims = create_test_claims();
525        let token = signer.sign(claims).unwrap();
526
527        // Should fail with invalid signature
528        let result = verifier.verify(&token, None, None);
529        assert!(matches!(result, Err(VerifyError::InvalidSignature)));
530    }
531
532    #[test]
533    fn test_wrong_issuer() {
534        let signing_key = crate::keys::SigningKey::generate();
535        let verifying_key = signing_key.verifying_key();
536
537        let signer = TokenSigner::new(signing_key, "wrong-issuer");
538        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
539
540        // Create claims with the wrong issuer
541        let claims = SessionClaims::builder("wrong-issuer", "test-subject", "test-audience")
542            .with_ttl(300)
543            .with_scope("read")
544            .with_metering_key("meter-123")
545            .with_key_class(KeyClass::Publishable)
546            .build();
547        let token = signer.sign(claims).unwrap();
548
549        // Should fail with invalid issuer
550        let result = verifier.verify(&token, None, None);
551        assert!(matches!(result, Err(VerifyError::InvalidIssuer)));
552    }
553
554    #[test]
555    fn test_wrong_audience() {
556        let signing_key = crate::keys::SigningKey::generate();
557        let verifying_key = signing_key.verifying_key();
558
559        let signer = TokenSigner::new(signing_key, "test-issuer");
560        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "expected-audience");
561
562        let claims = SessionClaims::builder("test-issuer", "test-subject", "wrong-audience")
563            .with_ttl(300)
564            .with_scope("read")
565            .with_metering_key("meter-123")
566            .with_key_class(KeyClass::Publishable)
567            .build();
568        let token = signer.sign(claims).unwrap();
569
570        let result = verifier.verify(&token, None, None);
571        assert!(matches!(result, Err(VerifyError::InvalidAudience)));
572    }
573
574    #[test]
575    fn test_origin_mismatch() {
576        let signing_key = crate::keys::SigningKey::generate();
577        let verifying_key = signing_key.verifying_key();
578
579        let signer = TokenSigner::new(signing_key, "test-issuer");
580        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
581            .with_origin_validation();
582
583        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
584            .with_ttl(300)
585            .with_scope("read")
586            .with_metering_key("meter-123")
587            .with_origin("https://allowed.example")
588            .with_key_class(KeyClass::Publishable)
589            .build();
590        let token = signer.sign(claims).unwrap();
591
592        let result = verifier.verify(&token, Some("https://other.example"), None);
593        assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
594    }
595
596    #[test]
597    fn test_origin_validation_success() {
598        let signing_key = crate::keys::SigningKey::generate();
599        let verifying_key = signing_key.verifying_key();
600
601        let signer = TokenSigner::new(signing_key, "test-issuer");
602        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
603            .with_origin_validation();
604
605        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
606            .with_ttl(300)
607            .with_scope("read")
608            .with_metering_key("meter-123")
609            .with_origin("https://allowed.example")
610            .with_key_class(KeyClass::Publishable)
611            .build();
612        let token = signer.sign(claims).unwrap();
613
614        let context = verifier
615            .verify(&token, Some("https://allowed.example"), None)
616            .unwrap();
617        assert_eq!(context.origin.as_deref(), Some("https://allowed.example"));
618    }
619
620    #[test]
621    fn test_origin_validation_requires_origin_claim() {
622        let signing_key = crate::keys::SigningKey::generate();
623        let verifying_key = signing_key.verifying_key();
624
625        let signer = TokenSigner::new(signing_key, "test-issuer");
626        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
627            .with_origin_validation();
628
629        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
630            .with_ttl(300)
631            .with_scope("read")
632            .with_metering_key("meter-123")
633            .with_key_class(KeyClass::Publishable)
634            .build();
635        let token = signer.sign(claims).unwrap();
636
637        let result = verifier.verify(&token, None, None);
638        assert!(matches!(
639            result,
640            Err(VerifyError::MissingClaim(ref claim)) if claim == "origin"
641        ));
642    }
643
644    #[test]
645    fn test_client_ip_validation_requires_client_ip_claim() {
646        let signing_key = crate::keys::SigningKey::generate();
647        let verifying_key = signing_key.verifying_key();
648
649        let signer = TokenSigner::new(signing_key, "test-issuer");
650        let verifier = TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
651            .with_client_ip_validation();
652
653        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
654            .with_ttl(300)
655            .with_scope("read")
656            .with_metering_key("meter-123")
657            .with_key_class(KeyClass::Publishable)
658            .build();
659        let token = signer.sign(claims).unwrap();
660
661        let result = verifier.verify(&token, None, None);
662        assert!(matches!(
663            result,
664            Err(VerifyError::MissingClaim(ref claim)) if claim == "client_ip"
665        ));
666    }
667}