atproto_oauth/
jwt.rs

1//! JSON Web Token (JWT) minting and verification.
2//!
3//! Create and verify JWTs with JOSE standard claims supporting
4//! ES256, ES384, and ES256K signature algorithms.
5
6use anyhow::Result;
7use atproto_identity::key::{KeyData, KeyType, sign, to_public, validate};
8use base64::{Engine as _, engine::general_purpose};
9use elliptic_curve::JwkEcKey;
10use serde::{Deserialize, Serialize};
11use std::collections::BTreeMap;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::encoding::ToBase64;
15use crate::errors::JWTError;
16
17#[cfg(feature = "zeroize")]
18use zeroize::{Zeroize, ZeroizeOnDrop};
19
20/// JWT header containing algorithm and key metadata.
21#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
22#[cfg_attr(debug_assertions, derive(Debug))]
23#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
24pub struct Header {
25    /// Algorithm used for signing (e.g., "ES256", "ES384", "ES256K").
26    #[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
27    pub algorithm: Option<String>,
28
29    /// Key identifier for the signing key.
30    #[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
31    pub key_id: Option<String>,
32
33    /// Token type, typically "JWT".
34    #[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
35    pub type_: Option<String>,
36
37    /// Embedded JSON Web Key.
38    #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
39    pub json_web_key: Option<JwkEcKey>,
40}
41
42impl TryFrom<KeyData> for Header {
43    type Error = anyhow::Error;
44
45    fn try_from(value: KeyData) -> std::result::Result<Self, Self::Error> {
46        let algorithm = match value.key_type() {
47            KeyType::P256Public => Some("ES256".to_string()),
48            KeyType::P256Private => Some("ES256".to_string()),
49            KeyType::P384Public => Some("ES384".to_string()),
50            KeyType::P384Private => Some("ES384".to_string()),
51            KeyType::K256Public => Some("ES256K".to_string()),
52            KeyType::K256Private => Some("ES256K".to_string()),
53        };
54
55        let public_key = to_public(&value)?;
56        let key_id = Some(public_key.to_string());
57
58        Ok(Self {
59            algorithm,
60            key_id,
61            type_: None,
62            json_web_key: None,
63        })
64    }
65}
66
67/// JWT claims combining standard JOSE claims with custom private claims.
68#[cfg_attr(debug_assertions, derive(Debug))]
69#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
70pub struct Claims {
71    /// Standard JOSE claims.
72    #[serde(flatten)]
73    pub jose: JoseClaims,
74    /// Custom private claims.
75    #[serde(flatten)]
76    pub private: BTreeMap<String, serde_json::Value>,
77}
78
79impl Claims {
80    /// Create new Claims with the given JOSE claims.
81    pub fn new(jose: JoseClaims) -> Self {
82        Claims {
83            jose,
84            private: BTreeMap::new(),
85        }
86    }
87}
88
89/// Type alias for timestamp values representing seconds since Unix epoch.
90pub type SecondsSinceEpoch = u64;
91
92/// Standard JOSE claims for JWT tokens.
93#[cfg_attr(debug_assertions, derive(Debug))]
94#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
95pub struct JoseClaims {
96    /// Issuer of the token.
97    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
98    pub issuer: Option<String>,
99
100    /// Subject of the token.
101    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
102    pub subject: Option<String>,
103
104    /// Intended audience for the token.
105    #[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
106    pub audience: Option<String>,
107
108    /// Expiration time after which the token is invalid.
109    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
110    pub expiration: Option<SecondsSinceEpoch>,
111
112    /// Time before which the token is not valid.
113    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
114    pub not_before: Option<SecondsSinceEpoch>,
115
116    /// Time at which the token was issued.
117    #[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
118    pub issued_at: Option<SecondsSinceEpoch>,
119
120    /// Unique identifier for the token.
121    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
122    pub json_web_token_id: Option<String>,
123
124    /// HTTP method for request binding.
125    #[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
126    pub http_method: Option<String>,
127
128    /// HTTP URI for request binding.
129    #[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
130    pub http_uri: Option<String>,
131
132    /// Nonce value for replay protection.
133    #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
134    pub nonce: Option<String>,
135
136    /// Authorization token hash.
137    #[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
138    pub auth: Option<String>,
139}
140
141/// Create and sign a new JWT token.
142pub fn mint(key_data: &KeyData, header: &Header, claims: &Claims) -> Result<String> {
143    let header = header.to_base64()?;
144    let claims = claims.to_base64()?;
145    let content = format!("{}.{}", header, claims);
146
147    let signature = sign(key_data, content.as_bytes())?;
148
149    Ok(format!(
150        "{}.{}",
151        content,
152        general_purpose::URL_SAFE_NO_PAD.encode(signature)
153    ))
154}
155
156/// Verify a JWT token and extract its claims.
157pub fn verify(token: &str, key_data: &KeyData) -> Result<Claims> {
158    // Split token into its parts
159    let parts: Vec<&str> = token.split('.').collect();
160    if parts.len() != 3 {
161        return Err(JWTError::InvalidFormat.into());
162    }
163
164    let encoded_header = parts[0];
165    let encoded_claims = parts[1];
166    let encoded_signature = parts[2];
167
168    // Decode header
169    let header_bytes = general_purpose::URL_SAFE_NO_PAD
170        .decode(encoded_header)
171        .map_err(|_| JWTError::InvalidHeader)?;
172
173    let header: Header =
174        serde_json::from_slice(&header_bytes).map_err(|_| JWTError::InvalidHeader)?;
175
176    match (header.algorithm.as_deref(), key_data.key_type()) {
177        (Some("ES256K"), KeyType::K256Private) | (Some("ES256K"), KeyType::K256Public) => {}
178        (Some("ES256"), KeyType::P256Private) | (Some("ES256"), KeyType::P256Public) => {}
179        (Some("ES384"), KeyType::P384Private) | (Some("ES384"), KeyType::P384Public) => {}
180        _ => {
181            return Err(JWTError::UnsupportedAlgorithm {
182                algorithm: header
183                    .algorithm
184                    .clone()
185                    .unwrap_or_else(|| "none".to_string()),
186                key_type: format!("{}", key_data.key_type()),
187            }
188            .into());
189        }
190    }
191
192    // Decode claims
193    let claims_bytes = general_purpose::URL_SAFE_NO_PAD
194        .decode(encoded_claims)
195        .map_err(|_| JWTError::InvalidClaims)?;
196
197    let claims: Claims =
198        serde_json::from_slice(&claims_bytes).map_err(|_| JWTError::InvalidClaims)?;
199
200    // Decode signature
201    let signature_bytes = general_purpose::URL_SAFE_NO_PAD
202        .decode(encoded_signature)
203        .map_err(|_| JWTError::InvalidSignature)?;
204
205    let content = format!("{}.{}", encoded_header, encoded_claims);
206
207    validate(key_data, &signature_bytes, content.as_bytes())
208        .map_err(|_| JWTError::SignatureVerificationFailed)?;
209
210    // Get current timestamp for validation
211    let now = SystemTime::now()
212        .duration_since(UNIX_EPOCH)
213        .map_err(|_| JWTError::SystemTimeError)?
214        .as_secs();
215
216    // Validate expiration time if present
217    if let Some(exp) = claims.jose.expiration
218        && now >= exp
219    {
220        return Err(JWTError::TokenExpired.into());
221    }
222
223    // Validate not-before time if present
224    if let Some(nbf) = claims.jose.not_before
225        && now < nbf
226    {
227        return Err(JWTError::TokenNotValidYet.into());
228    }
229
230    // Return validated claims
231    Ok(claims)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use atproto_identity::key::{KeyType, generate_key, identify_key, to_public};
238
239    #[test]
240    fn test_header_from_p256_private_key() -> Result<()> {
241        let key_data = generate_key(KeyType::P256Private)?;
242        let header: Header = key_data.clone().try_into()?;
243
244        assert_eq!(header.algorithm, Some("ES256".to_string()));
245        assert!(header.key_id.is_some());
246        assert_eq!(header.type_, None);
247        assert_eq!(header.json_web_key, None);
248
249        // Verify key_id contains the public key DID
250        let public_key = to_public(&key_data)?;
251        let expected_key_id = format!("{}", public_key);
252        assert_eq!(header.key_id, Some(expected_key_id));
253
254        Ok(())
255    }
256
257    #[test]
258    fn test_header_from_p256_public_key() -> Result<()> {
259        let private_key = generate_key(KeyType::P256Private)?;
260        let public_key = to_public(&private_key)?;
261        let header: Header = public_key.clone().try_into()?;
262
263        assert_eq!(header.algorithm, Some("ES256".to_string()));
264        assert!(header.key_id.is_some());
265        assert_eq!(header.type_, None);
266        assert_eq!(header.json_web_key, None);
267
268        // Verify key_id contains the public key DID
269        let expected_key_id = format!("{}", public_key);
270        assert_eq!(header.key_id, Some(expected_key_id));
271
272        Ok(())
273    }
274
275    #[test]
276    fn test_header_from_k256_private_key() -> Result<()> {
277        let key_data = generate_key(KeyType::K256Private)?;
278        let header: Header = key_data.clone().try_into()?;
279
280        assert_eq!(header.algorithm, Some("ES256K".to_string()));
281        assert!(header.key_id.is_some());
282        assert_eq!(header.type_, None);
283        assert_eq!(header.json_web_key, None);
284
285        // Verify key_id contains the public key DID
286        let public_key = to_public(&key_data)?;
287        let expected_key_id = format!("{}", public_key);
288        assert_eq!(header.key_id, Some(expected_key_id));
289
290        Ok(())
291    }
292
293    #[test]
294    fn test_header_from_k256_public_key() -> Result<()> {
295        let private_key = generate_key(KeyType::K256Private)?;
296        let public_key = to_public(&private_key)?;
297        let header: Header = public_key.clone().try_into()?;
298
299        assert_eq!(header.algorithm, Some("ES256K".to_string()));
300        assert!(header.key_id.is_some());
301        assert_eq!(header.type_, None);
302        assert_eq!(header.json_web_key, None);
303
304        // Verify key_id contains the public key DID
305        let expected_key_id = format!("{}", public_key);
306        assert_eq!(header.key_id, Some(expected_key_id));
307
308        Ok(())
309    }
310
311    #[test]
312    fn test_header_consistency_private_vs_public_key() -> Result<()> {
313        // Test that private key and its derived public key produce headers with same key_id
314        let p256_private = generate_key(KeyType::P256Private)?;
315        let p256_public = to_public(&p256_private)?;
316
317        let header_from_private: Header = p256_private.try_into()?;
318        let header_from_public: Header = p256_public.try_into()?;
319
320        assert_eq!(header_from_private.algorithm, header_from_public.algorithm);
321        assert_eq!(header_from_private.key_id, header_from_public.key_id);
322        assert_eq!(header_from_private.type_, header_from_public.type_);
323        assert_eq!(
324            header_from_private.json_web_key,
325            header_from_public.json_web_key
326        );
327
328        // Test with K256 as well
329        let k256_private = generate_key(KeyType::K256Private)?;
330        let k256_public = to_public(&k256_private)?;
331
332        let k256_header_from_private: Header = k256_private.try_into()?;
333        let k256_header_from_public: Header = k256_public.try_into()?;
334
335        assert_eq!(
336            k256_header_from_private.algorithm,
337            k256_header_from_public.algorithm
338        );
339        assert_eq!(
340            k256_header_from_private.key_id,
341            k256_header_from_public.key_id
342        );
343        assert_eq!(
344            k256_header_from_private.type_,
345            k256_header_from_public.type_
346        );
347        assert_eq!(
348            k256_header_from_private.json_web_key,
349            k256_header_from_public.json_web_key
350        );
351
352        Ok(())
353    }
354
355    #[test]
356    fn test_header_from_existing_test_keys() -> Result<()> {
357        // Test with known keys from the identity crate test suite
358        let p256_private_key = "did:key:z42tnbHmmnhF11nwSnp5kQJbcZQw2Vbw5WF3ABDSxPtDgU2o";
359        let p256_public_key = "did:key:zDnaeXduWbJ1b1Kgjf3uCdCpMDF1LEDizUiyxAxGwerou3Nh2";
360        let k256_private_key = "did:key:z3vLY4nbXy2rV4Qr65gUtfnSF3A8Be7gmYzUiCX6eo2PR1Rt";
361        let k256_public_key = "did:key:zQ3shNzMp4oaaQ1gQRzCxMGXFrSW3NEM1M9T6KCY9eA7HhyEA";
362
363        // Parse the keys
364        let parsed_p256_private = identify_key(p256_private_key)?;
365        let parsed_p256_public = identify_key(p256_public_key)?;
366        let parsed_k256_private = identify_key(k256_private_key)?;
367        let parsed_k256_public = identify_key(k256_public_key)?;
368
369        // Derive the actual public keys from the private keys for comparison
370        let derived_p256_public = to_public(&parsed_p256_private)?;
371        let derived_k256_public = to_public(&parsed_k256_private)?;
372
373        // Test P256 private key
374        let p256_private_header: Header = parsed_p256_private.try_into()?;
375        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
376        let expected_p256_key_id = format!("{}", derived_p256_public);
377        assert_eq!(p256_private_header.key_id, Some(expected_p256_key_id));
378
379        // Test P256 public key (standalone)
380        let p256_public_header: Header = parsed_p256_public.try_into()?;
381        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
382        assert_eq!(p256_public_header.key_id, Some(p256_public_key.to_string()));
383
384        // Test K256 private key
385        let k256_private_header: Header = parsed_k256_private.try_into()?;
386        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
387        let expected_k256_key_id = format!("{}", derived_k256_public);
388        assert_eq!(k256_private_header.key_id, Some(expected_k256_key_id));
389
390        // Test K256 public key (standalone)
391        let k256_public_header: Header = parsed_k256_public.try_into()?;
392        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
393        assert_eq!(k256_public_header.key_id, Some(k256_public_key.to_string()));
394
395        // Test that derived public keys produce consistent headers
396        let derived_p256_public_header: Header = derived_p256_public.try_into()?;
397        let derived_k256_public_header: Header = derived_k256_public.try_into()?;
398
399        assert_eq!(p256_private_header, derived_p256_public_header);
400        assert_eq!(k256_private_header, derived_k256_public_header);
401
402        Ok(())
403    }
404
405    #[test]
406    fn test_header_multiple_conversions_same_key() -> Result<()> {
407        // Test that multiple conversions of the same key produce identical headers
408        let key_data = generate_key(KeyType::P256Private)?;
409
410        let header1: Header = key_data.clone().try_into()?;
411        let header2: Header = key_data.try_into()?;
412
413        assert_eq!(header1, header2);
414
415        Ok(())
416    }
417
418    #[test]
419    fn test_header_different_keys_different_headers() -> Result<()> {
420        // Test that different keys produce different headers
421        let p256_key = generate_key(KeyType::P256Private)?;
422        let k256_key = generate_key(KeyType::K256Private)?;
423
424        let p256_header: Header = p256_key.try_into()?;
425        let k256_header: Header = k256_key.try_into()?;
426
427        // Algorithm should be different
428        assert_ne!(p256_header.algorithm, k256_header.algorithm);
429        assert_eq!(p256_header.algorithm, Some("ES256".to_string()));
430        assert_eq!(k256_header.algorithm, Some("ES256K".to_string()));
431
432        // Key IDs should be different (different public keys)
433        assert_ne!(p256_header.key_id, k256_header.key_id);
434
435        Ok(())
436    }
437
438    #[test]
439    fn test_header_from_invalid_key_data() {
440        // Test with invalid key data that would cause to_public() to fail
441        let invalid_key_data = KeyData::new(KeyType::P256Private, vec![0u8; 10]); // Too short
442
443        let result: Result<Header> = invalid_key_data.try_into();
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn test_header_serialization_deserialization() -> Result<()> {
449        // Test that Header can be serialized and deserialized correctly
450        let key_data = generate_key(KeyType::P256Private)?;
451        let header: Header = key_data.try_into()?;
452
453        // Serialize to JSON
454        let json = serde_json::to_string(&header)?;
455
456        // Deserialize back
457        let deserialized_header: Header = serde_json::from_str(&json)?;
458
459        assert_eq!(header, deserialized_header);
460
461        Ok(())
462    }
463
464    #[test]
465    fn test_header_json_field_names() -> Result<()> {
466        // Test that Header uses correct JSON field names (alg, kid, typ, jwk)
467        let key_data = generate_key(KeyType::P256Private)?;
468        let header: Header = key_data.try_into()?;
469
470        let json = serde_json::to_string(&header)?;
471        let json_value: serde_json::Value = serde_json::from_str(&json)?;
472
473        // Check that the correct field names are used
474        assert!(json_value.get("alg").is_some());
475        assert!(json_value.get("kid").is_some());
476        assert!(json_value.get("typ").is_none()); // Should be None and thus omitted
477        assert!(json_value.get("jwk").is_none()); // Should be None and thus omitted
478
479        // Verify values
480        assert_eq!(json_value["alg"], "ES256");
481        assert!(json_value["kid"].is_string());
482
483        Ok(())
484    }
485
486    #[test]
487    fn test_header_complete_workflow() -> Result<()> {
488        println!("\n=== Header TryFrom<KeyData> Test Workflow ===");
489
490        // Generate keys for all curves
491        println!("1. Generating test keys...");
492        let p256_private = generate_key(KeyType::P256Private)?;
493        let p384_private = generate_key(KeyType::P384Private)?;
494        let k256_private = generate_key(KeyType::K256Private)?;
495        let p256_public = to_public(&p256_private)?;
496        let p384_public = to_public(&p384_private)?;
497        let k256_public = to_public(&k256_private)?;
498
499        // Convert to headers
500        println!("2. Converting KeyData to Headers...");
501        let p256_private_header: Header = p256_private.try_into()?;
502        let p256_public_header: Header = p256_public.try_into()?;
503        let p384_private_header: Header = p384_private.try_into()?;
504        let p384_public_header: Header = p384_public.try_into()?;
505        let k256_private_header: Header = k256_private.try_into()?;
506        let k256_public_header: Header = k256_public.try_into()?;
507
508        // Verify algorithms
509        println!("3. Verifying algorithms...");
510        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
511        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
512        assert_eq!(p384_private_header.algorithm, Some("ES384".to_string()));
513        assert_eq!(p384_public_header.algorithm, Some("ES384".to_string()));
514        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
515        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
516        println!("   ✓ P-256 keys → ES256");
517        println!("   ✓ P-384 keys → ES384");
518        println!("   ✓ K-256 keys → ES256K");
519
520        // Verify key IDs match between private and public
521        println!("4. Verifying key ID consistency...");
522        assert_eq!(p256_private_header.key_id, p256_public_header.key_id);
523        assert_eq!(p384_private_header.key_id, p384_public_header.key_id);
524        assert_eq!(k256_private_header.key_id, k256_public_header.key_id);
525        println!("   ✓ Private and public keys produce same key_id");
526
527        // Verify other fields are None
528        println!("5. Verifying optional fields are None...");
529        for header in [
530            &p256_private_header,
531            &p256_public_header,
532            &p384_private_header,
533            &p384_public_header,
534            &k256_private_header,
535            &k256_public_header,
536        ] {
537            assert_eq!(header.type_, None);
538            assert_eq!(header.json_web_key, None);
539        }
540        println!("   ✓ type_ and json_web_key fields are None");
541
542        // Test JSON serialization
543        println!("6. Testing JSON serialization...");
544        let json = serde_json::to_string(&p384_private_header)?;
545        let parsed: Header = serde_json::from_str(&json)?;
546        assert_eq!(p384_private_header, parsed);
547        println!("   ✓ Headers serialize/deserialize correctly");
548
549        println!("=== All Header conversion tests passed! ===\n");
550
551        Ok(())
552    }
553
554    #[test]
555    fn test_header_from_p384_private_key() -> Result<()> {
556        let key_data = generate_key(KeyType::P384Private)?;
557        let header: Header = key_data.clone().try_into()?;
558
559        assert_eq!(header.algorithm, Some("ES384".to_string()));
560        assert!(header.key_id.is_some());
561        assert_eq!(header.type_, None);
562        assert_eq!(header.json_web_key, None);
563
564        // Verify key_id contains the public key DID
565        let public_key = to_public(&key_data)?;
566        let expected_key_id = format!("{}", public_key);
567        assert_eq!(header.key_id, Some(expected_key_id));
568
569        Ok(())
570    }
571
572    #[test]
573    fn test_header_from_p384_public_key() -> Result<()> {
574        let private_key = generate_key(KeyType::P384Private)?;
575        let public_key = to_public(&private_key)?;
576        let header: Header = public_key.clone().try_into()?;
577
578        assert_eq!(header.algorithm, Some("ES384".to_string()));
579        assert!(header.key_id.is_some());
580        assert_eq!(header.type_, None);
581        assert_eq!(header.json_web_key, None);
582
583        // Verify key_id contains the public key DID
584        let expected_key_id = format!("{}", public_key);
585        assert_eq!(header.key_id, Some(expected_key_id));
586
587        Ok(())
588    }
589}