atproto_oauth/
jwt.rs

1//! JSON Web Token (JWT) minting and verification.
2//!
3//! Create and verify JWTs with JOSE standard claims supporting
4//! ES256, ES384, and ES256K signature algorithms.
5
6use anyhow::Result;
7use atproto_identity::key::{KeyData, KeyType, sign, to_public, validate};
8use base64::{Engine as _, engine::general_purpose};
9use elliptic_curve::JwkEcKey;
10use serde::{Deserialize, Serialize};
11use std::collections::BTreeMap;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::encoding::ToBase64;
15use crate::errors::JWTError;
16
17#[cfg(feature = "zeroize")]
18use zeroize::{Zeroize, ZeroizeOnDrop};
19
20/// JWT header containing algorithm and key metadata.
21#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
22#[cfg_attr(debug_assertions, derive(Debug))]
23#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
24pub struct Header {
25    /// Algorithm used for signing (e.g., "ES256", "ES384", "ES256K").
26    #[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
27    pub algorithm: Option<String>,
28
29    /// Key identifier for the signing key.
30    #[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
31    pub key_id: Option<String>,
32
33    /// Token type, typically "JWT".
34    #[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
35    pub type_: Option<String>,
36
37    /// Embedded JSON Web Key.
38    #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
39    pub json_web_key: Option<JwkEcKey>,
40}
41
42impl TryFrom<KeyData> for Header {
43    type Error = anyhow::Error;
44
45    fn try_from(value: KeyData) -> std::result::Result<Self, Self::Error> {
46        let algorithm = match value.key_type() {
47            KeyType::P256Public => Some("ES256".to_string()),
48            KeyType::P256Private => Some("ES256".to_string()),
49            KeyType::P384Public => Some("ES384".to_string()),
50            KeyType::P384Private => Some("ES384".to_string()),
51            KeyType::K256Public => Some("ES256K".to_string()),
52            KeyType::K256Private => Some("ES256K".to_string()),
53        };
54
55        let public_key = to_public(&value)?;
56        let key_id = Some(public_key.to_string());
57
58        Ok(Self {
59            algorithm,
60            key_id,
61            type_: None,
62            json_web_key: None,
63        })
64    }
65}
66
67/// JWT claims combining standard JOSE claims with custom private claims.
68#[cfg_attr(debug_assertions, derive(Debug))]
69#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
70pub struct Claims {
71    /// Standard JOSE claims.
72    #[serde(flatten)]
73    pub jose: JoseClaims,
74    /// Custom private claims.
75    #[serde(flatten)]
76    pub private: BTreeMap<String, serde_json::Value>,
77}
78
79impl Claims {
80    /// Create new Claims with the given JOSE claims.
81    pub fn new(jose: JoseClaims) -> Self {
82        Claims {
83            jose,
84            private: BTreeMap::new(),
85        }
86    }
87}
88
89/// Type alias for timestamp values representing seconds since Unix epoch.
90pub type SecondsSinceEpoch = u64;
91
92/// Standard JOSE claims for JWT tokens.
93#[cfg_attr(debug_assertions, derive(Debug))]
94#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
95pub struct JoseClaims {
96    /// Issuer of the token.
97    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
98    pub issuer: Option<String>,
99
100    /// Subject of the token.
101    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
102    pub subject: Option<String>,
103
104    /// Intended audience for the token.
105    #[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
106    pub audience: Option<String>,
107
108    /// Expiration time after which the token is invalid.
109    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
110    pub expiration: Option<SecondsSinceEpoch>,
111
112    /// Time before which the token is not valid.
113    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
114    pub not_before: Option<SecondsSinceEpoch>,
115
116    /// Time at which the token was issued.
117    #[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
118    pub issued_at: Option<SecondsSinceEpoch>,
119
120    /// Unique identifier for the token.
121    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
122    pub json_web_token_id: Option<String>,
123
124    /// HTTP method for request binding.
125    #[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
126    pub http_method: Option<String>,
127
128    /// HTTP URI for request binding.
129    #[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
130    pub http_uri: Option<String>,
131
132    /// Nonce value for replay protection.
133    #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
134    pub nonce: Option<String>,
135
136    /// Authorization token hash.
137    #[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
138    pub auth: Option<String>,
139}
140
141/// Create and sign a new JWT token.
142pub fn mint(key_data: &KeyData, header: &Header, claims: &Claims) -> Result<String> {
143    let header = header.to_base64()?;
144    let claims = claims.to_base64()?;
145    let content = format!("{}.{}", header, claims);
146
147    let signature = sign(key_data, content.as_bytes())?;
148
149    Ok(format!(
150        "{}.{}",
151        content,
152        general_purpose::URL_SAFE_NO_PAD.encode(signature)
153    ))
154}
155
156/// Verify a JWT token and extract its claims.
157pub fn verify(token: &str, key_data: &KeyData) -> Result<Claims> {
158    // Split token into its parts
159    let parts: Vec<&str> = token.split('.').collect();
160    if parts.len() != 3 {
161        return Err(JWTError::InvalidFormat.into());
162    }
163
164    let encoded_header = parts[0];
165    let encoded_claims = parts[1];
166    let encoded_signature = parts[2];
167
168    // Decode header
169    let header_bytes = general_purpose::URL_SAFE_NO_PAD
170        .decode(encoded_header)
171        .map_err(|_| JWTError::InvalidHeader)?;
172
173    let header: Header =
174        serde_json::from_slice(&header_bytes).map_err(|_| JWTError::InvalidHeader)?;
175
176    match (header.algorithm.as_deref(), key_data.key_type()) {
177        (Some("ES256K"), KeyType::K256Private) | (Some("ES256K"), KeyType::K256Public) => {}
178        (Some("ES256"), KeyType::P256Private) | (Some("ES256"), KeyType::P256Public) => {}
179        (Some("ES384"), KeyType::P384Private) | (Some("ES384"), KeyType::P384Public) => {}
180        _ => {
181            return Err(JWTError::UnsupportedAlgorithm {
182                algorithm: header
183                    .algorithm
184                    .clone()
185                    .unwrap_or_else(|| "none".to_string()),
186                key_type: format!("{}", key_data.key_type()),
187            }
188            .into());
189        }
190    }
191
192    // Decode claims
193    let claims_bytes = general_purpose::URL_SAFE_NO_PAD
194        .decode(encoded_claims)
195        .map_err(|_| JWTError::InvalidClaims)?;
196
197    let claims: Claims =
198        serde_json::from_slice(&claims_bytes).map_err(|_| JWTError::InvalidClaims)?;
199
200    // Decode signature
201    let signature_bytes = general_purpose::URL_SAFE_NO_PAD
202        .decode(encoded_signature)
203        .map_err(|_| JWTError::InvalidSignature)?;
204
205    let content = format!("{}.{}", encoded_header, encoded_claims);
206
207    validate(key_data, &signature_bytes, content.as_bytes())
208        .map_err(|_| JWTError::SignatureVerificationFailed)?;
209
210    // Get current timestamp for validation
211    let now = SystemTime::now()
212        .duration_since(UNIX_EPOCH)
213        .map_err(|_| JWTError::SystemTimeError)?
214        .as_secs();
215
216    // Validate expiration time if present
217    if let Some(exp) = claims.jose.expiration
218        && now >= exp {
219            return Err(JWTError::TokenExpired.into());
220        }
221
222    // Validate not-before time if present
223    if let Some(nbf) = claims.jose.not_before
224        && now < nbf {
225            return Err(JWTError::TokenNotValidYet.into());
226        }
227
228    // Return validated claims
229    Ok(claims)
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use atproto_identity::key::{KeyType, generate_key, identify_key, to_public};
236
237    #[test]
238    fn test_header_from_p256_private_key() -> Result<()> {
239        let key_data = generate_key(KeyType::P256Private)?;
240        let header: Header = key_data.clone().try_into()?;
241
242        assert_eq!(header.algorithm, Some("ES256".to_string()));
243        assert!(header.key_id.is_some());
244        assert_eq!(header.type_, None);
245        assert_eq!(header.json_web_key, None);
246
247        // Verify key_id contains the public key DID
248        let public_key = to_public(&key_data)?;
249        let expected_key_id = format!("{}", public_key);
250        assert_eq!(header.key_id, Some(expected_key_id));
251
252        Ok(())
253    }
254
255    #[test]
256    fn test_header_from_p256_public_key() -> Result<()> {
257        let private_key = generate_key(KeyType::P256Private)?;
258        let public_key = to_public(&private_key)?;
259        let header: Header = public_key.clone().try_into()?;
260
261        assert_eq!(header.algorithm, Some("ES256".to_string()));
262        assert!(header.key_id.is_some());
263        assert_eq!(header.type_, None);
264        assert_eq!(header.json_web_key, None);
265
266        // Verify key_id contains the public key DID
267        let expected_key_id = format!("{}", public_key);
268        assert_eq!(header.key_id, Some(expected_key_id));
269
270        Ok(())
271    }
272
273    #[test]
274    fn test_header_from_k256_private_key() -> Result<()> {
275        let key_data = generate_key(KeyType::K256Private)?;
276        let header: Header = key_data.clone().try_into()?;
277
278        assert_eq!(header.algorithm, Some("ES256K".to_string()));
279        assert!(header.key_id.is_some());
280        assert_eq!(header.type_, None);
281        assert_eq!(header.json_web_key, None);
282
283        // Verify key_id contains the public key DID
284        let public_key = to_public(&key_data)?;
285        let expected_key_id = format!("{}", public_key);
286        assert_eq!(header.key_id, Some(expected_key_id));
287
288        Ok(())
289    }
290
291    #[test]
292    fn test_header_from_k256_public_key() -> Result<()> {
293        let private_key = generate_key(KeyType::K256Private)?;
294        let public_key = to_public(&private_key)?;
295        let header: Header = public_key.clone().try_into()?;
296
297        assert_eq!(header.algorithm, Some("ES256K".to_string()));
298        assert!(header.key_id.is_some());
299        assert_eq!(header.type_, None);
300        assert_eq!(header.json_web_key, None);
301
302        // Verify key_id contains the public key DID
303        let expected_key_id = format!("{}", public_key);
304        assert_eq!(header.key_id, Some(expected_key_id));
305
306        Ok(())
307    }
308
309    #[test]
310    fn test_header_consistency_private_vs_public_key() -> Result<()> {
311        // Test that private key and its derived public key produce headers with same key_id
312        let p256_private = generate_key(KeyType::P256Private)?;
313        let p256_public = to_public(&p256_private)?;
314
315        let header_from_private: Header = p256_private.try_into()?;
316        let header_from_public: Header = p256_public.try_into()?;
317
318        assert_eq!(header_from_private.algorithm, header_from_public.algorithm);
319        assert_eq!(header_from_private.key_id, header_from_public.key_id);
320        assert_eq!(header_from_private.type_, header_from_public.type_);
321        assert_eq!(
322            header_from_private.json_web_key,
323            header_from_public.json_web_key
324        );
325
326        // Test with K256 as well
327        let k256_private = generate_key(KeyType::K256Private)?;
328        let k256_public = to_public(&k256_private)?;
329
330        let k256_header_from_private: Header = k256_private.try_into()?;
331        let k256_header_from_public: Header = k256_public.try_into()?;
332
333        assert_eq!(
334            k256_header_from_private.algorithm,
335            k256_header_from_public.algorithm
336        );
337        assert_eq!(
338            k256_header_from_private.key_id,
339            k256_header_from_public.key_id
340        );
341        assert_eq!(
342            k256_header_from_private.type_,
343            k256_header_from_public.type_
344        );
345        assert_eq!(
346            k256_header_from_private.json_web_key,
347            k256_header_from_public.json_web_key
348        );
349
350        Ok(())
351    }
352
353    #[test]
354    fn test_header_from_existing_test_keys() -> Result<()> {
355        // Test with known keys from the identity crate test suite
356        let p256_private_key = "did:key:z42tnbHmmnhF11nwSnp5kQJbcZQw2Vbw5WF3ABDSxPtDgU2o";
357        let p256_public_key = "did:key:zDnaeXduWbJ1b1Kgjf3uCdCpMDF1LEDizUiyxAxGwerou3Nh2";
358        let k256_private_key = "did:key:z3vLY4nbXy2rV4Qr65gUtfnSF3A8Be7gmYzUiCX6eo2PR1Rt";
359        let k256_public_key = "did:key:zQ3shNzMp4oaaQ1gQRzCxMGXFrSW3NEM1M9T6KCY9eA7HhyEA";
360
361        // Parse the keys
362        let parsed_p256_private = identify_key(p256_private_key)?;
363        let parsed_p256_public = identify_key(p256_public_key)?;
364        let parsed_k256_private = identify_key(k256_private_key)?;
365        let parsed_k256_public = identify_key(k256_public_key)?;
366
367        // Derive the actual public keys from the private keys for comparison
368        let derived_p256_public = to_public(&parsed_p256_private)?;
369        let derived_k256_public = to_public(&parsed_k256_private)?;
370
371        // Test P256 private key
372        let p256_private_header: Header = parsed_p256_private.try_into()?;
373        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
374        let expected_p256_key_id = format!("{}", derived_p256_public);
375        assert_eq!(p256_private_header.key_id, Some(expected_p256_key_id));
376
377        // Test P256 public key (standalone)
378        let p256_public_header: Header = parsed_p256_public.try_into()?;
379        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
380        assert_eq!(p256_public_header.key_id, Some(p256_public_key.to_string()));
381
382        // Test K256 private key
383        let k256_private_header: Header = parsed_k256_private.try_into()?;
384        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
385        let expected_k256_key_id = format!("{}", derived_k256_public);
386        assert_eq!(k256_private_header.key_id, Some(expected_k256_key_id));
387
388        // Test K256 public key (standalone)
389        let k256_public_header: Header = parsed_k256_public.try_into()?;
390        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
391        assert_eq!(k256_public_header.key_id, Some(k256_public_key.to_string()));
392
393        // Test that derived public keys produce consistent headers
394        let derived_p256_public_header: Header = derived_p256_public.try_into()?;
395        let derived_k256_public_header: Header = derived_k256_public.try_into()?;
396
397        assert_eq!(p256_private_header, derived_p256_public_header);
398        assert_eq!(k256_private_header, derived_k256_public_header);
399
400        Ok(())
401    }
402
403    #[test]
404    fn test_header_multiple_conversions_same_key() -> Result<()> {
405        // Test that multiple conversions of the same key produce identical headers
406        let key_data = generate_key(KeyType::P256Private)?;
407
408        let header1: Header = key_data.clone().try_into()?;
409        let header2: Header = key_data.try_into()?;
410
411        assert_eq!(header1, header2);
412
413        Ok(())
414    }
415
416    #[test]
417    fn test_header_different_keys_different_headers() -> Result<()> {
418        // Test that different keys produce different headers
419        let p256_key = generate_key(KeyType::P256Private)?;
420        let k256_key = generate_key(KeyType::K256Private)?;
421
422        let p256_header: Header = p256_key.try_into()?;
423        let k256_header: Header = k256_key.try_into()?;
424
425        // Algorithm should be different
426        assert_ne!(p256_header.algorithm, k256_header.algorithm);
427        assert_eq!(p256_header.algorithm, Some("ES256".to_string()));
428        assert_eq!(k256_header.algorithm, Some("ES256K".to_string()));
429
430        // Key IDs should be different (different public keys)
431        assert_ne!(p256_header.key_id, k256_header.key_id);
432
433        Ok(())
434    }
435
436    #[test]
437    fn test_header_from_invalid_key_data() {
438        // Test with invalid key data that would cause to_public() to fail
439        let invalid_key_data = KeyData::new(KeyType::P256Private, vec![0u8; 10]); // Too short
440
441        let result: Result<Header> = invalid_key_data.try_into();
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_header_serialization_deserialization() -> Result<()> {
447        // Test that Header can be serialized and deserialized correctly
448        let key_data = generate_key(KeyType::P256Private)?;
449        let header: Header = key_data.try_into()?;
450
451        // Serialize to JSON
452        let json = serde_json::to_string(&header)?;
453
454        // Deserialize back
455        let deserialized_header: Header = serde_json::from_str(&json)?;
456
457        assert_eq!(header, deserialized_header);
458
459        Ok(())
460    }
461
462    #[test]
463    fn test_header_json_field_names() -> Result<()> {
464        // Test that Header uses correct JSON field names (alg, kid, typ, jwk)
465        let key_data = generate_key(KeyType::P256Private)?;
466        let header: Header = key_data.try_into()?;
467
468        let json = serde_json::to_string(&header)?;
469        let json_value: serde_json::Value = serde_json::from_str(&json)?;
470
471        // Check that the correct field names are used
472        assert!(json_value.get("alg").is_some());
473        assert!(json_value.get("kid").is_some());
474        assert!(json_value.get("typ").is_none()); // Should be None and thus omitted
475        assert!(json_value.get("jwk").is_none()); // Should be None and thus omitted
476
477        // Verify values
478        assert_eq!(json_value["alg"], "ES256");
479        assert!(json_value["kid"].is_string());
480
481        Ok(())
482    }
483
484    #[test]
485    fn test_header_complete_workflow() -> Result<()> {
486        println!("\n=== Header TryFrom<KeyData> Test Workflow ===");
487
488        // Generate keys for all curves
489        println!("1. Generating test keys...");
490        let p256_private = generate_key(KeyType::P256Private)?;
491        let p384_private = generate_key(KeyType::P384Private)?;
492        let k256_private = generate_key(KeyType::K256Private)?;
493        let p256_public = to_public(&p256_private)?;
494        let p384_public = to_public(&p384_private)?;
495        let k256_public = to_public(&k256_private)?;
496
497        // Convert to headers
498        println!("2. Converting KeyData to Headers...");
499        let p256_private_header: Header = p256_private.try_into()?;
500        let p256_public_header: Header = p256_public.try_into()?;
501        let p384_private_header: Header = p384_private.try_into()?;
502        let p384_public_header: Header = p384_public.try_into()?;
503        let k256_private_header: Header = k256_private.try_into()?;
504        let k256_public_header: Header = k256_public.try_into()?;
505
506        // Verify algorithms
507        println!("3. Verifying algorithms...");
508        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
509        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
510        assert_eq!(p384_private_header.algorithm, Some("ES384".to_string()));
511        assert_eq!(p384_public_header.algorithm, Some("ES384".to_string()));
512        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
513        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
514        println!("   ✓ P-256 keys → ES256");
515        println!("   ✓ P-384 keys → ES384");
516        println!("   ✓ K-256 keys → ES256K");
517
518        // Verify key IDs match between private and public
519        println!("4. Verifying key ID consistency...");
520        assert_eq!(p256_private_header.key_id, p256_public_header.key_id);
521        assert_eq!(p384_private_header.key_id, p384_public_header.key_id);
522        assert_eq!(k256_private_header.key_id, k256_public_header.key_id);
523        println!("   ✓ Private and public keys produce same key_id");
524
525        // Verify other fields are None
526        println!("5. Verifying optional fields are None...");
527        for header in [
528            &p256_private_header,
529            &p256_public_header,
530            &p384_private_header,
531            &p384_public_header,
532            &k256_private_header,
533            &k256_public_header,
534        ] {
535            assert_eq!(header.type_, None);
536            assert_eq!(header.json_web_key, None);
537        }
538        println!("   ✓ type_ and json_web_key fields are None");
539
540        // Test JSON serialization
541        println!("6. Testing JSON serialization...");
542        let json = serde_json::to_string(&p384_private_header)?;
543        let parsed: Header = serde_json::from_str(&json)?;
544        assert_eq!(p384_private_header, parsed);
545        println!("   ✓ Headers serialize/deserialize correctly");
546
547        println!("=== All Header conversion tests passed! ===\n");
548
549        Ok(())
550    }
551
552    #[test]
553    fn test_header_from_p384_private_key() -> Result<()> {
554        let key_data = generate_key(KeyType::P384Private)?;
555        let header: Header = key_data.clone().try_into()?;
556
557        assert_eq!(header.algorithm, Some("ES384".to_string()));
558        assert!(header.key_id.is_some());
559        assert_eq!(header.type_, None);
560        assert_eq!(header.json_web_key, None);
561
562        // Verify key_id contains the public key DID
563        let public_key = to_public(&key_data)?;
564        let expected_key_id = format!("{}", public_key);
565        assert_eq!(header.key_id, Some(expected_key_id));
566
567        Ok(())
568    }
569
570    #[test]
571    fn test_header_from_p384_public_key() -> Result<()> {
572        let private_key = generate_key(KeyType::P384Private)?;
573        let public_key = to_public(&private_key)?;
574        let header: Header = public_key.clone().try_into()?;
575
576        assert_eq!(header.algorithm, Some("ES384".to_string()));
577        assert!(header.key_id.is_some());
578        assert_eq!(header.type_, None);
579        assert_eq!(header.json_web_key, None);
580
581        // Verify key_id contains the public key DID
582        let expected_key_id = format!("{}", public_key);
583        assert_eq!(header.key_id, Some(expected_key_id));
584
585        Ok(())
586    }
587}