1use crate::{DidDocument, DidError, DidResult, VerificationMethod};
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13use std::collections::HashMap;
14use std::sync::RwLock;
15
16type HmacSha256 = Hmac<Sha256>;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum KmsAlgorithm {
25 Ed25519,
26 EcP256,
27 EcP384,
28 Rsa2048,
29 Rsa4096,
30}
31
32impl KmsAlgorithm {
33 pub fn as_str(&self) -> &'static str {
35 match self {
36 Self::Ed25519 => "Ed25519",
37 Self::EcP256 => "EC_P256",
38 Self::EcP384 => "EC_P384",
39 Self::Rsa2048 => "RSA_2048",
40 Self::Rsa4096 => "RSA_4096",
41 }
42 }
43
44 pub fn key_size_bytes(&self) -> usize {
46 match self {
47 Self::Ed25519 => 32,
48 Self::EcP256 => 32,
49 Self::EcP384 => 48,
50 Self::Rsa2048 => 256,
51 Self::Rsa4096 => 512,
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum KeyUsage {
59 SignVerify,
60 EncryptDecrypt,
61}
62
63#[derive(Debug, Clone)]
65pub struct KmsKeyMetadata {
66 pub key_id: String,
67 pub algorithm: KmsAlgorithm,
68 pub created_at: i64,
70 pub enabled: bool,
71 pub key_usage: KeyUsage,
72}
73
74struct KmsKeyEntry {
79 metadata: KmsKeyMetadata,
80 private_key_bytes: Vec<u8>,
82}
83
84pub trait KmsBackend: Send + Sync {
90 fn create_key(&self, key_id: &str, algorithm: KmsAlgorithm) -> DidResult<KmsKeyMetadata>;
91
92 fn sign(&self, key_id: &str, data: &[u8]) -> DidResult<Vec<u8>>;
93
94 fn verify(&self, key_id: &str, data: &[u8], signature: &[u8]) -> DidResult<bool>;
95
96 fn get_public_key(&self, key_id: &str) -> DidResult<Vec<u8>>;
97
98 fn delete_key(&self, key_id: &str) -> DidResult<()>;
99
100 fn list_keys(&self) -> DidResult<Vec<KmsKeyMetadata>>;
101}
102
103fn hmac_sign(private_key_bytes: &[u8], key_id: &str, data: &[u8]) -> Vec<u8> {
110 let mut mac =
111 HmacSha256::new_from_slice(private_key_bytes).expect("HMAC accepts any key length");
112 mac.update(key_id.as_bytes());
113 mac.update(data);
114 mac.finalize().into_bytes().to_vec()
115}
116
117fn derive_key_bytes(key_id: &str, algorithm: &KmsAlgorithm) -> Vec<u8> {
119 use sha2::Digest;
120 let mut hasher = sha2::Sha256::new();
121 hasher.update(key_id.as_bytes());
122 hasher.update(algorithm.as_str().as_bytes());
123 let base = hasher.finalize().to_vec();
124 let mut key = Vec::with_capacity(algorithm.key_size_bytes());
126 let mut counter: u8 = 0;
127 while key.len() < algorithm.key_size_bytes() {
128 let mut h2 = sha2::Sha256::new();
129 h2.update(&base);
130 h2.update([counter]);
131 key.extend_from_slice(&h2.finalize());
132 counter = counter.wrapping_add(1);
133 }
134 key.truncate(algorithm.key_size_bytes());
135 key
136}
137
138fn derive_public_key(private_key_bytes: &[u8]) -> Vec<u8> {
140 let half = private_key_bytes.len() / 2;
141 let mut pub_key = private_key_bytes[..half.max(1)].to_vec();
142 pub_key.reverse();
143 pub_key
144}
145
146fn now_unix() -> i64 {
147 std::time::SystemTime::now()
148 .duration_since(std::time::UNIX_EPOCH)
149 .map(|d| d.as_secs() as i64)
150 .unwrap_or(0)
151}
152
153macro_rules! impl_mock_kms {
158 ($name:ident, $display:literal) => {
159 pub struct $name {
161 keys: RwLock<HashMap<String, KmsKeyEntry>>,
162 }
163
164 impl Default for $name {
165 fn default() -> Self {
166 Self::new()
167 }
168 }
169
170 impl $name {
171 pub fn new() -> Self {
172 Self {
173 keys: RwLock::new(HashMap::new()),
174 }
175 }
176 }
177
178 impl KmsBackend for $name {
179 fn create_key(
180 &self,
181 key_id: &str,
182 algorithm: KmsAlgorithm,
183 ) -> DidResult<KmsKeyMetadata> {
184 let mut store = self
185 .keys
186 .write()
187 .map_err(|e| DidError::InternalError(format!("KMS lock poisoned: {}", e)))?;
188
189 if store.contains_key(key_id) {
190 return Err(DidError::InvalidKey(format!(
191 "Key '{}' already exists in {} KMS",
192 key_id, $display
193 )));
194 }
195
196 let private_key_bytes = derive_key_bytes(key_id, &algorithm);
197 let metadata = KmsKeyMetadata {
198 key_id: key_id.to_string(),
199 algorithm,
200 created_at: now_unix(),
201 enabled: true,
202 key_usage: KeyUsage::SignVerify,
203 };
204
205 let entry = KmsKeyEntry {
206 metadata: metadata.clone(),
207 private_key_bytes,
208 };
209 store.insert(key_id.to_string(), entry);
210 Ok(metadata)
211 }
212
213 fn sign(&self, key_id: &str, data: &[u8]) -> DidResult<Vec<u8>> {
214 let store = self
215 .keys
216 .read()
217 .map_err(|e| DidError::InternalError(format!("KMS lock poisoned: {}", e)))?;
218 let entry = store.get(key_id).ok_or_else(|| {
219 DidError::KeyNotFound(format!("Key '{}' not found in {} KMS", key_id, $display))
220 })?;
221 if !entry.metadata.enabled {
222 return Err(DidError::SigningFailed(format!(
223 "Key '{}' is disabled",
224 key_id
225 )));
226 }
227 Ok(hmac_sign(&entry.private_key_bytes, key_id, data))
228 }
229
230 fn verify(&self, key_id: &str, data: &[u8], signature: &[u8]) -> DidResult<bool> {
231 let expected = self.sign(key_id, data)?;
232 Ok(expected == signature)
233 }
234
235 fn get_public_key(&self, key_id: &str) -> DidResult<Vec<u8>> {
236 let store = self
237 .keys
238 .read()
239 .map_err(|e| DidError::InternalError(format!("KMS lock poisoned: {}", e)))?;
240 let entry = store.get(key_id).ok_or_else(|| {
241 DidError::KeyNotFound(format!("Key '{}' not found in {} KMS", key_id, $display))
242 })?;
243 Ok(derive_public_key(&entry.private_key_bytes))
244 }
245
246 fn delete_key(&self, key_id: &str) -> DidResult<()> {
247 let mut store = self
248 .keys
249 .write()
250 .map_err(|e| DidError::InternalError(format!("KMS lock poisoned: {}", e)))?;
251 store
252 .remove(key_id)
253 .ok_or_else(|| {
254 DidError::KeyNotFound(format!(
255 "Key '{}' not found in {} KMS",
256 key_id, $display
257 ))
258 })
259 .map(|_| ())
260 }
261
262 fn list_keys(&self) -> DidResult<Vec<KmsKeyMetadata>> {
263 let store = self
264 .keys
265 .read()
266 .map_err(|e| DidError::InternalError(format!("KMS lock poisoned: {}", e)))?;
267 let mut list: Vec<KmsKeyMetadata> =
268 store.values().map(|e| e.metadata.clone()).collect();
269 list.sort_by(|a, b| a.key_id.cmp(&b.key_id));
271 Ok(list)
272 }
273 }
274 };
275}
276
277impl_mock_kms!(MockAwsKms, "AWS");
278impl_mock_kms!(MockGcpKms, "GCP");
279impl_mock_kms!(MockAzureKms, "Azure");
280
281pub enum KmsProvider {
287 MockAws,
288 MockGcp,
289 MockAzure,
290}
291
292pub fn create_mock_kms(provider: KmsProvider) -> Box<dyn KmsBackend> {
294 match provider {
295 KmsProvider::MockAws => Box::new(MockAwsKms::new()),
296 KmsProvider::MockGcp => Box::new(MockGcpKms::new()),
297 KmsProvider::MockAzure => Box::new(MockAzureKms::new()),
298 }
299}
300
301pub struct KmsDidSigner {
307 backend: Box<dyn KmsBackend>,
308 key_id: String,
309}
310
311impl KmsDidSigner {
312 pub fn new(backend: Box<dyn KmsBackend>, key_id: &str) -> Self {
313 Self {
314 backend,
315 key_id: key_id.to_string(),
316 }
317 }
318
319 pub fn create_did_document(&self, did: &str) -> DidResult<DidDocument> {
322 let public_key = self.backend.get_public_key(&self.key_id)?;
323 let key_fragment = format!("{}#kms-key-0", did);
324
325 let vm = VerificationMethod::ed25519(&key_fragment, did, &public_key);
326
327 use crate::did::document::{DidDocument as DocType, VerificationRelationship};
328 use crate::Did;
329
330 let did_obj = Did::new(did)?;
331 let mut doc = DocType::new(did_obj);
332 doc.verification_method.push(vm);
333 doc.authentication
334 .push(VerificationRelationship::Reference(key_fragment.clone()));
335 doc.assertion_method
336 .push(VerificationRelationship::Reference(key_fragment));
337
338 Ok(doc)
339 }
340
341 pub fn sign_credential(&self, credential: &serde_json::Value) -> DidResult<serde_json::Value> {
344 let serialized = serde_json::to_vec(credential)
345 .map_err(|e| DidError::SerializationError(e.to_string()))?;
346
347 let sig = self.backend.sign(&self.key_id, &serialized)?;
348 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
349 use base64::Engine;
350 let sig_b64 = URL_SAFE_NO_PAD.encode(&sig);
351
352 let mut signed = credential.clone();
353 if let Some(obj) = signed.as_object_mut() {
354 obj.insert(
355 "proof".to_string(),
356 serde_json::json!({
357 "type": "KmsHmacSignature2024",
358 "verificationMethod": self.key_id,
359 "signatureValue": sig_b64
360 }),
361 );
362 }
363 Ok(signed)
364 }
365
366 pub fn verify_credential(&self, signed_credential: &serde_json::Value) -> DidResult<bool> {
368 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
369 use base64::Engine;
370
371 let proof = signed_credential
373 .get("proof")
374 .ok_or_else(|| DidError::InvalidProof("Missing proof field".to_string()))?;
375
376 let sig_b64 = proof
377 .get("signatureValue")
378 .and_then(|v| v.as_str())
379 .ok_or_else(|| DidError::InvalidProof("Missing signatureValue".to_string()))?;
380
381 let signature = URL_SAFE_NO_PAD
382 .decode(sig_b64)
383 .map_err(|e| DidError::InvalidProof(format!("Invalid base64: {}", e)))?;
384
385 let mut without_proof = signed_credential.clone();
387 if let Some(obj) = without_proof.as_object_mut() {
388 obj.remove("proof");
389 }
390 let serialized = serde_json::to_vec(&without_proof)
391 .map_err(|e| DidError::SerializationError(e.to_string()))?;
392
393 self.backend.verify(&self.key_id, &serialized, &signature)
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
408 fn test_aws_create_ed25519_key() {
409 let kms = MockAwsKms::new();
410 let meta = kms.create_key("my-key", KmsAlgorithm::Ed25519).unwrap();
411 assert_eq!(meta.key_id, "my-key");
412 assert_eq!(meta.algorithm.as_str(), "Ed25519");
413 assert!(meta.enabled);
414 }
415
416 #[test]
417 fn test_aws_create_duplicate_key_fails() {
418 let kms = MockAwsKms::new();
419 kms.create_key("dup", KmsAlgorithm::Ed25519).unwrap();
420 assert!(kms.create_key("dup", KmsAlgorithm::Ed25519).is_err());
421 }
422
423 #[test]
424 fn test_aws_sign_and_verify() {
425 let kms = MockAwsKms::new();
426 kms.create_key("signing-key", KmsAlgorithm::EcP256).unwrap();
427
428 let data = b"hello world";
429 let sig = kms.sign("signing-key", data).unwrap();
430 assert!(!sig.is_empty());
431
432 let valid = kms.verify("signing-key", data, &sig).unwrap();
433 assert!(valid);
434 }
435
436 #[test]
437 fn test_aws_verify_wrong_data_fails() {
438 let kms = MockAwsKms::new();
439 kms.create_key("k1", KmsAlgorithm::Ed25519).unwrap();
440
441 let sig = kms.sign("k1", b"original").unwrap();
442 let valid = kms.verify("k1", b"tampered", &sig).unwrap();
443 assert!(!valid);
444 }
445
446 #[test]
447 fn test_aws_sign_missing_key_error() {
448 let kms = MockAwsKms::new();
449 assert!(kms.sign("nonexistent", b"data").is_err());
450 }
451
452 #[test]
453 fn test_aws_get_public_key() {
454 let kms = MockAwsKms::new();
455 kms.create_key("pk-key", KmsAlgorithm::Ed25519).unwrap();
456 let pub_key = kms.get_public_key("pk-key").unwrap();
457 assert!(!pub_key.is_empty());
458 }
459
460 #[test]
461 fn test_aws_delete_key() {
462 let kms = MockAwsKms::new();
463 kms.create_key("del-key", KmsAlgorithm::Ed25519).unwrap();
464 kms.delete_key("del-key").unwrap();
465 assert!(kms.sign("del-key", b"data").is_err());
466 }
467
468 #[test]
469 fn test_aws_delete_missing_key_error() {
470 let kms = MockAwsKms::new();
471 assert!(kms.delete_key("ghost").is_err());
472 }
473
474 #[test]
475 fn test_aws_list_keys() {
476 let kms = MockAwsKms::new();
477 kms.create_key("a", KmsAlgorithm::Ed25519).unwrap();
478 kms.create_key("b", KmsAlgorithm::EcP256).unwrap();
479
480 let keys = kms.list_keys().unwrap();
481 assert_eq!(keys.len(), 2);
482 assert_eq!(keys[0].key_id, "a");
484 assert_eq!(keys[1].key_id, "b");
485 }
486
487 #[test]
488 fn test_aws_all_algorithms_create() {
489 let kms = MockAwsKms::new();
490 kms.create_key("ed", KmsAlgorithm::Ed25519).unwrap();
491 kms.create_key("p256", KmsAlgorithm::EcP256).unwrap();
492 kms.create_key("p384", KmsAlgorithm::EcP384).unwrap();
493 kms.create_key("rsa2048", KmsAlgorithm::Rsa2048).unwrap();
494 kms.create_key("rsa4096", KmsAlgorithm::Rsa4096).unwrap();
495
496 let keys = kms.list_keys().unwrap();
497 assert_eq!(keys.len(), 5);
498 }
499
500 #[test]
503 fn test_gcp_create_and_sign() {
504 let kms = MockGcpKms::new();
505 kms.create_key("gcp-key", KmsAlgorithm::EcP256).unwrap();
506
507 let sig = kms.sign("gcp-key", b"gcp-data").unwrap();
508 let valid = kms.verify("gcp-key", b"gcp-data", &sig).unwrap();
509 assert!(valid);
510 }
511
512 #[test]
513 fn test_gcp_list_empty() {
514 let kms = MockGcpKms::new();
515 let keys = kms.list_keys().unwrap();
516 assert!(keys.is_empty());
517 }
518
519 #[test]
520 fn test_gcp_public_key_differs_from_private() {
521 let kms = MockGcpKms::new();
522 kms.create_key("gcp-pk", KmsAlgorithm::Ed25519).unwrap();
523
524 let pub_key = kms.get_public_key("gcp-pk").unwrap();
525 assert!(!pub_key.is_empty());
527 }
528
529 #[test]
530 fn test_gcp_delete_and_recreate() {
531 let kms = MockGcpKms::new();
532 kms.create_key("reuse", KmsAlgorithm::Ed25519).unwrap();
533 kms.delete_key("reuse").unwrap();
534 kms.create_key("reuse", KmsAlgorithm::Ed25519).unwrap();
536 }
537
538 #[test]
541 fn test_azure_create_and_sign() {
542 let kms = MockAzureKms::new();
543 kms.create_key("az-key", KmsAlgorithm::Rsa2048).unwrap();
544
545 let sig = kms.sign("az-key", b"azure-data").unwrap();
546 let valid = kms.verify("az-key", b"azure-data", &sig).unwrap();
547 assert!(valid);
548 }
549
550 #[test]
551 fn test_azure_wrong_signature() {
552 let kms = MockAzureKms::new();
553 kms.create_key("az2", KmsAlgorithm::EcP256).unwrap();
554
555 let bad_sig = vec![0u8; 32];
556 let valid = kms.verify("az2", b"some-data", &bad_sig).unwrap();
557 assert!(!valid);
558 }
559
560 #[test]
561 fn test_azure_list_after_delete() {
562 let kms = MockAzureKms::new();
563 kms.create_key("x", KmsAlgorithm::Ed25519).unwrap();
564 kms.create_key("y", KmsAlgorithm::Ed25519).unwrap();
565 kms.delete_key("x").unwrap();
566
567 let keys = kms.list_keys().unwrap();
568 assert_eq!(keys.len(), 1);
569 assert_eq!(keys[0].key_id, "y");
570 }
571
572 #[test]
575 fn test_create_mock_kms_aws() {
576 let kms = create_mock_kms(KmsProvider::MockAws);
577 kms.create_key("factory-aws", KmsAlgorithm::Ed25519)
578 .unwrap();
579 let keys = kms.list_keys().unwrap();
580 assert_eq!(keys.len(), 1);
581 }
582
583 #[test]
584 fn test_create_mock_kms_gcp() {
585 let kms = create_mock_kms(KmsProvider::MockGcp);
586 kms.create_key("factory-gcp", KmsAlgorithm::EcP256).unwrap();
587 let keys = kms.list_keys().unwrap();
588 assert_eq!(keys.len(), 1);
589 }
590
591 #[test]
592 fn test_create_mock_kms_azure() {
593 let kms = create_mock_kms(KmsProvider::MockAzure);
594 kms.create_key("factory-azure", KmsAlgorithm::Rsa2048)
595 .unwrap();
596 let keys = kms.list_keys().unwrap();
597 assert_eq!(keys.len(), 1);
598 }
599
600 #[test]
603 fn test_kms_did_signer_create_document() {
604 let backend = create_mock_kms(KmsProvider::MockAws);
605 backend
606 .create_key("did-signer-key", KmsAlgorithm::Ed25519)
607 .unwrap();
608
609 let signer = KmsDidSigner::new(backend, "did-signer-key");
610 let did_str = "did:key:z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK";
611 let doc = signer.create_did_document(did_str).unwrap();
612
613 assert_eq!(doc.id.as_str(), did_str);
614 assert_eq!(doc.verification_method.len(), 1);
615 assert!(!doc.authentication.is_empty());
616 }
617
618 #[test]
619 fn test_kms_did_signer_sign_credential() {
620 let backend = create_mock_kms(KmsProvider::MockGcp);
621 backend.create_key("vc-key", KmsAlgorithm::EcP256).unwrap();
622
623 let signer = KmsDidSigner::new(backend, "vc-key");
624 let credential = serde_json::json!({
625 "@context": ["https://www.w3.org/2018/credentials/v1"],
626 "type": ["VerifiableCredential"],
627 "issuer": "did:example:issuer",
628 "credentialSubject": { "id": "did:example:subject", "name": "Alice" }
629 });
630
631 let signed = signer.sign_credential(&credential).unwrap();
632 assert!(signed.get("proof").is_some());
633 let proof = signed.get("proof").unwrap();
634 assert_eq!(proof["type"].as_str().unwrap(), "KmsHmacSignature2024");
635 assert!(proof.get("signatureValue").is_some());
636 }
637
638 #[test]
639 fn test_kms_did_signer_verify_credential() {
640 let backend = create_mock_kms(KmsProvider::MockAzure);
641 backend
642 .create_key("verify-key", KmsAlgorithm::Ed25519)
643 .unwrap();
644
645 let signer = KmsDidSigner::new(backend, "verify-key");
646 let credential = serde_json::json!({
647 "id": "http://example.com/vc/1",
648 "type": ["VerifiableCredential"],
649 "issuer": "did:example:issuer"
650 });
651
652 let signed = signer.sign_credential(&credential).unwrap();
653 let valid = signer.verify_credential(&signed).unwrap();
654 assert!(valid);
655 }
656
657 #[test]
658 fn test_kms_did_signer_tampered_credential_fails() {
659 let backend = create_mock_kms(KmsProvider::MockAws);
660 backend
661 .create_key("tamper-key", KmsAlgorithm::EcP256)
662 .unwrap();
663
664 let signer = KmsDidSigner::new(backend, "tamper-key");
665 let credential = serde_json::json!({
666 "type": ["VerifiableCredential"],
667 "issuer": "did:example:issuer"
668 });
669
670 let mut signed = signer.sign_credential(&credential).unwrap();
671 if let Some(obj) = signed.as_object_mut() {
673 obj.insert("issuer".to_string(), serde_json::json!("did:evil:attacker"));
674 }
675
676 let valid = signer.verify_credential(&signed).unwrap();
677 assert!(!valid);
678 }
679
680 #[test]
681 fn test_kms_did_signer_missing_proof_error() {
682 let backend = create_mock_kms(KmsProvider::MockAws);
683 backend
684 .create_key("no-proof-key", KmsAlgorithm::Ed25519)
685 .unwrap();
686
687 let signer = KmsDidSigner::new(backend, "no-proof-key");
688 let credential = serde_json::json!({ "type": "VerifiableCredential" });
689 assert!(signer.verify_credential(&credential).is_err());
691 }
692
693 #[test]
694 fn test_key_metadata_fields() {
695 let kms = MockAwsKms::new();
696 let meta = kms.create_key("meta-test", KmsAlgorithm::EcP384).unwrap();
697
698 assert_eq!(meta.key_id, "meta-test");
699 assert_eq!(meta.algorithm.as_str(), "EC_P384");
700 assert!(matches!(meta.key_usage, KeyUsage::SignVerify));
701 assert!(meta.created_at > 0);
702 assert!(meta.enabled);
703 }
704
705 #[test]
706 fn test_algorithm_key_sizes() {
707 assert_eq!(KmsAlgorithm::Ed25519.key_size_bytes(), 32);
708 assert_eq!(KmsAlgorithm::EcP256.key_size_bytes(), 32);
709 assert_eq!(KmsAlgorithm::EcP384.key_size_bytes(), 48);
710 assert_eq!(KmsAlgorithm::Rsa2048.key_size_bytes(), 256);
711 assert_eq!(KmsAlgorithm::Rsa4096.key_size_bytes(), 512);
712 }
713
714 #[test]
715 fn test_signatures_are_deterministic() {
716 let kms = MockAwsKms::new();
717 kms.create_key("det-key", KmsAlgorithm::Ed25519).unwrap();
718
719 let sig1 = kms.sign("det-key", b"same data").unwrap();
720 let sig2 = kms.sign("det-key", b"same data").unwrap();
721 assert_eq!(sig1, sig2);
722 }
723
724 #[test]
725 fn test_different_keys_produce_different_signatures() {
726 let kms = MockAwsKms::new();
727 kms.create_key("key-a", KmsAlgorithm::Ed25519).unwrap();
728 kms.create_key("key-b", KmsAlgorithm::Ed25519).unwrap();
729
730 let sig_a = kms.sign("key-a", b"data").unwrap();
731 let sig_b = kms.sign("key-b", b"data").unwrap();
732 assert_ne!(sig_a, sig_b);
733 }
734}