atproto_oauth/
jwt.rs

1//! JSON Web Token (JWT) minting and verification for AT Protocol OAuth.
2//!
3//! Provides functionality for creating and verifying JWT tokens with JOSE standard claims and
4//! custom extensions for AT Protocol OAuth flows. Supports ES256, ES384, and ES256K elliptic curve
5//! signature algorithms with comprehensive timestamp validation and structured error handling.
6
7use crate::encoding::ToBase64;
8use crate::errors::JWTError;
9use anyhow::Result;
10use atproto_identity::key::{KeyData, KeyType, sign, to_public, validate};
11use base64::{Engine as _, engine::general_purpose};
12use elliptic_curve::JwkEcKey;
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17/// JWT header containing algorithm and key metadata.
18#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
19pub struct Header {
20    /// Algorithm used for signing (e.g., "ES256", "ES384", "ES256K").
21    #[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
22    pub algorithm: Option<String>,
23
24    /// Key identifier for the signing key.
25    #[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
26    pub key_id: Option<String>,
27
28    /// Token type, typically "JWT".
29    #[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
30    pub type_: Option<String>,
31
32    /// Embedded JSON Web Key.
33    #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
34    pub json_web_key: Option<JwkEcKey>,
35}
36
37impl TryFrom<KeyData> for Header {
38    type Error = anyhow::Error;
39
40    fn try_from(value: KeyData) -> std::result::Result<Self, Self::Error> {
41        let algorithm = match value.key_type() {
42            KeyType::P256Public => Some("ES256".to_string()),
43            KeyType::P256Private => Some("ES256".to_string()),
44            KeyType::P384Public => Some("ES384".to_string()),
45            KeyType::P384Private => Some("ES384".to_string()),
46            KeyType::K256Public => Some("ES256K".to_string()),
47            KeyType::K256Private => Some("ES256K".to_string()),
48        };
49
50        let public_key = to_public(&value)?;
51        let key_id = Some(public_key.to_string());
52
53        Ok(Self {
54            algorithm,
55            key_id,
56            type_: None,
57            json_web_key: None,
58        })
59    }
60}
61
62/// JWT claims combining standard JOSE claims with custom private claims.
63#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
64pub struct Claims {
65    /// Standard JOSE claims.
66    #[serde(flatten)]
67    pub jose: JoseClaims,
68    /// Custom private claims.
69    #[serde(flatten)]
70    pub private: BTreeMap<String, serde_json::Value>,
71}
72
73impl Claims {
74    /// Create new Claims with the given JOSE claims.
75    pub fn new(jose: JoseClaims) -> Self {
76        Claims {
77            jose,
78            private: BTreeMap::new(),
79        }
80    }
81}
82
83/// Type alias for timestamp values representing seconds since Unix epoch.
84pub type SecondsSinceEpoch = u64;
85
86/// Standard JOSE claims for JWT tokens.
87#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
88pub struct JoseClaims {
89    /// Issuer of the token.
90    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
91    pub issuer: Option<String>,
92
93    /// Subject of the token.
94    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
95    pub subject: Option<String>,
96
97    /// Intended audience for the token.
98    #[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
99    pub audience: Option<String>,
100
101    /// Expiration time after which the token is invalid.
102    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
103    pub expiration: Option<SecondsSinceEpoch>,
104
105    /// Time before which the token is not valid.
106    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
107    pub not_before: Option<SecondsSinceEpoch>,
108
109    /// Time at which the token was issued.
110    #[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
111    pub issued_at: Option<SecondsSinceEpoch>,
112
113    /// Unique identifier for the token.
114    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
115    pub json_web_token_id: Option<String>,
116
117    /// HTTP method for request binding.
118    #[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
119    pub http_method: Option<String>,
120
121    /// HTTP URI for request binding.
122    #[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
123    pub http_uri: Option<String>,
124
125    /// Nonce value for replay protection.
126    #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
127    pub nonce: Option<String>,
128
129    /// Authorization token hash.
130    #[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
131    pub auth: Option<String>,
132}
133
134/// Create and sign a new JWT token.
135pub fn mint(key_data: &KeyData, header: &Header, claims: &Claims) -> Result<String> {
136    let header = header.to_base64()?;
137    let claims = claims.to_base64()?;
138    let content = format!("{}.{}", header, claims);
139
140    let signature = sign(key_data, content.as_bytes())?;
141
142    Ok(format!(
143        "{}.{}",
144        content,
145        general_purpose::URL_SAFE_NO_PAD.encode(signature)
146    ))
147}
148
149/// Verify a JWT token and extract its claims.
150pub fn verify(token: &str, key_data: &KeyData) -> Result<Claims> {
151    // Split token into its parts
152    let parts: Vec<&str> = token.split('.').collect();
153    if parts.len() != 3 {
154        return Err(JWTError::InvalidFormat.into());
155    }
156
157    let encoded_header = parts[0];
158    let encoded_claims = parts[1];
159    let encoded_signature = parts[2];
160
161    // Decode header
162    let header_bytes = general_purpose::URL_SAFE_NO_PAD
163        .decode(encoded_header)
164        .map_err(|_| JWTError::InvalidHeader)?;
165
166    let header: Header =
167        serde_json::from_slice(&header_bytes).map_err(|_| JWTError::InvalidHeader)?;
168
169    match (header.algorithm.as_deref(), key_data.key_type()) {
170        (Some("ES256K"), KeyType::K256Private) | (Some("ES256K"), KeyType::K256Public) => {}
171        (Some("ES256"), KeyType::P256Private) | (Some("ES256"), KeyType::P256Public) => {}
172        (Some("ES384"), KeyType::P384Private) | (Some("ES384"), KeyType::P384Public) => {}
173        _ => {
174            return Err(JWTError::UnsupportedAlgorithm {
175                algorithm: header.algorithm.unwrap_or_else(|| "none".to_string()),
176                key_type: format!("{:?}", key_data.key_type()),
177            }
178            .into());
179        }
180    }
181
182    // Decode claims
183    let claims_bytes = general_purpose::URL_SAFE_NO_PAD
184        .decode(encoded_claims)
185        .map_err(|_| JWTError::InvalidClaims)?;
186
187    let claims: Claims =
188        serde_json::from_slice(&claims_bytes).map_err(|_| JWTError::InvalidClaims)?;
189
190    // Decode signature
191    let signature_bytes = general_purpose::URL_SAFE_NO_PAD
192        .decode(encoded_signature)
193        .map_err(|_| JWTError::InvalidSignature)?;
194
195    let content = format!("{}.{}", encoded_header, encoded_claims);
196
197    validate(key_data, &signature_bytes, content.as_bytes())
198        .map_err(|_| JWTError::SignatureVerificationFailed)?;
199
200    // Get current timestamp for validation
201    let now = SystemTime::now()
202        .duration_since(UNIX_EPOCH)
203        .map_err(|_| JWTError::SystemTimeError)?
204        .as_secs();
205
206    // Validate expiration time if present
207    if let Some(exp) = claims.jose.expiration {
208        if now >= exp {
209            return Err(JWTError::TokenExpired.into());
210        }
211    }
212
213    // Validate not-before time if present
214    if let Some(nbf) = claims.jose.not_before {
215        if now < nbf {
216            return Err(JWTError::TokenNotValidYet.into());
217        }
218    }
219
220    // Return validated claims
221    Ok(claims)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use atproto_identity::key::{KeyType, generate_key, identify_key, to_public};
228
229    #[test]
230    fn test_header_from_p256_private_key() -> Result<()> {
231        let key_data = generate_key(KeyType::P256Private)?;
232        let header: Header = key_data.clone().try_into()?;
233
234        assert_eq!(header.algorithm, Some("ES256".to_string()));
235        assert!(header.key_id.is_some());
236        assert_eq!(header.type_, None);
237        assert_eq!(header.json_web_key, None);
238
239        // Verify key_id contains the public key DID
240        let public_key = to_public(&key_data)?;
241        let expected_key_id = format!("{}", public_key);
242        assert_eq!(header.key_id, Some(expected_key_id));
243
244        Ok(())
245    }
246
247    #[test]
248    fn test_header_from_p256_public_key() -> Result<()> {
249        let private_key = generate_key(KeyType::P256Private)?;
250        let public_key = to_public(&private_key)?;
251        let header: Header = public_key.clone().try_into()?;
252
253        assert_eq!(header.algorithm, Some("ES256".to_string()));
254        assert!(header.key_id.is_some());
255        assert_eq!(header.type_, None);
256        assert_eq!(header.json_web_key, None);
257
258        // Verify key_id contains the public key DID
259        let expected_key_id = format!("{}", public_key);
260        assert_eq!(header.key_id, Some(expected_key_id));
261
262        Ok(())
263    }
264
265    #[test]
266    fn test_header_from_k256_private_key() -> Result<()> {
267        let key_data = generate_key(KeyType::K256Private)?;
268        let header: Header = key_data.clone().try_into()?;
269
270        assert_eq!(header.algorithm, Some("ES256K".to_string()));
271        assert!(header.key_id.is_some());
272        assert_eq!(header.type_, None);
273        assert_eq!(header.json_web_key, None);
274
275        // Verify key_id contains the public key DID
276        let public_key = to_public(&key_data)?;
277        let expected_key_id = format!("{}", public_key);
278        assert_eq!(header.key_id, Some(expected_key_id));
279
280        Ok(())
281    }
282
283    #[test]
284    fn test_header_from_k256_public_key() -> Result<()> {
285        let private_key = generate_key(KeyType::K256Private)?;
286        let public_key = to_public(&private_key)?;
287        let header: Header = public_key.clone().try_into()?;
288
289        assert_eq!(header.algorithm, Some("ES256K".to_string()));
290        assert!(header.key_id.is_some());
291        assert_eq!(header.type_, None);
292        assert_eq!(header.json_web_key, None);
293
294        // Verify key_id contains the public key DID
295        let expected_key_id = format!("{}", public_key);
296        assert_eq!(header.key_id, Some(expected_key_id));
297
298        Ok(())
299    }
300
301    #[test]
302    fn test_header_consistency_private_vs_public_key() -> Result<()> {
303        // Test that private key and its derived public key produce headers with same key_id
304        let p256_private = generate_key(KeyType::P256Private)?;
305        let p256_public = to_public(&p256_private)?;
306
307        let header_from_private: Header = p256_private.try_into()?;
308        let header_from_public: Header = p256_public.try_into()?;
309
310        assert_eq!(header_from_private.algorithm, header_from_public.algorithm);
311        assert_eq!(header_from_private.key_id, header_from_public.key_id);
312        assert_eq!(header_from_private.type_, header_from_public.type_);
313        assert_eq!(
314            header_from_private.json_web_key,
315            header_from_public.json_web_key
316        );
317
318        // Test with K256 as well
319        let k256_private = generate_key(KeyType::K256Private)?;
320        let k256_public = to_public(&k256_private)?;
321
322        let k256_header_from_private: Header = k256_private.try_into()?;
323        let k256_header_from_public: Header = k256_public.try_into()?;
324
325        assert_eq!(
326            k256_header_from_private.algorithm,
327            k256_header_from_public.algorithm
328        );
329        assert_eq!(
330            k256_header_from_private.key_id,
331            k256_header_from_public.key_id
332        );
333        assert_eq!(
334            k256_header_from_private.type_,
335            k256_header_from_public.type_
336        );
337        assert_eq!(
338            k256_header_from_private.json_web_key,
339            k256_header_from_public.json_web_key
340        );
341
342        Ok(())
343    }
344
345    #[test]
346    fn test_header_from_existing_test_keys() -> Result<()> {
347        // Test with known keys from the identity crate test suite
348        let p256_private_key = "did:key:z42tnbHmmnhF11nwSnp5kQJbcZQw2Vbw5WF3ABDSxPtDgU2o";
349        let p256_public_key = "did:key:zDnaeXduWbJ1b1Kgjf3uCdCpMDF1LEDizUiyxAxGwerou3Nh2";
350        let k256_private_key = "did:key:z3vLY4nbXy2rV4Qr65gUtfnSF3A8Be7gmYzUiCX6eo2PR1Rt";
351        let k256_public_key = "did:key:zQ3shNzMp4oaaQ1gQRzCxMGXFrSW3NEM1M9T6KCY9eA7HhyEA";
352
353        // Parse the keys
354        let parsed_p256_private = identify_key(p256_private_key)?;
355        let parsed_p256_public = identify_key(p256_public_key)?;
356        let parsed_k256_private = identify_key(k256_private_key)?;
357        let parsed_k256_public = identify_key(k256_public_key)?;
358
359        // Derive the actual public keys from the private keys for comparison
360        let derived_p256_public = to_public(&parsed_p256_private)?;
361        let derived_k256_public = to_public(&parsed_k256_private)?;
362
363        // Test P256 private key
364        let p256_private_header: Header = parsed_p256_private.try_into()?;
365        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
366        let expected_p256_key_id = format!("{}", derived_p256_public);
367        assert_eq!(p256_private_header.key_id, Some(expected_p256_key_id));
368
369        // Test P256 public key (standalone)
370        let p256_public_header: Header = parsed_p256_public.try_into()?;
371        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
372        assert_eq!(p256_public_header.key_id, Some(p256_public_key.to_string()));
373
374        // Test K256 private key
375        let k256_private_header: Header = parsed_k256_private.try_into()?;
376        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
377        let expected_k256_key_id = format!("{}", derived_k256_public);
378        assert_eq!(k256_private_header.key_id, Some(expected_k256_key_id));
379
380        // Test K256 public key (standalone)
381        let k256_public_header: Header = parsed_k256_public.try_into()?;
382        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
383        assert_eq!(k256_public_header.key_id, Some(k256_public_key.to_string()));
384
385        // Test that derived public keys produce consistent headers
386        let derived_p256_public_header: Header = derived_p256_public.try_into()?;
387        let derived_k256_public_header: Header = derived_k256_public.try_into()?;
388
389        assert_eq!(p256_private_header, derived_p256_public_header);
390        assert_eq!(k256_private_header, derived_k256_public_header);
391
392        Ok(())
393    }
394
395    #[test]
396    fn test_header_multiple_conversions_same_key() -> Result<()> {
397        // Test that multiple conversions of the same key produce identical headers
398        let key_data = generate_key(KeyType::P256Private)?;
399
400        let header1: Header = key_data.clone().try_into()?;
401        let header2: Header = key_data.try_into()?;
402
403        assert_eq!(header1, header2);
404
405        Ok(())
406    }
407
408    #[test]
409    fn test_header_different_keys_different_headers() -> Result<()> {
410        // Test that different keys produce different headers
411        let p256_key = generate_key(KeyType::P256Private)?;
412        let k256_key = generate_key(KeyType::K256Private)?;
413
414        let p256_header: Header = p256_key.try_into()?;
415        let k256_header: Header = k256_key.try_into()?;
416
417        // Algorithm should be different
418        assert_ne!(p256_header.algorithm, k256_header.algorithm);
419        assert_eq!(p256_header.algorithm, Some("ES256".to_string()));
420        assert_eq!(k256_header.algorithm, Some("ES256K".to_string()));
421
422        // Key IDs should be different (different public keys)
423        assert_ne!(p256_header.key_id, k256_header.key_id);
424
425        Ok(())
426    }
427
428    #[test]
429    fn test_header_from_invalid_key_data() {
430        // Test with invalid key data that would cause to_public() to fail
431        let invalid_key_data = KeyData::new(KeyType::P256Private, vec![0u8; 10]); // Too short
432
433        let result: Result<Header> = invalid_key_data.try_into();
434        assert!(result.is_err());
435    }
436
437    #[test]
438    fn test_header_serialization_deserialization() -> Result<()> {
439        // Test that Header can be serialized and deserialized correctly
440        let key_data = generate_key(KeyType::P256Private)?;
441        let header: Header = key_data.try_into()?;
442
443        // Serialize to JSON
444        let json = serde_json::to_string(&header)?;
445
446        // Deserialize back
447        let deserialized_header: Header = serde_json::from_str(&json)?;
448
449        assert_eq!(header, deserialized_header);
450
451        Ok(())
452    }
453
454    #[test]
455    fn test_header_json_field_names() -> Result<()> {
456        // Test that Header uses correct JSON field names (alg, kid, typ, jwk)
457        let key_data = generate_key(KeyType::P256Private)?;
458        let header: Header = key_data.try_into()?;
459
460        let json = serde_json::to_string(&header)?;
461        let json_value: serde_json::Value = serde_json::from_str(&json)?;
462
463        // Check that the correct field names are used
464        assert!(json_value.get("alg").is_some());
465        assert!(json_value.get("kid").is_some());
466        assert!(json_value.get("typ").is_none()); // Should be None and thus omitted
467        assert!(json_value.get("jwk").is_none()); // Should be None and thus omitted
468
469        // Verify values
470        assert_eq!(json_value["alg"], "ES256");
471        assert!(json_value["kid"].is_string());
472
473        Ok(())
474    }
475
476    #[test]
477    fn test_header_complete_workflow() -> Result<()> {
478        println!("\n=== Header TryFrom<KeyData> Test Workflow ===");
479
480        // Generate keys for all curves
481        println!("1. Generating test keys...");
482        let p256_private = generate_key(KeyType::P256Private)?;
483        let p384_private = generate_key(KeyType::P384Private)?;
484        let k256_private = generate_key(KeyType::K256Private)?;
485        let p256_public = to_public(&p256_private)?;
486        let p384_public = to_public(&p384_private)?;
487        let k256_public = to_public(&k256_private)?;
488
489        // Convert to headers
490        println!("2. Converting KeyData to Headers...");
491        let p256_private_header: Header = p256_private.try_into()?;
492        let p256_public_header: Header = p256_public.try_into()?;
493        let p384_private_header: Header = p384_private.try_into()?;
494        let p384_public_header: Header = p384_public.try_into()?;
495        let k256_private_header: Header = k256_private.try_into()?;
496        let k256_public_header: Header = k256_public.try_into()?;
497
498        // Verify algorithms
499        println!("3. Verifying algorithms...");
500        assert_eq!(p256_private_header.algorithm, Some("ES256".to_string()));
501        assert_eq!(p256_public_header.algorithm, Some("ES256".to_string()));
502        assert_eq!(p384_private_header.algorithm, Some("ES384".to_string()));
503        assert_eq!(p384_public_header.algorithm, Some("ES384".to_string()));
504        assert_eq!(k256_private_header.algorithm, Some("ES256K".to_string()));
505        assert_eq!(k256_public_header.algorithm, Some("ES256K".to_string()));
506        println!("   ✓ P-256 keys → ES256");
507        println!("   ✓ P-384 keys → ES384");
508        println!("   ✓ K-256 keys → ES256K");
509
510        // Verify key IDs match between private and public
511        println!("4. Verifying key ID consistency...");
512        assert_eq!(p256_private_header.key_id, p256_public_header.key_id);
513        assert_eq!(p384_private_header.key_id, p384_public_header.key_id);
514        assert_eq!(k256_private_header.key_id, k256_public_header.key_id);
515        println!("   ✓ Private and public keys produce same key_id");
516
517        // Verify other fields are None
518        println!("5. Verifying optional fields are None...");
519        for header in [
520            &p256_private_header,
521            &p256_public_header,
522            &p384_private_header,
523            &p384_public_header,
524            &k256_private_header,
525            &k256_public_header,
526        ] {
527            assert_eq!(header.type_, None);
528            assert_eq!(header.json_web_key, None);
529        }
530        println!("   ✓ type_ and json_web_key fields are None");
531
532        // Test JSON serialization
533        println!("6. Testing JSON serialization...");
534        let json = serde_json::to_string(&p384_private_header)?;
535        let parsed: Header = serde_json::from_str(&json)?;
536        assert_eq!(p384_private_header, parsed);
537        println!("   ✓ Headers serialize/deserialize correctly");
538
539        println!("=== All Header conversion tests passed! ===\n");
540
541        Ok(())
542    }
543
544    #[test]
545    fn test_header_from_p384_private_key() -> Result<()> {
546        let key_data = generate_key(KeyType::P384Private)?;
547        let header: Header = key_data.clone().try_into()?;
548
549        assert_eq!(header.algorithm, Some("ES384".to_string()));
550        assert!(header.key_id.is_some());
551        assert_eq!(header.type_, None);
552        assert_eq!(header.json_web_key, None);
553
554        // Verify key_id contains the public key DID
555        let public_key = to_public(&key_data)?;
556        let expected_key_id = format!("{}", public_key);
557        assert_eq!(header.key_id, Some(expected_key_id));
558
559        Ok(())
560    }
561
562    #[test]
563    fn test_header_from_p384_public_key() -> Result<()> {
564        let private_key = generate_key(KeyType::P384Private)?;
565        let public_key = to_public(&private_key)?;
566        let header: Header = public_key.clone().try_into()?;
567
568        assert_eq!(header.algorithm, Some("ES384".to_string()));
569        assert!(header.key_id.is_some());
570        assert_eq!(header.type_, None);
571        assert_eq!(header.json_web_key, None);
572
573        // Verify key_id contains the public key DID
574        let expected_key_id = format!("{}", public_key);
575        assert_eq!(header.key_id, Some(expected_key_id));
576
577        Ok(())
578    }
579}