atproto_oauth/
jwt.rs

1//! JSON Web Token (JWT) minting and verification for AT Protocol OAuth.
2//!
3//! Provides functionality for creating and verifying JWT tokens with JOSE standard claims and
4//! custom extensions for AT Protocol OAuth flows. Supports ES256, ES384, and ES256K elliptic curve
5//! signature algorithms with comprehensive timestamp validation and structured error handling.
6
7use crate::encoding::ToBase64;
8use crate::errors::JWTError;
9use anyhow::Result;
10use atproto_identity::key::{KeyData, KeyType, sign, to_public, validate};
11use base64::{Engine as _, engine::general_purpose};
12use elliptic_curve::JwkEcKey;
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17/// JWT header containing algorithm and key metadata.
18#[cfg_attr(debug_assertions, derive(Debug))]
19#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
20pub struct Header {
21    /// Algorithm used for signing (e.g., "ES256", "ES384", "ES256K").
22    #[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
23    pub algorithm: Option<String>,
24
25    /// Key identifier for the signing key.
26    #[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
27    pub key_id: Option<String>,
28
29    /// Token type, typically "JWT".
30    #[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
31    pub type_: Option<String>,
32
33    /// Embedded JSON Web Key.
34    #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
35    pub json_web_key: Option<JwkEcKey>,
36}
37
38impl TryFrom<KeyData> for Header {
39    type Error = anyhow::Error;
40
41    fn try_from(value: KeyData) -> std::result::Result<Self, Self::Error> {
42        let algorithm = match value.key_type() {
43            KeyType::P256Public => Some("ES256".to_string()),
44            KeyType::P256Private => Some("ES256".to_string()),
45            KeyType::P384Public => Some("ES384".to_string()),
46            KeyType::P384Private => Some("ES384".to_string()),
47            KeyType::K256Public => Some("ES256K".to_string()),
48            KeyType::K256Private => Some("ES256K".to_string()),
49        };
50
51        let public_key = to_public(&value)?;
52        let key_id = Some(public_key.to_string());
53
54        Ok(Self {
55            algorithm,
56            key_id,
57            type_: None,
58            json_web_key: None,
59        })
60    }
61}
62
63/// JWT claims combining standard JOSE claims with custom private claims.
64#[cfg_attr(debug_assertions, derive(Debug))]
65#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
66pub struct Claims {
67    /// Standard JOSE claims.
68    #[serde(flatten)]
69    pub jose: JoseClaims,
70    /// Custom private claims.
71    #[serde(flatten)]
72    pub private: BTreeMap<String, serde_json::Value>,
73}
74
75impl Claims {
76    /// Create new Claims with the given JOSE claims.
77    pub fn new(jose: JoseClaims) -> Self {
78        Claims {
79            jose,
80            private: BTreeMap::new(),
81        }
82    }
83}
84
85/// Type alias for timestamp values representing seconds since Unix epoch.
86pub type SecondsSinceEpoch = u64;
87
88/// Standard JOSE claims for JWT tokens.
89#[cfg_attr(debug_assertions, derive(Debug))]
90#[derive(Clone, Default, PartialEq, Serialize, Deserialize)]
91pub struct JoseClaims {
92    /// Issuer of the token.
93    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
94    pub issuer: Option<String>,
95
96    /// Subject of the token.
97    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
98    pub subject: Option<String>,
99
100    /// Intended audience for the token.
101    #[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
102    pub audience: Option<String>,
103
104    /// Expiration time after which the token is invalid.
105    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
106    pub expiration: Option<SecondsSinceEpoch>,
107
108    /// Time before which the token is not valid.
109    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
110    pub not_before: Option<SecondsSinceEpoch>,
111
112    /// Time at which the token was issued.
113    #[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
114    pub issued_at: Option<SecondsSinceEpoch>,
115
116    /// Unique identifier for the token.
117    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
118    pub json_web_token_id: Option<String>,
119
120    /// HTTP method for request binding.
121    #[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
122    pub http_method: Option<String>,
123
124    /// HTTP URI for request binding.
125    #[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
126    pub http_uri: Option<String>,
127
128    /// Nonce value for replay protection.
129    #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
130    pub nonce: Option<String>,
131
132    /// Authorization token hash.
133    #[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
134    pub auth: Option<String>,
135}
136
137/// Create and sign a new JWT token.
138pub fn mint(key_data: &KeyData, header: &Header, claims: &Claims) -> Result<String> {
139    let header = header.to_base64()?;
140    let claims = claims.to_base64()?;
141    let content = format!("{}.{}", header, claims);
142
143    let signature = sign(key_data, content.as_bytes())?;
144
145    Ok(format!(
146        "{}.{}",
147        content,
148        general_purpose::URL_SAFE_NO_PAD.encode(signature)
149    ))
150}
151
152/// Verify a JWT token and extract its claims.
153pub fn verify(token: &str, key_data: &KeyData) -> Result<Claims> {
154    // Split token into its parts
155    let parts: Vec<&str> = token.split('.').collect();
156    if parts.len() != 3 {
157        return Err(JWTError::InvalidFormat.into());
158    }
159
160    let encoded_header = parts[0];
161    let encoded_claims = parts[1];
162    let encoded_signature = parts[2];
163
164    // Decode header
165    let header_bytes = general_purpose::URL_SAFE_NO_PAD
166        .decode(encoded_header)
167        .map_err(|_| JWTError::InvalidHeader)?;
168
169    let header: Header =
170        serde_json::from_slice(&header_bytes).map_err(|_| JWTError::InvalidHeader)?;
171
172    match (header.algorithm.as_deref(), key_data.key_type()) {
173        (Some("ES256K"), KeyType::K256Private) | (Some("ES256K"), KeyType::K256Public) => {}
174        (Some("ES256"), KeyType::P256Private) | (Some("ES256"), KeyType::P256Public) => {}
175        (Some("ES384"), KeyType::P384Private) | (Some("ES384"), KeyType::P384Public) => {}
176        _ => {
177            return Err(JWTError::UnsupportedAlgorithm {
178                algorithm: header.algorithm.unwrap_or_else(|| "none".to_string()),
179                key_type: format!("{}", key_data.key_type()),
180            }
181            .into());
182        }
183    }
184
185    // Decode claims
186    let claims_bytes = general_purpose::URL_SAFE_NO_PAD
187        .decode(encoded_claims)
188        .map_err(|_| JWTError::InvalidClaims)?;
189
190    let claims: Claims =
191        serde_json::from_slice(&claims_bytes).map_err(|_| JWTError::InvalidClaims)?;
192
193    // Decode signature
194    let signature_bytes = general_purpose::URL_SAFE_NO_PAD
195        .decode(encoded_signature)
196        .map_err(|_| JWTError::InvalidSignature)?;
197
198    let content = format!("{}.{}", encoded_header, encoded_claims);
199
200    validate(key_data, &signature_bytes, content.as_bytes())
201        .map_err(|_| JWTError::SignatureVerificationFailed)?;
202
203    // Get current timestamp for validation
204    let now = SystemTime::now()
205        .duration_since(UNIX_EPOCH)
206        .map_err(|_| JWTError::SystemTimeError)?
207        .as_secs();
208
209    // Validate expiration time if present
210    if let Some(exp) = claims.jose.expiration {
211        if now >= exp {
212            return Err(JWTError::TokenExpired.into());
213        }
214    }
215
216    // Validate not-before time if present
217    if let Some(nbf) = claims.jose.not_before {
218        if now < nbf {
219            return Err(JWTError::TokenNotValidYet.into());
220        }
221    }
222
223    // Return validated claims
224    Ok(claims)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use atproto_identity::key::{KeyType, generate_key, identify_key, to_public};
231
232    #[test]
233    fn test_header_from_p256_private_key() -> Result<()> {
234        let key_data = generate_key(KeyType::P256Private)?;
235        let header: Header = key_data.clone().try_into()?;
236
237        assert_eq!(header.algorithm, Some("ES256".to_string()));
238        assert!(header.key_id.is_some());
239        assert_eq!(header.type_, None);
240        assert_eq!(header.json_web_key, None);
241
242        // Verify key_id contains the public key DID
243        let public_key = to_public(&key_data)?;
244        let expected_key_id = format!("{}", public_key);
245        assert_eq!(header.key_id, Some(expected_key_id));
246
247        Ok(())
248    }
249
250    #[test]
251    fn test_header_from_p256_public_key() -> Result<()> {
252        let private_key = generate_key(KeyType::P256Private)?;
253        let public_key = to_public(&private_key)?;
254        let header: Header = public_key.clone().try_into()?;
255
256        assert_eq!(header.algorithm, Some("ES256".to_string()));
257        assert!(header.key_id.is_some());
258        assert_eq!(header.type_, None);
259        assert_eq!(header.json_web_key, None);
260
261        // Verify key_id contains the public key DID
262        let expected_key_id = format!("{}", public_key);
263        assert_eq!(header.key_id, Some(expected_key_id));
264
265        Ok(())
266    }
267
268    #[test]
269    fn test_header_from_k256_private_key() -> Result<()> {
270        let key_data = generate_key(KeyType::K256Private)?;
271        let header: Header = key_data.clone().try_into()?;
272
273        assert_eq!(header.algorithm, Some("ES256K".to_string()));
274        assert!(header.key_id.is_some());
275        assert_eq!(header.type_, None);
276        assert_eq!(header.json_web_key, None);
277
278        // Verify key_id contains the public key DID
279        let public_key = to_public(&key_data)?;
280        let expected_key_id = format!("{}", public_key);
281        assert_eq!(header.key_id, Some(expected_key_id));
282
283        Ok(())
284    }
285
286    #[test]
287    fn test_header_from_k256_public_key() -> Result<()> {
288        let private_key = generate_key(KeyType::K256Private)?;
289        let public_key = to_public(&private_key)?;
290        let header: Header = public_key.clone().try_into()?;
291
292        assert_eq!(header.algorithm, Some("ES256K".to_string()));
293        assert!(header.key_id.is_some());
294        assert_eq!(header.type_, None);
295        assert_eq!(header.json_web_key, None);
296
297        // Verify key_id contains the public key DID
298        let expected_key_id = format!("{}", public_key);
299        assert_eq!(header.key_id, Some(expected_key_id));
300
301        Ok(())
302    }
303
304    #[test]
305    fn test_header_consistency_private_vs_public_key() -> Result<()> {
306        // Test that private key and its derived public key produce headers with same key_id
307        let p256_private = generate_key(KeyType::P256Private)?;
308        let p256_public = to_public(&p256_private)?;
309
310        let header_from_private: Header = p256_private.try_into()?;
311        let header_from_public: Header = p256_public.try_into()?;
312
313        assert_eq!(header_from_private.algorithm, header_from_public.algorithm);
314        assert_eq!(header_from_private.key_id, header_from_public.key_id);
315        assert_eq!(header_from_private.type_, header_from_public.type_);
316        assert_eq!(
317            header_from_private.json_web_key,
318            header_from_public.json_web_key
319        );
320
321        // Test with K256 as well
322        let k256_private = generate_key(KeyType::K256Private)?;
323        let k256_public = to_public(&k256_private)?;
324
325        let k256_header_from_private: Header = k256_private.try_into()?;
326        let k256_header_from_public: Header = k256_public.try_into()?;
327
328        assert_eq!(
329            k256_header_from_private.algorithm,
330            k256_header_from_public.algorithm
331        );
332        assert_eq!(
333            k256_header_from_private.key_id,
334            k256_header_from_public.key_id
335        );
336        assert_eq!(
337            k256_header_from_private.type_,
338            k256_header_from_public.type_
339        );
340        assert_eq!(
341            k256_header_from_private.json_web_key,
342            k256_header_from_public.json_web_key
343        );
344
345        Ok(())
346    }
347
348    #[test]
349    fn test_header_from_existing_test_keys() -> Result<()> {
350        // Test with known keys from the identity crate test suite
351        let p256_private_key = "did:key:z42tnbHmmnhF11nwSnp5kQJbcZQw2Vbw5WF3ABDSxPtDgU2o";
352        let p256_public_key = "did:key:zDnaeXduWbJ1b1Kgjf3uCdCpMDF1LEDizUiyxAxGwerou3Nh2";
353        let k256_private_key = "did:key:z3vLY4nbXy2rV4Qr65gUtfnSF3A8Be7gmYzUiCX6eo2PR1Rt";
354        let k256_public_key = "did:key:zQ3shNzMp4oaaQ1gQRzCxMGXFrSW3NEM1M9T6KCY9eA7HhyEA";
355
356        // Parse the keys
357        let parsed_p256_private = identify_key(p256_private_key)?;
358        let parsed_p256_public = identify_key(p256_public_key)?;
359        let parsed_k256_private = identify_key(k256_private_key)?;
360        let parsed_k256_public = identify_key(k256_public_key)?;
361
362        // Derive the actual public keys from the private keys for comparison
363        let derived_p256_public = to_public(&parsed_p256_private)?;
364        let derived_k256_public = to_public(&parsed_k256_private)?;
365
366        // Test P256 private key
367        let p256_private_header: Header = parsed_p256_private.try_into()?;
368        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
369        let expected_p256_key_id = format!("{}", derived_p256_public);
370        assert_eq!(p256_private_header.key_id, Some(expected_p256_key_id));
371
372        // Test P256 public key (standalone)
373        let p256_public_header: Header = parsed_p256_public.try_into()?;
374        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
375        assert_eq!(p256_public_header.key_id, Some(p256_public_key.to_string()));
376
377        // Test K256 private key
378        let k256_private_header: Header = parsed_k256_private.try_into()?;
379        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
380        let expected_k256_key_id = format!("{}", derived_k256_public);
381        assert_eq!(k256_private_header.key_id, Some(expected_k256_key_id));
382
383        // Test K256 public key (standalone)
384        let k256_public_header: Header = parsed_k256_public.try_into()?;
385        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
386        assert_eq!(k256_public_header.key_id, Some(k256_public_key.to_string()));
387
388        // Test that derived public keys produce consistent headers
389        let derived_p256_public_header: Header = derived_p256_public.try_into()?;
390        let derived_k256_public_header: Header = derived_k256_public.try_into()?;
391
392        assert_eq!(p256_private_header, derived_p256_public_header);
393        assert_eq!(k256_private_header, derived_k256_public_header);
394
395        Ok(())
396    }
397
398    #[test]
399    fn test_header_multiple_conversions_same_key() -> Result<()> {
400        // Test that multiple conversions of the same key produce identical headers
401        let key_data = generate_key(KeyType::P256Private)?;
402
403        let header1: Header = key_data.clone().try_into()?;
404        let header2: Header = key_data.try_into()?;
405
406        assert_eq!(header1, header2);
407
408        Ok(())
409    }
410
411    #[test]
412    fn test_header_different_keys_different_headers() -> Result<()> {
413        // Test that different keys produce different headers
414        let p256_key = generate_key(KeyType::P256Private)?;
415        let k256_key = generate_key(KeyType::K256Private)?;
416
417        let p256_header: Header = p256_key.try_into()?;
418        let k256_header: Header = k256_key.try_into()?;
419
420        // Algorithm should be different
421        assert_ne!(p256_header.algorithm, k256_header.algorithm);
422        assert_eq!(p256_header.algorithm, Some("ES256".to_string()));
423        assert_eq!(k256_header.algorithm, Some("ES256K".to_string()));
424
425        // Key IDs should be different (different public keys)
426        assert_ne!(p256_header.key_id, k256_header.key_id);
427
428        Ok(())
429    }
430
431    #[test]
432    fn test_header_from_invalid_key_data() {
433        // Test with invalid key data that would cause to_public() to fail
434        let invalid_key_data = KeyData::new(KeyType::P256Private, vec![0u8; 10]); // Too short
435
436        let result: Result<Header> = invalid_key_data.try_into();
437        assert!(result.is_err());
438    }
439
440    #[test]
441    fn test_header_serialization_deserialization() -> Result<()> {
442        // Test that Header can be serialized and deserialized correctly
443        let key_data = generate_key(KeyType::P256Private)?;
444        let header: Header = key_data.try_into()?;
445
446        // Serialize to JSON
447        let json = serde_json::to_string(&header)?;
448
449        // Deserialize back
450        let deserialized_header: Header = serde_json::from_str(&json)?;
451
452        assert_eq!(header, deserialized_header);
453
454        Ok(())
455    }
456
457    #[test]
458    fn test_header_json_field_names() -> Result<()> {
459        // Test that Header uses correct JSON field names (alg, kid, typ, jwk)
460        let key_data = generate_key(KeyType::P256Private)?;
461        let header: Header = key_data.try_into()?;
462
463        let json = serde_json::to_string(&header)?;
464        let json_value: serde_json::Value = serde_json::from_str(&json)?;
465
466        // Check that the correct field names are used
467        assert!(json_value.get("alg").is_some());
468        assert!(json_value.get("kid").is_some());
469        assert!(json_value.get("typ").is_none()); // Should be None and thus omitted
470        assert!(json_value.get("jwk").is_none()); // Should be None and thus omitted
471
472        // Verify values
473        assert_eq!(json_value["alg"], "ES256");
474        assert!(json_value["kid"].is_string());
475
476        Ok(())
477    }
478
479    #[test]
480    fn test_header_complete_workflow() -> Result<()> {
481        println!("\n=== Header TryFrom<KeyData> Test Workflow ===");
482
483        // Generate keys for all curves
484        println!("1. Generating test keys...");
485        let p256_private = generate_key(KeyType::P256Private)?;
486        let p384_private = generate_key(KeyType::P384Private)?;
487        let k256_private = generate_key(KeyType::K256Private)?;
488        let p256_public = to_public(&p256_private)?;
489        let p384_public = to_public(&p384_private)?;
490        let k256_public = to_public(&k256_private)?;
491
492        // Convert to headers
493        println!("2. Converting KeyData to Headers...");
494        let p256_private_header: Header = p256_private.try_into()?;
495        let p256_public_header: Header = p256_public.try_into()?;
496        let p384_private_header: Header = p384_private.try_into()?;
497        let p384_public_header: Header = p384_public.try_into()?;
498        let k256_private_header: Header = k256_private.try_into()?;
499        let k256_public_header: Header = k256_public.try_into()?;
500
501        // Verify algorithms
502        println!("3. Verifying algorithms...");
503        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
504        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
505        assert_eq!(p384_private_header.algorithm, Some("ES384".to_string()));
506        assert_eq!(p384_public_header.algorithm, Some("ES384".to_string()));
507        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
508        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
509        println!("   ✓ P-256 keys → ES256");
510        println!("   ✓ P-384 keys → ES384");
511        println!("   ✓ K-256 keys → ES256K");
512
513        // Verify key IDs match between private and public
514        println!("4. Verifying key ID consistency...");
515        assert_eq!(p256_private_header.key_id, p256_public_header.key_id);
516        assert_eq!(p384_private_header.key_id, p384_public_header.key_id);
517        assert_eq!(k256_private_header.key_id, k256_public_header.key_id);
518        println!("   ✓ Private and public keys produce same key_id");
519
520        // Verify other fields are None
521        println!("5. Verifying optional fields are None...");
522        for header in [
523            &p256_private_header,
524            &p256_public_header,
525            &p384_private_header,
526            &p384_public_header,
527            &k256_private_header,
528            &k256_public_header,
529        ] {
530            assert_eq!(header.type_, None);
531            assert_eq!(header.json_web_key, None);
532        }
533        println!("   ✓ type_ and json_web_key fields are None");
534
535        // Test JSON serialization
536        println!("6. Testing JSON serialization...");
537        let json = serde_json::to_string(&p384_private_header)?;
538        let parsed: Header = serde_json::from_str(&json)?;
539        assert_eq!(p384_private_header, parsed);
540        println!("   ✓ Headers serialize/deserialize correctly");
541
542        println!("=== All Header conversion tests passed! ===\n");
543
544        Ok(())
545    }
546
547    #[test]
548    fn test_header_from_p384_private_key() -> Result<()> {
549        let key_data = generate_key(KeyType::P384Private)?;
550        let header: Header = key_data.clone().try_into()?;
551
552        assert_eq!(header.algorithm, Some("ES384".to_string()));
553        assert!(header.key_id.is_some());
554        assert_eq!(header.type_, None);
555        assert_eq!(header.json_web_key, None);
556
557        // Verify key_id contains the public key DID
558        let public_key = to_public(&key_data)?;
559        let expected_key_id = format!("{}", public_key);
560        assert_eq!(header.key_id, Some(expected_key_id));
561
562        Ok(())
563    }
564
565    #[test]
566    fn test_header_from_p384_public_key() -> Result<()> {
567        let private_key = generate_key(KeyType::P384Private)?;
568        let public_key = to_public(&private_key)?;
569        let header: Header = public_key.clone().try_into()?;
570
571        assert_eq!(header.algorithm, Some("ES384".to_string()));
572        assert!(header.key_id.is_some());
573        assert_eq!(header.type_, None);
574        assert_eq!(header.json_web_key, None);
575
576        // Verify key_id contains the public key DID
577        let expected_key_id = format!("{}", public_key);
578        assert_eq!(header.key_id, Some(expected_key_id));
579
580        Ok(())
581    }
582}