atproto_oauth/
jwt.rs

1//! JSON Web Token (JWT) minting and verification for AT Protocol OAuth.
2//!
3//! Provides functionality for creating and verifying JWT tokens with JOSE standard claims and
4//! custom extensions for AT Protocol OAuth flows. Supports ES256, ES384, and ES256K elliptic curve
5//! signature algorithms with comprehensive timestamp validation and structured error handling.
6
7use anyhow::Result;
8use atproto_identity::key::{KeyData, KeyType, sign, to_public, validate};
9use base64::{Engine as _, engine::general_purpose};
10use elliptic_curve::JwkEcKey;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeMap;
13use std::time::{SystemTime, UNIX_EPOCH};
14
15use crate::encoding::ToBase64;
16use crate::errors::JWTError;
17
18#[cfg(feature = "zeroize")]
19use zeroize::{Zeroize, ZeroizeOnDrop};
20
21/// JWT header containing algorithm and key metadata.
22#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
23#[cfg_attr(debug_assertions, derive(Debug))]
24#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
25pub struct Header {
26    /// Algorithm used for signing (e.g., "ES256", "ES384", "ES256K").
27    #[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
28    pub algorithm: Option<String>,
29
30    /// Key identifier for the signing key.
31    #[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
32    pub key_id: Option<String>,
33
34    /// Token type, typically "JWT".
35    #[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
36    pub type_: Option<String>,
37
38    /// Embedded JSON Web Key.
39    #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
40    pub json_web_key: Option<JwkEcKey>,
41}
42
43impl TryFrom<KeyData> for Header {
44    type Error = anyhow::Error;
45
46    fn try_from(value: KeyData) -> std::result::Result<Self, Self::Error> {
47        let algorithm = match value.key_type() {
48            KeyType::P256Public => Some("ES256".to_string()),
49            KeyType::P256Private => Some("ES256".to_string()),
50            KeyType::P384Public => Some("ES384".to_string()),
51            KeyType::P384Private => Some("ES384".to_string()),
52            KeyType::K256Public => Some("ES256K".to_string()),
53            KeyType::K256Private => Some("ES256K".to_string()),
54        };
55
56        let public_key = to_public(&value)?;
57        let key_id = Some(public_key.to_string());
58
59        Ok(Self {
60            algorithm,
61            key_id,
62            type_: None,
63            json_web_key: None,
64        })
65    }
66}
67
68/// JWT claims combining standard JOSE claims with custom private claims.
69#[cfg_attr(debug_assertions, derive(Debug))]
70#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
71pub struct Claims {
72    /// Standard JOSE claims.
73    #[serde(flatten)]
74    pub jose: JoseClaims,
75    /// Custom private claims.
76    #[serde(flatten)]
77    pub private: BTreeMap<String, serde_json::Value>,
78}
79
80impl Claims {
81    /// Create new Claims with the given JOSE claims.
82    pub fn new(jose: JoseClaims) -> Self {
83        Claims {
84            jose,
85            private: BTreeMap::new(),
86        }
87    }
88}
89
90/// Type alias for timestamp values representing seconds since Unix epoch.
91pub type SecondsSinceEpoch = u64;
92
93/// Standard JOSE claims for JWT tokens.
94#[cfg_attr(debug_assertions, derive(Debug))]
95#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
96pub struct JoseClaims {
97    /// Issuer of the token.
98    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
99    pub issuer: Option<String>,
100
101    /// Subject of the token.
102    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
103    pub subject: Option<String>,
104
105    /// Intended audience for the token.
106    #[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
107    pub audience: Option<String>,
108
109    /// Expiration time after which the token is invalid.
110    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
111    pub expiration: Option<SecondsSinceEpoch>,
112
113    /// Time before which the token is not valid.
114    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
115    pub not_before: Option<SecondsSinceEpoch>,
116
117    /// Time at which the token was issued.
118    #[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
119    pub issued_at: Option<SecondsSinceEpoch>,
120
121    /// Unique identifier for the token.
122    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
123    pub json_web_token_id: Option<String>,
124
125    /// HTTP method for request binding.
126    #[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
127    pub http_method: Option<String>,
128
129    /// HTTP URI for request binding.
130    #[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
131    pub http_uri: Option<String>,
132
133    /// Nonce value for replay protection.
134    #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
135    pub nonce: Option<String>,
136
137    /// Authorization token hash.
138    #[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
139    pub auth: Option<String>,
140}
141
142/// Create and sign a new JWT token.
143pub fn mint(key_data: &KeyData, header: &Header, claims: &Claims) -> Result<String> {
144    let header = header.to_base64()?;
145    let claims = claims.to_base64()?;
146    let content = format!("{}.{}", header, claims);
147
148    let signature = sign(key_data, content.as_bytes())?;
149
150    Ok(format!(
151        "{}.{}",
152        content,
153        general_purpose::URL_SAFE_NO_PAD.encode(signature)
154    ))
155}
156
157/// Verify a JWT token and extract its claims.
158pub fn verify(token: &str, key_data: &KeyData) -> Result<Claims> {
159    // Split token into its parts
160    let parts: Vec<&str> = token.split('.').collect();
161    if parts.len() != 3 {
162        return Err(JWTError::InvalidFormat.into());
163    }
164
165    let encoded_header = parts[0];
166    let encoded_claims = parts[1];
167    let encoded_signature = parts[2];
168
169    // Decode header
170    let header_bytes = general_purpose::URL_SAFE_NO_PAD
171        .decode(encoded_header)
172        .map_err(|_| JWTError::InvalidHeader)?;
173
174    let header: Header =
175        serde_json::from_slice(&header_bytes).map_err(|_| JWTError::InvalidHeader)?;
176
177    match (header.algorithm.as_deref(), key_data.key_type()) {
178        (Some("ES256K"), KeyType::K256Private) | (Some("ES256K"), KeyType::K256Public) => {}
179        (Some("ES256"), KeyType::P256Private) | (Some("ES256"), KeyType::P256Public) => {}
180        (Some("ES384"), KeyType::P384Private) | (Some("ES384"), KeyType::P384Public) => {}
181        _ => {
182            return Err(JWTError::UnsupportedAlgorithm {
183                algorithm: header
184                    .algorithm
185                    .clone()
186                    .unwrap_or_else(|| "none".to_string()),
187                key_type: format!("{}", key_data.key_type()),
188            }
189            .into());
190        }
191    }
192
193    // Decode claims
194    let claims_bytes = general_purpose::URL_SAFE_NO_PAD
195        .decode(encoded_claims)
196        .map_err(|_| JWTError::InvalidClaims)?;
197
198    let claims: Claims =
199        serde_json::from_slice(&claims_bytes).map_err(|_| JWTError::InvalidClaims)?;
200
201    // Decode signature
202    let signature_bytes = general_purpose::URL_SAFE_NO_PAD
203        .decode(encoded_signature)
204        .map_err(|_| JWTError::InvalidSignature)?;
205
206    let content = format!("{}.{}", encoded_header, encoded_claims);
207
208    validate(key_data, &signature_bytes, content.as_bytes())
209        .map_err(|_| JWTError::SignatureVerificationFailed)?;
210
211    // Get current timestamp for validation
212    let now = SystemTime::now()
213        .duration_since(UNIX_EPOCH)
214        .map_err(|_| JWTError::SystemTimeError)?
215        .as_secs();
216
217    // Validate expiration time if present
218    if let Some(exp) = claims.jose.expiration {
219        if now >= exp {
220            return Err(JWTError::TokenExpired.into());
221        }
222    }
223
224    // Validate not-before time if present
225    if let Some(nbf) = claims.jose.not_before {
226        if now < nbf {
227            return Err(JWTError::TokenNotValidYet.into());
228        }
229    }
230
231    // Return validated claims
232    Ok(claims)
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use atproto_identity::key::{KeyType, generate_key, identify_key, to_public};
239
240    #[test]
241    fn test_header_from_p256_private_key() -> Result<()> {
242        let key_data = generate_key(KeyType::P256Private)?;
243        let header: Header = key_data.clone().try_into()?;
244
245        assert_eq!(header.algorithm, Some("ES256".to_string()));
246        assert!(header.key_id.is_some());
247        assert_eq!(header.type_, None);
248        assert_eq!(header.json_web_key, None);
249
250        // Verify key_id contains the public key DID
251        let public_key = to_public(&key_data)?;
252        let expected_key_id = format!("{}", public_key);
253        assert_eq!(header.key_id, Some(expected_key_id));
254
255        Ok(())
256    }
257
258    #[test]
259    fn test_header_from_p256_public_key() -> Result<()> {
260        let private_key = generate_key(KeyType::P256Private)?;
261        let public_key = to_public(&private_key)?;
262        let header: Header = public_key.clone().try_into()?;
263
264        assert_eq!(header.algorithm, Some("ES256".to_string()));
265        assert!(header.key_id.is_some());
266        assert_eq!(header.type_, None);
267        assert_eq!(header.json_web_key, None);
268
269        // Verify key_id contains the public key DID
270        let expected_key_id = format!("{}", public_key);
271        assert_eq!(header.key_id, Some(expected_key_id));
272
273        Ok(())
274    }
275
276    #[test]
277    fn test_header_from_k256_private_key() -> Result<()> {
278        let key_data = generate_key(KeyType::K256Private)?;
279        let header: Header = key_data.clone().try_into()?;
280
281        assert_eq!(header.algorithm, Some("ES256K".to_string()));
282        assert!(header.key_id.is_some());
283        assert_eq!(header.type_, None);
284        assert_eq!(header.json_web_key, None);
285
286        // Verify key_id contains the public key DID
287        let public_key = to_public(&key_data)?;
288        let expected_key_id = format!("{}", public_key);
289        assert_eq!(header.key_id, Some(expected_key_id));
290
291        Ok(())
292    }
293
294    #[test]
295    fn test_header_from_k256_public_key() -> Result<()> {
296        let private_key = generate_key(KeyType::K256Private)?;
297        let public_key = to_public(&private_key)?;
298        let header: Header = public_key.clone().try_into()?;
299
300        assert_eq!(header.algorithm, Some("ES256K".to_string()));
301        assert!(header.key_id.is_some());
302        assert_eq!(header.type_, None);
303        assert_eq!(header.json_web_key, None);
304
305        // Verify key_id contains the public key DID
306        let expected_key_id = format!("{}", public_key);
307        assert_eq!(header.key_id, Some(expected_key_id));
308
309        Ok(())
310    }
311
312    #[test]
313    fn test_header_consistency_private_vs_public_key() -> Result<()> {
314        // Test that private key and its derived public key produce headers with same key_id
315        let p256_private = generate_key(KeyType::P256Private)?;
316        let p256_public = to_public(&p256_private)?;
317
318        let header_from_private: Header = p256_private.try_into()?;
319        let header_from_public: Header = p256_public.try_into()?;
320
321        assert_eq!(header_from_private.algorithm, header_from_public.algorithm);
322        assert_eq!(header_from_private.key_id, header_from_public.key_id);
323        assert_eq!(header_from_private.type_, header_from_public.type_);
324        assert_eq!(
325            header_from_private.json_web_key,
326            header_from_public.json_web_key
327        );
328
329        // Test with K256 as well
330        let k256_private = generate_key(KeyType::K256Private)?;
331        let k256_public = to_public(&k256_private)?;
332
333        let k256_header_from_private: Header = k256_private.try_into()?;
334        let k256_header_from_public: Header = k256_public.try_into()?;
335
336        assert_eq!(
337            k256_header_from_private.algorithm,
338            k256_header_from_public.algorithm
339        );
340        assert_eq!(
341            k256_header_from_private.key_id,
342            k256_header_from_public.key_id
343        );
344        assert_eq!(
345            k256_header_from_private.type_,
346            k256_header_from_public.type_
347        );
348        assert_eq!(
349            k256_header_from_private.json_web_key,
350            k256_header_from_public.json_web_key
351        );
352
353        Ok(())
354    }
355
356    #[test]
357    fn test_header_from_existing_test_keys() -> Result<()> {
358        // Test with known keys from the identity crate test suite
359        let p256_private_key = "did:key:z42tnbHmmnhF11nwSnp5kQJbcZQw2Vbw5WF3ABDSxPtDgU2o";
360        let p256_public_key = "did:key:zDnaeXduWbJ1b1Kgjf3uCdCpMDF1LEDizUiyxAxGwerou3Nh2";
361        let k256_private_key = "did:key:z3vLY4nbXy2rV4Qr65gUtfnSF3A8Be7gmYzUiCX6eo2PR1Rt";
362        let k256_public_key = "did:key:zQ3shNzMp4oaaQ1gQRzCxMGXFrSW3NEM1M9T6KCY9eA7HhyEA";
363
364        // Parse the keys
365        let parsed_p256_private = identify_key(p256_private_key)?;
366        let parsed_p256_public = identify_key(p256_public_key)?;
367        let parsed_k256_private = identify_key(k256_private_key)?;
368        let parsed_k256_public = identify_key(k256_public_key)?;
369
370        // Derive the actual public keys from the private keys for comparison
371        let derived_p256_public = to_public(&parsed_p256_private)?;
372        let derived_k256_public = to_public(&parsed_k256_private)?;
373
374        // Test P256 private key
375        let p256_private_header: Header = parsed_p256_private.try_into()?;
376        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
377        let expected_p256_key_id = format!("{}", derived_p256_public);
378        assert_eq!(p256_private_header.key_id, Some(expected_p256_key_id));
379
380        // Test P256 public key (standalone)
381        let p256_public_header: Header = parsed_p256_public.try_into()?;
382        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
383        assert_eq!(p256_public_header.key_id, Some(p256_public_key.to_string()));
384
385        // Test K256 private key
386        let k256_private_header: Header = parsed_k256_private.try_into()?;
387        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
388        let expected_k256_key_id = format!("{}", derived_k256_public);
389        assert_eq!(k256_private_header.key_id, Some(expected_k256_key_id));
390
391        // Test K256 public key (standalone)
392        let k256_public_header: Header = parsed_k256_public.try_into()?;
393        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
394        assert_eq!(k256_public_header.key_id, Some(k256_public_key.to_string()));
395
396        // Test that derived public keys produce consistent headers
397        let derived_p256_public_header: Header = derived_p256_public.try_into()?;
398        let derived_k256_public_header: Header = derived_k256_public.try_into()?;
399
400        assert_eq!(p256_private_header, derived_p256_public_header);
401        assert_eq!(k256_private_header, derived_k256_public_header);
402
403        Ok(())
404    }
405
406    #[test]
407    fn test_header_multiple_conversions_same_key() -> Result<()> {
408        // Test that multiple conversions of the same key produce identical headers
409        let key_data = generate_key(KeyType::P256Private)?;
410
411        let header1: Header = key_data.clone().try_into()?;
412        let header2: Header = key_data.try_into()?;
413
414        assert_eq!(header1, header2);
415
416        Ok(())
417    }
418
419    #[test]
420    fn test_header_different_keys_different_headers() -> Result<()> {
421        // Test that different keys produce different headers
422        let p256_key = generate_key(KeyType::P256Private)?;
423        let k256_key = generate_key(KeyType::K256Private)?;
424
425        let p256_header: Header = p256_key.try_into()?;
426        let k256_header: Header = k256_key.try_into()?;
427
428        // Algorithm should be different
429        assert_ne!(p256_header.algorithm, k256_header.algorithm);
430        assert_eq!(p256_header.algorithm, Some("ES256".to_string()));
431        assert_eq!(k256_header.algorithm, Some("ES256K".to_string()));
432
433        // Key IDs should be different (different public keys)
434        assert_ne!(p256_header.key_id, k256_header.key_id);
435
436        Ok(())
437    }
438
439    #[test]
440    fn test_header_from_invalid_key_data() {
441        // Test with invalid key data that would cause to_public() to fail
442        let invalid_key_data = KeyData::new(KeyType::P256Private, vec![0u8; 10]); // Too short
443
444        let result: Result<Header> = invalid_key_data.try_into();
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_header_serialization_deserialization() -> Result<()> {
450        // Test that Header can be serialized and deserialized correctly
451        let key_data = generate_key(KeyType::P256Private)?;
452        let header: Header = key_data.try_into()?;
453
454        // Serialize to JSON
455        let json = serde_json::to_string(&header)?;
456
457        // Deserialize back
458        let deserialized_header: Header = serde_json::from_str(&json)?;
459
460        assert_eq!(header, deserialized_header);
461
462        Ok(())
463    }
464
465    #[test]
466    fn test_header_json_field_names() -> Result<()> {
467        // Test that Header uses correct JSON field names (alg, kid, typ, jwk)
468        let key_data = generate_key(KeyType::P256Private)?;
469        let header: Header = key_data.try_into()?;
470
471        let json = serde_json::to_string(&header)?;
472        let json_value: serde_json::Value = serde_json::from_str(&json)?;
473
474        // Check that the correct field names are used
475        assert!(json_value.get("alg").is_some());
476        assert!(json_value.get("kid").is_some());
477        assert!(json_value.get("typ").is_none()); // Should be None and thus omitted
478        assert!(json_value.get("jwk").is_none()); // Should be None and thus omitted
479
480        // Verify values
481        assert_eq!(json_value["alg"], "ES256");
482        assert!(json_value["kid"].is_string());
483
484        Ok(())
485    }
486
487    #[test]
488    fn test_header_complete_workflow() -> Result<()> {
489        println!("\n=== Header TryFrom<KeyData> Test Workflow ===");
490
491        // Generate keys for all curves
492        println!("1. Generating test keys...");
493        let p256_private = generate_key(KeyType::P256Private)?;
494        let p384_private = generate_key(KeyType::P384Private)?;
495        let k256_private = generate_key(KeyType::K256Private)?;
496        let p256_public = to_public(&p256_private)?;
497        let p384_public = to_public(&p384_private)?;
498        let k256_public = to_public(&k256_private)?;
499
500        // Convert to headers
501        println!("2. Converting KeyData to Headers...");
502        let p256_private_header: Header = p256_private.try_into()?;
503        let p256_public_header: Header = p256_public.try_into()?;
504        let p384_private_header: Header = p384_private.try_into()?;
505        let p384_public_header: Header = p384_public.try_into()?;
506        let k256_private_header: Header = k256_private.try_into()?;
507        let k256_public_header: Header = k256_public.try_into()?;
508
509        // Verify algorithms
510        println!("3. Verifying algorithms...");
511        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
512        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
513        assert_eq!(p384_private_header.algorithm, Some("ES384".to_string()));
514        assert_eq!(p384_public_header.algorithm, Some("ES384".to_string()));
515        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
516        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
517        println!("   ✓ P-256 keys → ES256");
518        println!("   ✓ P-384 keys → ES384");
519        println!("   ✓ K-256 keys → ES256K");
520
521        // Verify key IDs match between private and public
522        println!("4. Verifying key ID consistency...");
523        assert_eq!(p256_private_header.key_id, p256_public_header.key_id);
524        assert_eq!(p384_private_header.key_id, p384_public_header.key_id);
525        assert_eq!(k256_private_header.key_id, k256_public_header.key_id);
526        println!("   ✓ Private and public keys produce same key_id");
527
528        // Verify other fields are None
529        println!("5. Verifying optional fields are None...");
530        for header in [
531            &p256_private_header,
532            &p256_public_header,
533            &p384_private_header,
534            &p384_public_header,
535            &k256_private_header,
536            &k256_public_header,
537        ] {
538            assert_eq!(header.type_, None);
539            assert_eq!(header.json_web_key, None);
540        }
541        println!("   ✓ type_ and json_web_key fields are None");
542
543        // Test JSON serialization
544        println!("6. Testing JSON serialization...");
545        let json = serde_json::to_string(&p384_private_header)?;
546        let parsed: Header = serde_json::from_str(&json)?;
547        assert_eq!(p384_private_header, parsed);
548        println!("   ✓ Headers serialize/deserialize correctly");
549
550        println!("=== All Header conversion tests passed! ===\n");
551
552        Ok(())
553    }
554
555    #[test]
556    fn test_header_from_p384_private_key() -> Result<()> {
557        let key_data = generate_key(KeyType::P384Private)?;
558        let header: Header = key_data.clone().try_into()?;
559
560        assert_eq!(header.algorithm, Some("ES384".to_string()));
561        assert!(header.key_id.is_some());
562        assert_eq!(header.type_, None);
563        assert_eq!(header.json_web_key, None);
564
565        // Verify key_id contains the public key DID
566        let public_key = to_public(&key_data)?;
567        let expected_key_id = format!("{}", public_key);
568        assert_eq!(header.key_id, Some(expected_key_id));
569
570        Ok(())
571    }
572
573    #[test]
574    fn test_header_from_p384_public_key() -> Result<()> {
575        let private_key = generate_key(KeyType::P384Private)?;
576        let public_key = to_public(&private_key)?;
577        let header: Header = public_key.clone().try_into()?;
578
579        assert_eq!(header.algorithm, Some("ES384".to_string()));
580        assert!(header.key_id.is_some());
581        assert_eq!(header.type_, None);
582        assert_eq!(header.json_web_key, None);
583
584        // Verify key_id contains the public key DID
585        let expected_key_id = format!("{}", public_key);
586        assert_eq!(header.key_id, Some(expected_key_id));
587
588        Ok(())
589    }
590}