Skip to main content

atproto_devtool/common/
jwt.rs

1//! Minimal hand-rolled JWT (RFC 7515 compact JWS) encoder and decoder for
2//! atproto service-auth.
3//!
4//! This module exists to avoid pulling a full JWT library for a handful of
5//! tightly-scoped use cases: minting self-mint JWTs for labeler conformance
6//! tests, and decoding them in tests to verify round-trip correctness.
7//! Only ES256 and ES256K are supported (RFC 7518 §3.4); raw r||s signature
8//! encoding, unpadded base64url segments, UTF-8 JSON payloads.
9
10use base64::Engine;
11use base64::engine::general_purpose::URL_SAFE_NO_PAD;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::common::identity::{AnySignature, AnySignatureError, AnySigningKey, AnyVerifyingKey};
16
17/// Compact JWS header for atproto service-auth tokens.
18///
19/// Field names map 1:1 to the JWS wire format (RFC 7515 §4.1). Do NOT
20/// rename without updating `serde` attributes — atproto wire format
21/// requires exactly `alg` and `typ`.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct JwtHeader {
24    /// Algorithm identifier: "ES256K" (secp256k1) or "ES256" (P-256).
25    pub alg: String,
26    /// Token type; always "JWT".
27    pub typ: String,
28}
29
30impl JwtHeader {
31    /// Build a header for the given signing key, setting `alg` to match the
32    /// curve and `typ` to `"JWT"`.
33    pub fn for_signing_key(key: &AnySigningKey) -> Self {
34        Self {
35            alg: key.jwt_alg().to_string(),
36            typ: "JWT".to_string(),
37        }
38    }
39}
40
41/// Atproto service-auth JWT claims.
42///
43/// Fields match the atproto inter-service authentication spec:
44/// <https://atproto.com/specs/xrpc#inter-service-authentication>. `nbf` is
45/// deliberately omitted — the spec does not require it and some servers
46/// reject unexpected claims.
47///
48/// **Field names are wire-format-critical:** `iss`, `aud`, `exp`, `iat`,
49/// `lxm`, `jti` are the exact JSON keys atproto labelers expect. Do NOT
50/// rename without adding `#[serde(rename = "...")]` attributes.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct JwtClaims {
53    /// Issuer DID (e.g., `did:web:127.0.0.1%3A5000`).
54    pub iss: String,
55    /// Audience — the target service's DID, bare (no `#fragment`).
56    pub aud: String,
57    /// Expiration, UNIX seconds.
58    pub exp: i64,
59    /// Issued-at, UNIX seconds.
60    pub iat: i64,
61    /// Lexicon method NSID the token authorizes (e.g.,
62    /// `com.atproto.moderation.createReport`).
63    pub lxm: String,
64    /// Random nonce to prevent replay — hex string, 32 chars (16 bytes).
65    pub jti: String,
66}
67
68/// Errors from JWT encode/decode.
69///
70/// **Not user-rendered:** these errors only surface inside tests and
71/// library helpers. They deliberately do NOT derive `miette::Diagnostic`
72/// with stable codes — the stage converts any failure into a
73/// `CreateReportStageError::Transport` or a specific check SpecViolation
74/// before rendering. If a future caller needs one of these variants
75/// rendered to the user, they must wrap it in a stage-local diagnostic
76/// with a proper `code = "labeler::..."` string.
77#[derive(Debug, Error)]
78pub enum JwtError {
79    /// Compact form was not three `.`-separated base64url segments.
80    #[error("malformed compact JWT: expected three segments")]
81    MalformedCompact,
82    /// A base64url segment failed to decode.
83    #[error("base64url decode failed for {segment}")]
84    Base64Decode {
85        /// Which segment failed: "header", "claims", or "signature".
86        segment: &'static str,
87        /// Underlying base64 error.
88        #[source]
89        source: base64::DecodeError,
90    },
91    /// A segment decoded to valid bytes but invalid JSON.
92    #[error("JSON decode failed for {segment}")]
93    JsonDecode {
94        /// Which segment failed: "header" or "claims".
95        segment: &'static str,
96        /// Underlying serde_json error.
97        #[source]
98        source: serde_json::Error,
99    },
100    /// JSON serialization of header or claims failed (should not happen for
101    /// well-formed structs).
102    #[error("JSON encode failed")]
103    JsonEncode(serde_json::Error),
104    /// Signature was not exactly 64 bytes.
105    #[error("signature was {actual} bytes; expected 64")]
106    SignatureLength {
107        /// Actual length in bytes.
108        actual: usize,
109    },
110    /// Signature had the correct length but invalid scalar values (e.g., r or s
111    /// is 0 or exceeds the curve order).
112    #[error("signature has invalid scalar values")]
113    InvalidSignatureScalar,
114    /// The algorithm identifier in the header is not recognized.
115    #[error("unsupported JWT alg `{alg}` (expected ES256 or ES256K)")]
116    UnsupportedAlg {
117        /// The unrecognized algorithm string.
118        alg: String,
119    },
120    /// Underlying ECDSA verification failure (e.g., curve mismatch).
121    #[error("signature verification failed")]
122    SignatureVerify(#[from] AnySignatureError),
123}
124
125/// Encode a JWT in compact form: `base64url(header).base64url(claims).base64url(signature)`.
126///
127/// Signs the concatenation `header_b64 + "." + claims_b64` with SHA-256
128/// prehash under the supplied key. Returns the full compact token string.
129pub fn encode_compact(
130    header: &JwtHeader,
131    claims: &JwtClaims,
132    signer: &AnySigningKey,
133) -> Result<String, JwtError> {
134    let header_json = serde_json::to_vec(header).map_err(JwtError::JsonEncode)?;
135    let claims_json = serde_json::to_vec(claims).map_err(JwtError::JsonEncode)?;
136    let header_b64 = URL_SAFE_NO_PAD.encode(&header_json);
137    let claims_b64 = URL_SAFE_NO_PAD.encode(&claims_json);
138    let signing_input = format!("{header_b64}.{claims_b64}");
139    let sig = signer.sign(signing_input.as_bytes());
140    let sig_bytes = sig.to_jws_bytes();
141    let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes);
142    Ok(format!("{header_b64}.{claims_b64}.{sig_b64}"))
143}
144
145/// Decode a compact JWT into `(header, claims, signature_bytes)`.
146///
147/// Does NOT verify the signature — use `verify_compact` for that. This helper
148/// is primarily for test round-tripping and for negative-test assertions
149/// (e.g., "the minted token has the expected `alg` header").
150pub fn decode_compact(token: &str) -> Result<(JwtHeader, JwtClaims, Vec<u8>), JwtError> {
151    let parts: Vec<&str> = token.split('.').collect();
152    if parts.len() != 3 {
153        return Err(JwtError::MalformedCompact);
154    }
155    let header_b64 = parts[0];
156    let claims_b64 = parts[1];
157    let sig_b64 = parts[2];
158    let header_bytes =
159        URL_SAFE_NO_PAD
160            .decode(header_b64)
161            .map_err(|source| JwtError::Base64Decode {
162                segment: "header",
163                source,
164            })?;
165    let claims_bytes =
166        URL_SAFE_NO_PAD
167            .decode(claims_b64)
168            .map_err(|source| JwtError::Base64Decode {
169                segment: "claims",
170                source,
171            })?;
172    let sig_bytes = URL_SAFE_NO_PAD
173        .decode(sig_b64)
174        .map_err(|source| JwtError::Base64Decode {
175            segment: "signature",
176            source,
177        })?;
178    let header: JwtHeader =
179        serde_json::from_slice(&header_bytes).map_err(|source| JwtError::JsonDecode {
180            segment: "header",
181            source,
182        })?;
183    let claims: JwtClaims =
184        serde_json::from_slice(&claims_bytes).map_err(|source| JwtError::JsonDecode {
185            segment: "claims",
186            source,
187        })?;
188    Ok((header, claims, sig_bytes))
189}
190
191/// Verify a compact JWT against the given verifying key. Does NOT check
192/// claim values (exp/aud/lxm) — that is the labeler's job in production,
193/// or the stage's assertion job in tests. Only verifies the signature.
194pub fn verify_compact(
195    token: &str,
196    vkey: &AnyVerifyingKey,
197) -> Result<(JwtHeader, JwtClaims), JwtError> {
198    let (header, claims, sig_bytes) = decode_compact(token)?;
199    let expected_alg = match vkey {
200        AnyVerifyingKey::K256(_) => "ES256K",
201        AnyVerifyingKey::P256(_) => "ES256",
202    };
203    if header.alg != expected_alg {
204        return Err(JwtError::UnsupportedAlg {
205            alg: header.alg.clone(),
206        });
207    }
208    if sig_bytes.len() != 64 {
209        return Err(JwtError::SignatureLength {
210            actual: sig_bytes.len(),
211        });
212    }
213    let sig_array: [u8; 64] = sig_bytes.as_slice().try_into().expect("len checked above");
214    let any_sig = match vkey {
215        AnyVerifyingKey::K256(_) => {
216            let sig = k256::ecdsa::Signature::from_bytes(&sig_array.into())
217                .map_err(|_| JwtError::InvalidSignatureScalar)?;
218            AnySignature::K256(sig)
219        }
220        AnyVerifyingKey::P256(_) => {
221            let sig = p256::ecdsa::Signature::from_bytes(&sig_array.into())
222                .map_err(|_| JwtError::InvalidSignatureScalar)?;
223            AnySignature::P256(sig)
224        }
225    };
226    // Recompute the signing input and verify.
227    let dot = token
228        .rfind('.')
229        .expect("three-segment token has a last dot");
230    let signing_input = &token[..dot];
231    use sha2::{Digest, Sha256};
232    let prehash: [u8; 32] = Sha256::digest(signing_input.as_bytes()).into();
233    vkey.verify_prehash(&prehash, &any_sig)?;
234    Ok((header, claims))
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use k256::ecdsa::SigningKey as K256SigningKey;
241    use p256::ecdsa::SigningKey as P256SigningKey;
242
243    #[test]
244    fn encode_decode_roundtrip_k256() {
245        let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
246        let vkey = key.verifying_key();
247        let header = JwtHeader::for_signing_key(&key);
248        let claims = JwtClaims {
249            iss: "did:web:127.0.0.1%3A5000".to_string(),
250            aud: "did:plc:test".to_string(),
251            exp: 2000000000,
252            iat: 1700000000,
253            lxm: "com.atproto.moderation.createReport".to_string(),
254            jti: "0123456789abcdef".to_string(),
255        };
256
257        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
258        let (decoded_header, decoded_claims) =
259            verify_compact(&token, &vkey).expect("verify succeeds");
260
261        assert_eq!(decoded_header.alg, "ES256K");
262        assert_eq!(decoded_claims.iss, claims.iss);
263        assert_eq!(decoded_claims.aud, claims.aud);
264    }
265
266    #[test]
267    fn encode_decode_roundtrip_p256() {
268        let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
269        let vkey = key.verifying_key();
270        let header = JwtHeader::for_signing_key(&key);
271        let claims = JwtClaims {
272            iss: "did:web:example.com".to_string(),
273            aud: "did:plc:test".to_string(),
274            exp: 2000000000,
275            iat: 1700000000,
276            lxm: "com.atproto.moderation.createReport".to_string(),
277            jti: "fedcba9876543210".to_string(),
278        };
279
280        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
281        let (decoded_header, decoded_claims) =
282            verify_compact(&token, &vkey).expect("verify succeeds");
283
284        assert_eq!(decoded_header.alg, "ES256");
285        assert_eq!(decoded_claims.aud, claims.aud);
286    }
287
288    #[test]
289    fn encode_decode_roundtrip_tampered_claims_fails() {
290        let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
291        let vkey = key.verifying_key();
292        let header = JwtHeader::for_signing_key(&key);
293        let claims = JwtClaims {
294            iss: "did:web:127.0.0.1%3A5000".to_string(),
295            aud: "did:plc:test".to_string(),
296            exp: 2000000000,
297            iat: 1700000000,
298            lxm: "com.atproto.moderation.createReport".to_string(),
299            jti: "0123456789abcdef".to_string(),
300        };
301
302        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
303        let parts: Vec<&str> = token.split('.').collect();
304        assert_eq!(parts.len(), 3);
305
306        // Tamper with claims segment.
307        let tampered = format!("{}.YWJj.{}", parts[0], parts[2]);
308        let result = verify_compact(&tampered, &vkey);
309        assert!(result.is_err());
310    }
311
312    #[test]
313    fn decode_compact_malformed_two_segments() {
314        let result = decode_compact("header.claims");
315        assert!(matches!(result, Err(JwtError::MalformedCompact)));
316    }
317
318    #[test]
319    fn decode_compact_malformed_four_segments() {
320        // Tokens with four or more segments are malformed.
321        let result = decode_compact("YQ.Yg.Yw.ZA");
322        assert!(matches!(result, Err(JwtError::MalformedCompact)));
323    }
324
325    #[test]
326    fn decode_compact_invalid_base64() {
327        let result = decode_compact("!!!.claims.sig");
328        assert!(matches!(
329            result,
330            Err(JwtError::Base64Decode {
331                segment: "header",
332                ..
333            })
334        ));
335    }
336
337    #[test]
338    fn verify_compact_curve_mismatch() {
339        let k256_key =
340            AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
341        let p256_key =
342            AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
343
344        let header = JwtHeader::for_signing_key(&k256_key);
345        let claims = JwtClaims {
346            iss: "did:web:test".to_string(),
347            aud: "did:plc:test".to_string(),
348            exp: 2000000000,
349            iat: 1700000000,
350            lxm: "com.atproto.moderation.createReport".to_string(),
351            jti: "0123456789abcdef".to_string(),
352        };
353
354        let token = encode_compact(&header, &claims, &k256_key).expect("encode succeeds");
355        let p256_vkey = p256_key.verifying_key();
356
357        // Trying to verify a K256-signed token with a P256 key should fail.
358        let result = verify_compact(&token, &p256_vkey);
359        assert!(result.is_err());
360    }
361
362    #[test]
363    fn encode_compact_produces_valid_structure() {
364        let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
365        let header = JwtHeader::for_signing_key(&key);
366        let claims = JwtClaims {
367            iss: "did:web:test".to_string(),
368            aud: "did:plc:test".to_string(),
369            exp: 2000000000,
370            iat: 1700000000,
371            lxm: "com.atproto.moderation.createReport".to_string(),
372            jti: "0123456789abcdef".to_string(),
373        };
374
375        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
376
377        // Token must have exactly 3 segments.
378        let parts: Vec<&str> = token.split('.').collect();
379        assert_eq!(parts.len(), 3);
380
381        // Each segment must decode as valid base64url.
382        for (i, segment) in parts.iter().enumerate() {
383            let segment_name = ["header", "claims", "signature"][i];
384            let result = URL_SAFE_NO_PAD.decode(segment);
385            assert!(
386                result.is_ok(),
387                "segment {segment_name} failed to decode as base64url"
388            );
389        }
390    }
391
392    #[test]
393    fn verify_compact_invalid_signature_scalar_k256() {
394        let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
395        let vkey = key.verifying_key();
396        let header = JwtHeader::for_signing_key(&key);
397        let claims = JwtClaims {
398            iss: "did:web:127.0.0.1%3A5000".to_string(),
399            aud: "did:plc:test".to_string(),
400            exp: 2000000000,
401            iat: 1700000000,
402            lxm: "com.atproto.moderation.createReport".to_string(),
403            jti: "0123456789abcdef".to_string(),
404        };
405
406        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
407        let parts: Vec<&str> = token.split('.').collect();
408        assert_eq!(parts.len(), 3);
409
410        // Replace the signature with all zeros (64 bytes, base64url-encoded).
411        let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
412        let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
413
414        let result = verify_compact(&tampered, &vkey);
415        assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
416    }
417
418    #[test]
419    fn verify_compact_invalid_signature_scalar_p256() {
420        let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
421        let vkey = key.verifying_key();
422        let header = JwtHeader::for_signing_key(&key);
423        let claims = JwtClaims {
424            iss: "did:web:example.com".to_string(),
425            aud: "did:plc:test".to_string(),
426            exp: 2000000000,
427            iat: 1700000000,
428            lxm: "com.atproto.moderation.createReport".to_string(),
429            jti: "fedcba9876543210".to_string(),
430        };
431
432        let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
433        let parts: Vec<&str> = token.split('.').collect();
434        assert_eq!(parts.len(), 3);
435
436        // Replace the signature with all zeros (64 bytes, base64url-encoded).
437        let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
438        let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
439
440        let result = verify_compact(&tampered, &vkey);
441        assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
442    }
443}