1use crate::{DidError, DidResult};
9use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum JwsAlgorithm {
18 EdDSA,
20 Es256,
22 Es384,
24 Rs256,
26 Ps256,
28}
29
30impl JwsAlgorithm {
31 pub fn as_str(&self) -> &'static str {
33 match self {
34 Self::EdDSA => "EdDSA",
35 Self::Es256 => "ES256",
36 Self::Es384 => "ES384",
37 Self::Rs256 => "RS256",
38 Self::Ps256 => "PS256",
39 }
40 }
41
42 pub fn parse(s: &str) -> Option<Self> {
44 match s {
45 "EdDSA" => Some(Self::EdDSA),
46 "ES256" => Some(Self::Es256),
47 "ES384" => Some(Self::Es384),
48 "RS256" => Some(Self::Rs256),
49 "PS256" => Some(Self::Ps256),
50 _ => None,
51 }
52 }
53
54 pub fn is_fixed_length(&self) -> bool {
56 matches!(self, Self::EdDSA | Self::Es256 | Self::Es384)
57 }
58}
59
60impl std::fmt::Display for JwsAlgorithm {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.write_str(self.as_str())
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct JwsHeader {
71 pub alg: JwsAlgorithm,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub kid: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub typ: Option<String>,
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub b64: Option<bool>,
82}
83
84impl JwsHeader {
85 pub fn ed_dsa(kid: Option<&str>) -> Self {
87 Self {
88 alg: JwsAlgorithm::EdDSA,
89 kid: kid.map(String::from),
90 typ: Some("JWT".to_string()),
91 b64: None,
92 }
93 }
94
95 pub fn es256(kid: Option<&str>) -> Self {
97 Self {
98 alg: JwsAlgorithm::Es256,
99 kid: kid.map(String::from),
100 typ: Some("JWT".to_string()),
101 b64: None,
102 }
103 }
104
105 pub fn rs256(kid: Option<&str>) -> Self {
107 Self {
108 alg: JwsAlgorithm::Rs256,
109 kid: kid.map(String::from),
110 typ: Some("JWT".to_string()),
111 b64: None,
112 }
113 }
114
115 pub fn encode(&self) -> DidResult<String> {
117 let json =
118 serde_json::to_string(self).map_err(|e| DidError::SerializationError(e.to_string()))?;
119 Ok(URL_SAFE_NO_PAD.encode(json.as_bytes()))
120 }
121
122 pub fn decode(encoded: &str) -> DidResult<Self> {
124 let bytes = URL_SAFE_NO_PAD
125 .decode(encoded)
126 .map_err(|e| DidError::InvalidProof(format!("JwsHeader base64url decode: {e}")))?;
127 serde_json::from_slice(&bytes)
128 .map_err(|e| DidError::InvalidProof(format!("JwsHeader JSON parse: {e}")))
129 }
130}
131
132pub trait JwsSigner: Send + Sync {
136 fn sign(&self, payload: &[u8]) -> DidResult<Vec<u8>>;
138
139 fn algorithm(&self) -> JwsAlgorithm;
141}
142
143pub trait JwsVerifier: Send + Sync {
145 fn verify(&self, payload: &[u8], signature: &[u8]) -> DidResult<bool>;
147
148 fn algorithm(&self) -> JwsAlgorithm;
150}
151
152pub struct MockJwsSigner {
159 key: Vec<u8>,
161 pub kid: Option<String>,
163}
164
165impl MockJwsSigner {
166 pub fn new(key: impl Into<Vec<u8>>, kid: Option<&str>) -> Self {
168 Self {
169 key: key.into(),
170 kid: kid.map(String::from),
171 }
172 }
173
174 pub fn test_key() -> Self {
176 Self::new(vec![0u8; 32], Some("mock-key-1"))
177 }
178
179 pub fn hmac_sha256(&self, payload: &[u8]) -> Vec<u8> {
181 let mut hasher = Sha256::new();
183 hasher.update(&self.key);
184 hasher.update(payload);
185 hasher.finalize().to_vec()
186 }
187}
188
189impl JwsSigner for MockJwsSigner {
190 fn sign(&self, payload: &[u8]) -> DidResult<Vec<u8>> {
191 Ok(self.hmac_sha256(payload))
192 }
193
194 fn algorithm(&self) -> JwsAlgorithm {
195 JwsAlgorithm::EdDSA }
197}
198
199pub struct MockJwsVerifier {
201 key: Vec<u8>,
202}
203
204impl MockJwsVerifier {
205 pub fn from_signer(signer: &MockJwsSigner) -> Self {
207 Self {
208 key: signer.key.clone(),
209 }
210 }
211
212 pub fn test_key() -> Self {
214 Self { key: vec![0u8; 32] }
215 }
216}
217
218impl JwsVerifier for MockJwsVerifier {
219 fn verify(&self, payload: &[u8], signature: &[u8]) -> DidResult<bool> {
220 let mut hasher = Sha256::new();
221 hasher.update(&self.key);
222 hasher.update(payload);
223 let expected = hasher.finalize();
224 Ok(expected.as_slice() == signature)
225 }
226
227 fn algorithm(&self) -> JwsAlgorithm {
228 JwsAlgorithm::EdDSA
229 }
230}
231
232#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
241 fn test_algorithm_as_str() {
242 assert_eq!(JwsAlgorithm::EdDSA.as_str(), "EdDSA");
243 assert_eq!(JwsAlgorithm::Es256.as_str(), "ES256");
244 assert_eq!(JwsAlgorithm::Es384.as_str(), "ES384");
245 assert_eq!(JwsAlgorithm::Rs256.as_str(), "RS256");
246 assert_eq!(JwsAlgorithm::Ps256.as_str(), "PS256");
247 }
248
249 #[test]
250 fn test_algorithm_parse_valid() {
251 assert_eq!(JwsAlgorithm::parse("EdDSA"), Some(JwsAlgorithm::EdDSA));
252 assert_eq!(JwsAlgorithm::parse("ES256"), Some(JwsAlgorithm::Es256));
253 assert_eq!(JwsAlgorithm::parse("ES384"), Some(JwsAlgorithm::Es384));
254 assert_eq!(JwsAlgorithm::parse("RS256"), Some(JwsAlgorithm::Rs256));
255 assert_eq!(JwsAlgorithm::parse("PS256"), Some(JwsAlgorithm::Ps256));
256 }
257
258 #[test]
259 fn test_algorithm_parse_invalid() {
260 assert!(JwsAlgorithm::parse("HS256").is_none());
261 assert!(JwsAlgorithm::parse("").is_none());
262 assert!(JwsAlgorithm::parse("edDSA").is_none()); }
264
265 #[test]
266 fn test_algorithm_display() {
267 assert_eq!(format!("{}", JwsAlgorithm::Es256), "ES256");
268 }
269
270 #[test]
271 fn test_algorithm_is_fixed_length() {
272 assert!(JwsAlgorithm::EdDSA.is_fixed_length());
273 assert!(JwsAlgorithm::Es256.is_fixed_length());
274 assert!(JwsAlgorithm::Es384.is_fixed_length());
275 assert!(!JwsAlgorithm::Rs256.is_fixed_length());
276 assert!(!JwsAlgorithm::Ps256.is_fixed_length());
277 }
278
279 #[test]
280 fn test_algorithm_roundtrip_via_str() {
281 let algs = [
282 JwsAlgorithm::EdDSA,
283 JwsAlgorithm::Es256,
284 JwsAlgorithm::Es384,
285 JwsAlgorithm::Rs256,
286 JwsAlgorithm::Ps256,
287 ];
288 for alg in &algs {
289 let s = alg.as_str();
290 let parsed = JwsAlgorithm::parse(s).unwrap();
291 assert_eq!(parsed, *alg);
292 }
293 }
294
295 #[test]
298 fn test_header_ed_dsa() {
299 let h = JwsHeader::ed_dsa(Some("key-1"));
300 assert_eq!(h.alg, JwsAlgorithm::EdDSA);
301 assert_eq!(h.kid, Some("key-1".to_string()));
302 assert_eq!(h.typ, Some("JWT".to_string()));
303 assert!(h.b64.is_none());
304 }
305
306 #[test]
307 fn test_header_es256() {
308 let h = JwsHeader::es256(None);
309 assert_eq!(h.alg, JwsAlgorithm::Es256);
310 assert!(h.kid.is_none());
311 }
312
313 #[test]
314 fn test_header_rs256() {
315 let h = JwsHeader::rs256(Some("rsa-key"));
316 assert_eq!(h.alg, JwsAlgorithm::Rs256);
317 assert_eq!(h.kid, Some("rsa-key".to_string()));
318 }
319
320 #[test]
321 fn test_header_encode_decode_roundtrip() {
322 let h = JwsHeader::ed_dsa(Some("did:key:z123#key-1"));
323 let encoded = h.encode().unwrap();
324 let decoded = JwsHeader::decode(&encoded).unwrap();
325 assert_eq!(h, decoded);
326 }
327
328 #[test]
329 fn test_header_decode_invalid_base64() {
330 assert!(JwsHeader::decode("!!!invalid!!!").is_err());
331 }
332
333 #[test]
334 fn test_header_encode_contains_alg() {
335 let h = JwsHeader::es256(None);
336 let encoded = h.encode().unwrap();
337 let decoded = JwsHeader::decode(&encoded).unwrap();
339 assert_eq!(decoded.alg.as_str(), "ES256");
340 }
341
342 #[test]
343 fn test_header_b64_field_serialised() {
344 let h = JwsHeader {
345 alg: JwsAlgorithm::EdDSA,
346 kid: None,
347 typ: None,
348 b64: Some(false),
349 };
350 let encoded = h.encode().unwrap();
351 let decoded = JwsHeader::decode(&encoded).unwrap();
352 assert_eq!(decoded.b64, Some(false));
353 }
354
355 #[test]
358 fn test_mock_signer_deterministic() {
359 let signer = MockJwsSigner::test_key();
360 let payload = b"hello world";
361 let sig1 = signer.sign(payload).unwrap();
362 let sig2 = signer.sign(payload).unwrap();
363 assert_eq!(sig1, sig2);
364 }
365
366 #[test]
367 fn test_mock_signer_different_payload_different_sig() {
368 let signer = MockJwsSigner::test_key();
369 let sig1 = signer.sign(b"payload one").unwrap();
370 let sig2 = signer.sign(b"payload two").unwrap();
371 assert_ne!(sig1, sig2);
372 }
373
374 #[test]
375 fn test_mock_signer_sign_verify_roundtrip() {
376 let signer = MockJwsSigner::test_key();
377 let verifier = MockJwsVerifier::test_key();
378 let payload = b"test payload for verification";
379 let sig = signer.sign(payload).unwrap();
380 assert!(verifier.verify(payload, &sig).unwrap());
381 }
382
383 #[test]
384 fn test_mock_verifier_wrong_payload_fails() {
385 let signer = MockJwsSigner::test_key();
386 let verifier = MockJwsVerifier::test_key();
387 let sig = signer.sign(b"original").unwrap();
388 assert!(!verifier.verify(b"tampered", &sig).unwrap());
389 }
390
391 #[test]
392 fn test_mock_verifier_wrong_key_fails() {
393 let signer = MockJwsSigner::new([1u8; 32].to_vec(), None);
394 let sig = signer.sign(b"data").unwrap();
395 let wrong_verifier = MockJwsVerifier { key: vec![2u8; 32] };
396 assert!(!wrong_verifier.verify(b"data", &sig).unwrap());
397 }
398
399 #[test]
400 fn test_mock_signer_algorithm() {
401 let signer = MockJwsSigner::test_key();
402 assert_eq!(signer.algorithm(), JwsAlgorithm::EdDSA);
403 }
404
405 #[test]
406 fn test_mock_verifier_from_signer() {
407 let signer = MockJwsSigner::new(vec![42u8; 32], Some("test-kid"));
408 let verifier = MockJwsVerifier::from_signer(&signer);
409 let sig = signer.sign(b"hello").unwrap();
410 assert!(verifier.verify(b"hello", &sig).unwrap());
411 }
412
413 #[test]
414 fn test_mock_signer_kid_field() {
415 let signer = MockJwsSigner::new(vec![0u8; 32], Some("did:key:z123#k1"));
416 assert_eq!(signer.kid, Some("did:key:z123#k1".to_string()));
417 }
418
419 #[test]
420 fn test_mock_signer_test_key_has_kid() {
421 let signer = MockJwsSigner::test_key();
422 assert!(signer.kid.is_some());
423 }
424
425 #[test]
426 fn test_mock_signer_signature_length_is_32() {
427 let signer = MockJwsSigner::test_key();
428 let sig = signer.sign(b"data").unwrap();
429 assert_eq!(sig.len(), 32); }
431}