1use base64::Engine;
11use base64::engine::general_purpose::URL_SAFE_NO_PAD;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::common::identity::{AnySignature, AnySignatureError, AnySigningKey, AnyVerifyingKey};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct JwtHeader {
24 pub alg: String,
26 pub typ: String,
28}
29
30impl JwtHeader {
31 pub fn for_signing_key(key: &AnySigningKey) -> Self {
34 Self {
35 alg: key.jwt_alg().to_string(),
36 typ: "JWT".to_string(),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct JwtClaims {
53 pub iss: String,
55 pub aud: String,
57 pub exp: i64,
59 pub iat: i64,
61 pub lxm: String,
64 pub jti: String,
66}
67
68#[derive(Debug, Error)]
78pub enum JwtError {
79 #[error("malformed compact JWT: expected three segments")]
81 MalformedCompact,
82 #[error("base64url decode failed for {segment}")]
84 Base64Decode {
85 segment: &'static str,
87 #[source]
89 source: base64::DecodeError,
90 },
91 #[error("JSON decode failed for {segment}")]
93 JsonDecode {
94 segment: &'static str,
96 #[source]
98 source: serde_json::Error,
99 },
100 #[error("JSON encode failed")]
103 JsonEncode(serde_json::Error),
104 #[error("signature was {actual} bytes; expected 64")]
106 SignatureLength {
107 actual: usize,
109 },
110 #[error("signature has invalid scalar values")]
113 InvalidSignatureScalar,
114 #[error("unsupported JWT alg `{alg}` (expected ES256 or ES256K)")]
116 UnsupportedAlg {
117 alg: String,
119 },
120 #[error("signature verification failed")]
122 SignatureVerify(#[from] AnySignatureError),
123}
124
125pub fn encode_compact(
130 header: &JwtHeader,
131 claims: &JwtClaims,
132 signer: &AnySigningKey,
133) -> Result<String, JwtError> {
134 let header_json = serde_json::to_vec(header).map_err(JwtError::JsonEncode)?;
135 let claims_json = serde_json::to_vec(claims).map_err(JwtError::JsonEncode)?;
136 let header_b64 = URL_SAFE_NO_PAD.encode(&header_json);
137 let claims_b64 = URL_SAFE_NO_PAD.encode(&claims_json);
138 let signing_input = format!("{header_b64}.{claims_b64}");
139 let sig = signer.sign(signing_input.as_bytes());
140 let sig_bytes = sig.to_jws_bytes();
141 let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes);
142 Ok(format!("{header_b64}.{claims_b64}.{sig_b64}"))
143}
144
145pub fn decode_compact(token: &str) -> Result<(JwtHeader, JwtClaims, Vec<u8>), JwtError> {
151 let parts: Vec<&str> = token.split('.').collect();
152 if parts.len() != 3 {
153 return Err(JwtError::MalformedCompact);
154 }
155 let header_b64 = parts[0];
156 let claims_b64 = parts[1];
157 let sig_b64 = parts[2];
158 let header_bytes =
159 URL_SAFE_NO_PAD
160 .decode(header_b64)
161 .map_err(|source| JwtError::Base64Decode {
162 segment: "header",
163 source,
164 })?;
165 let claims_bytes =
166 URL_SAFE_NO_PAD
167 .decode(claims_b64)
168 .map_err(|source| JwtError::Base64Decode {
169 segment: "claims",
170 source,
171 })?;
172 let sig_bytes = URL_SAFE_NO_PAD
173 .decode(sig_b64)
174 .map_err(|source| JwtError::Base64Decode {
175 segment: "signature",
176 source,
177 })?;
178 let header: JwtHeader =
179 serde_json::from_slice(&header_bytes).map_err(|source| JwtError::JsonDecode {
180 segment: "header",
181 source,
182 })?;
183 let claims: JwtClaims =
184 serde_json::from_slice(&claims_bytes).map_err(|source| JwtError::JsonDecode {
185 segment: "claims",
186 source,
187 })?;
188 Ok((header, claims, sig_bytes))
189}
190
191pub fn verify_compact(
195 token: &str,
196 vkey: &AnyVerifyingKey,
197) -> Result<(JwtHeader, JwtClaims), JwtError> {
198 let (header, claims, sig_bytes) = decode_compact(token)?;
199 let expected_alg = match vkey {
200 AnyVerifyingKey::K256(_) => "ES256K",
201 AnyVerifyingKey::P256(_) => "ES256",
202 };
203 if header.alg != expected_alg {
204 return Err(JwtError::UnsupportedAlg {
205 alg: header.alg.clone(),
206 });
207 }
208 if sig_bytes.len() != 64 {
209 return Err(JwtError::SignatureLength {
210 actual: sig_bytes.len(),
211 });
212 }
213 let sig_array: [u8; 64] = sig_bytes.as_slice().try_into().expect("len checked above");
214 let any_sig = match vkey {
215 AnyVerifyingKey::K256(_) => {
216 let sig = k256::ecdsa::Signature::from_bytes(&sig_array.into())
217 .map_err(|_| JwtError::InvalidSignatureScalar)?;
218 AnySignature::K256(sig)
219 }
220 AnyVerifyingKey::P256(_) => {
221 let sig = p256::ecdsa::Signature::from_bytes(&sig_array.into())
222 .map_err(|_| JwtError::InvalidSignatureScalar)?;
223 AnySignature::P256(sig)
224 }
225 };
226 let dot = token
228 .rfind('.')
229 .expect("three-segment token has a last dot");
230 let signing_input = &token[..dot];
231 use sha2::{Digest, Sha256};
232 let prehash: [u8; 32] = Sha256::digest(signing_input.as_bytes()).into();
233 vkey.verify_prehash(&prehash, &any_sig)?;
234 Ok((header, claims))
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use k256::ecdsa::SigningKey as K256SigningKey;
241 use p256::ecdsa::SigningKey as P256SigningKey;
242
243 #[test]
244 fn encode_decode_roundtrip_k256() {
245 let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
246 let vkey = key.verifying_key();
247 let header = JwtHeader::for_signing_key(&key);
248 let claims = JwtClaims {
249 iss: "did:web:127.0.0.1%3A5000".to_string(),
250 aud: "did:plc:test".to_string(),
251 exp: 2000000000,
252 iat: 1700000000,
253 lxm: "com.atproto.moderation.createReport".to_string(),
254 jti: "0123456789abcdef".to_string(),
255 };
256
257 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
258 let (decoded_header, decoded_claims) =
259 verify_compact(&token, &vkey).expect("verify succeeds");
260
261 assert_eq!(decoded_header.alg, "ES256K");
262 assert_eq!(decoded_claims.iss, claims.iss);
263 assert_eq!(decoded_claims.aud, claims.aud);
264 }
265
266 #[test]
267 fn encode_decode_roundtrip_p256() {
268 let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
269 let vkey = key.verifying_key();
270 let header = JwtHeader::for_signing_key(&key);
271 let claims = JwtClaims {
272 iss: "did:web:example.com".to_string(),
273 aud: "did:plc:test".to_string(),
274 exp: 2000000000,
275 iat: 1700000000,
276 lxm: "com.atproto.moderation.createReport".to_string(),
277 jti: "fedcba9876543210".to_string(),
278 };
279
280 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
281 let (decoded_header, decoded_claims) =
282 verify_compact(&token, &vkey).expect("verify succeeds");
283
284 assert_eq!(decoded_header.alg, "ES256");
285 assert_eq!(decoded_claims.aud, claims.aud);
286 }
287
288 #[test]
289 fn encode_decode_roundtrip_tampered_claims_fails() {
290 let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
291 let vkey = key.verifying_key();
292 let header = JwtHeader::for_signing_key(&key);
293 let claims = JwtClaims {
294 iss: "did:web:127.0.0.1%3A5000".to_string(),
295 aud: "did:plc:test".to_string(),
296 exp: 2000000000,
297 iat: 1700000000,
298 lxm: "com.atproto.moderation.createReport".to_string(),
299 jti: "0123456789abcdef".to_string(),
300 };
301
302 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
303 let parts: Vec<&str> = token.split('.').collect();
304 assert_eq!(parts.len(), 3);
305
306 let tampered = format!("{}.YWJj.{}", parts[0], parts[2]);
308 let result = verify_compact(&tampered, &vkey);
309 assert!(result.is_err());
310 }
311
312 #[test]
313 fn decode_compact_malformed_two_segments() {
314 let result = decode_compact("header.claims");
315 assert!(matches!(result, Err(JwtError::MalformedCompact)));
316 }
317
318 #[test]
319 fn decode_compact_malformed_four_segments() {
320 let result = decode_compact("YQ.Yg.Yw.ZA");
322 assert!(matches!(result, Err(JwtError::MalformedCompact)));
323 }
324
325 #[test]
326 fn decode_compact_invalid_base64() {
327 let result = decode_compact("!!!.claims.sig");
328 assert!(matches!(
329 result,
330 Err(JwtError::Base64Decode {
331 segment: "header",
332 ..
333 })
334 ));
335 }
336
337 #[test]
338 fn verify_compact_curve_mismatch() {
339 let k256_key =
340 AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
341 let p256_key =
342 AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
343
344 let header = JwtHeader::for_signing_key(&k256_key);
345 let claims = JwtClaims {
346 iss: "did:web:test".to_string(),
347 aud: "did:plc:test".to_string(),
348 exp: 2000000000,
349 iat: 1700000000,
350 lxm: "com.atproto.moderation.createReport".to_string(),
351 jti: "0123456789abcdef".to_string(),
352 };
353
354 let token = encode_compact(&header, &claims, &k256_key).expect("encode succeeds");
355 let p256_vkey = p256_key.verifying_key();
356
357 let result = verify_compact(&token, &p256_vkey);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn encode_compact_produces_valid_structure() {
364 let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
365 let header = JwtHeader::for_signing_key(&key);
366 let claims = JwtClaims {
367 iss: "did:web:test".to_string(),
368 aud: "did:plc:test".to_string(),
369 exp: 2000000000,
370 iat: 1700000000,
371 lxm: "com.atproto.moderation.createReport".to_string(),
372 jti: "0123456789abcdef".to_string(),
373 };
374
375 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
376
377 let parts: Vec<&str> = token.split('.').collect();
379 assert_eq!(parts.len(), 3);
380
381 for (i, segment) in parts.iter().enumerate() {
383 let segment_name = ["header", "claims", "signature"][i];
384 let result = URL_SAFE_NO_PAD.decode(segment);
385 assert!(
386 result.is_ok(),
387 "segment {segment_name} failed to decode as base64url"
388 );
389 }
390 }
391
392 #[test]
393 fn verify_compact_invalid_signature_scalar_k256() {
394 let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
395 let vkey = key.verifying_key();
396 let header = JwtHeader::for_signing_key(&key);
397 let claims = JwtClaims {
398 iss: "did:web:127.0.0.1%3A5000".to_string(),
399 aud: "did:plc:test".to_string(),
400 exp: 2000000000,
401 iat: 1700000000,
402 lxm: "com.atproto.moderation.createReport".to_string(),
403 jti: "0123456789abcdef".to_string(),
404 };
405
406 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
407 let parts: Vec<&str> = token.split('.').collect();
408 assert_eq!(parts.len(), 3);
409
410 let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
412 let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
413
414 let result = verify_compact(&tampered, &vkey);
415 assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
416 }
417
418 #[test]
419 fn verify_compact_invalid_signature_scalar_p256() {
420 let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
421 let vkey = key.verifying_key();
422 let header = JwtHeader::for_signing_key(&key);
423 let claims = JwtClaims {
424 iss: "did:web:example.com".to_string(),
425 aud: "did:plc:test".to_string(),
426 exp: 2000000000,
427 iat: 1700000000,
428 lxm: "com.atproto.moderation.createReport".to_string(),
429 jti: "fedcba9876543210".to_string(),
430 };
431
432 let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
433 let parts: Vec<&str> = token.split('.').collect();
434 assert_eq!(parts.len(), 3);
435
436 let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
438 let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
439
440 let result = verify_compact(&tampered, &vkey);
441 assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
442 }
443}