1use crate::{ tags, EncryptedMessage, Nonce, Salt };
6use bc_crypto::{ hash::hkdf_hmac_sha512, hkdf_hmac_sha256, pbkdf2_hmac_sha256, scrypt_opt };
7use anyhow::{ Result, Error };
8use dcbor::prelude::*;
9
10use super::SymmetricKey;
11
12const SALT_LEN: usize = 16;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct EncryptedKey {
55 params: DerivationParams,
56 encrypted_key: EncryptedMessage,
57}
58
59impl EncryptedKey {
60 pub fn lock(
61 method: KeyDerivationMethod,
62 secret: impl Into<Vec<u8>>,
63 content_key: &SymmetricKey
64 ) -> Self {
65 match method {
66 KeyDerivationMethod::HKDF => {
67 let params = HKDF::new();
68 let encrypted_key = params.lock(content_key, secret);
69 Self { params: DerivationParams::HKDF(params), encrypted_key }
70 }
71 KeyDerivationMethod::PBKDF2 => {
72 let params = PBKDF2::new();
73 let encrypted_key = params.lock(content_key, secret);
74 Self { params: DerivationParams::PBKDF2(params), encrypted_key }
75 }
76 KeyDerivationMethod::Scrypt => {
77 let params = Scrypt::new();
78 let encrypted_key = params.lock(content_key, secret);
79 Self { params: DerivationParams::Scrypt(params), encrypted_key }
80 }
81 }
82 }
83
84 pub fn unlock(&self, secret: impl Into<Vec<u8>>) -> Result<SymmetricKey> {
85 let encrypted_message = &self.encrypted_key;
86 let aad = encrypted_message.aad();
87 let cbor = CBOR::try_from_data(aad)?;
88 let array = cbor.clone().try_into_array()?;
89 let method: KeyDerivationMethod = array
90 .get(0)
91 .ok_or_else(|| Error::msg("Missing method"))?
92 .try_into()?;
93 match method {
94 KeyDerivationMethod::HKDF => {
95 let params = HKDF::try_from(cbor)?;
96 params.unlock(&encrypted_message, secret)
97 }
98 KeyDerivationMethod::PBKDF2 => {
99 let params = PBKDF2::try_from(cbor)?;
100 params.unlock(&encrypted_message, secret)
101 }
102 KeyDerivationMethod::Scrypt => {
103 let params = Scrypt::try_from(cbor)?;
104 params.unlock(&encrypted_message, secret)
105 }
106 }
107 }
108}
109
110impl std::fmt::Display for EncryptedKey {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 write!(f, "EncryptedKey({})", self.params)
113 }
114}
115
116impl CBORTagged for EncryptedKey {
117 fn cbor_tags() -> Vec<Tag> {
118 tags_for_values(&[tags::TAG_ENCRYPTED_KEY])
119 }
120}
121
122impl From<EncryptedKey> for CBOR {
123 fn from(value: EncryptedKey) -> Self {
124 value.tagged_cbor()
125 }
126}
127
128impl CBORTaggedEncodable for EncryptedKey {
129 fn untagged_cbor(&self) -> CBOR {
130 return self.encrypted_key.clone().into();
131 }
132}
133
134impl TryFrom<CBOR> for EncryptedKey {
135 type Error = dcbor::Error;
136
137 fn try_from(value: CBOR) -> dcbor::Result<Self> {
138 Self::from_tagged_cbor(value)
139 }
140}
141
142impl CBORTaggedDecodable for EncryptedKey {
143 fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
144 let encrypted_key: EncryptedMessage = untagged_cbor.try_into()?;
145 let params_cbor = CBOR::try_from_data(encrypted_key.aad().clone())?;
146 let params = params_cbor.try_into()?;
147 Ok(Self { params, encrypted_key })
148 }
149}
150
151#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)]
160pub enum HashType {
161 SHA256 = 0,
162 SHA512 = 1,
163}
164
165impl std::fmt::Display for HashType {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 match self {
168 HashType::SHA256 => write!(f, "SHA256"),
169 HashType::SHA512 => write!(f, "SHA512"),
170 }
171 }
172}
173
174impl Into<CBOR> for HashType {
175 fn into(self) -> CBOR {
176 CBOR::from(self as u8)
177 }
178}
179
180impl TryFrom<CBOR> for HashType {
181 type Error = Error;
182
183 fn try_from(cbor: CBOR) -> Result<Self> {
184 let i: u8 = cbor.try_into()?;
185 match i {
186 0 => Ok(HashType::SHA256),
187 1 => Ok(HashType::SHA512),
188 _ => Err(Error::msg("Invalid HashType")),
189 }
190 }
191}
192
193#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
200pub enum KeyDerivationMethod {
201 HKDF = 0,
202 PBKDF2 = 1,
203 Scrypt = 2,
204}
205
206impl KeyDerivationMethod {
207 pub fn index(&self) -> usize {
209 *self as usize
210 }
211
212 pub fn from_index(index: usize) -> Option<Self> {
214 match index {
215 0 => Some(KeyDerivationMethod::HKDF),
216 1 => Some(KeyDerivationMethod::PBKDF2),
217 2 => Some(KeyDerivationMethod::Scrypt),
218 _ => None,
219 }
220 }
221}
222
223impl std::fmt::Display for KeyDerivationMethod {
224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 match self {
226 KeyDerivationMethod::HKDF => write!(f, "HKDF"),
227 KeyDerivationMethod::PBKDF2 => write!(f, "PBKDF2"),
228 KeyDerivationMethod::Scrypt => write!(f, "Scrypt"),
229 }
230 }
231}
232
233impl TryFrom<&CBOR> for KeyDerivationMethod {
234 type Error = Error;
235
236 fn try_from(cbor: &CBOR) -> Result<Self> {
237 let i: usize = cbor.clone().try_into()?;
238 KeyDerivationMethod::from_index(i).ok_or_else(|| Error::msg("Invalid KeyDerivationMethod"))
239 }
240}
241
242pub trait KeyDerivation: Into<CBOR> + TryFrom<CBOR> + Clone {
244 const INDEX: usize;
245
246 fn lock(&self, content_key: &SymmetricKey, secret: impl Into<Vec<u8>>) -> EncryptedMessage;
247 fn unlock(
248 &self,
249 encrypted_key: &EncryptedMessage,
250 secret: impl Into<Vec<u8>>
251 ) -> Result<SymmetricKey>;
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
261pub struct HKDF {
262 salt: Salt,
263 hash_type: HashType,
264}
265
266impl KeyDerivation for HKDF {
267 const INDEX: usize = KeyDerivationMethod::HKDF as usize;
268
269 fn lock(&self, content_key: &SymmetricKey, secret: impl Into<Vec<u8>>) -> EncryptedMessage {
270 let derived_key: SymmetricKey = (
271 match self.hash_type {
272 HashType::SHA256 => hkdf_hmac_sha256(secret.into(), &self.salt, 32),
273 HashType::SHA512 => hkdf_hmac_sha512(secret.into(), &self.salt, 32),
274 }
275 )
276 .try_into()
277 .unwrap();
278 let encoded_method: Vec<u8> = self.to_cbor_data();
279 derived_key.encrypt(content_key, Some(encoded_method), Option::<Nonce>::None)
280 }
281
282 fn unlock(
283 &self,
284 encrypted_key: &EncryptedMessage,
285 secret: impl Into<Vec<u8>>
286 ) -> Result<SymmetricKey> {
287 let derived_key: SymmetricKey = (
288 match self.hash_type {
289 HashType::SHA256 => hkdf_hmac_sha256(secret.into(), &self.salt, 32),
290 HashType::SHA512 => hkdf_hmac_sha512(secret.into(), &self.salt, 32),
291 }
292 ).try_into()?;
293 let content_key = derived_key.decrypt(encrypted_key)?.try_into()?;
294 Ok(content_key)
295 }
296}
297
298impl std::fmt::Display for HKDF {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 write!(f, "HKDF({})", self.hash_type)
301 }
302}
303
304impl Into<CBOR> for HKDF {
305 fn into(self) -> CBOR {
306 vec![CBOR::from(Self::INDEX), self.salt.into(), self.hash_type.into()].into()
307 }
308}
309
310impl TryFrom<CBOR> for HKDF {
311 type Error = Error;
312
313 fn try_from(cbor: CBOR) -> Result<Self> {
314 let a = cbor.try_into_array()?;
315 a
316 .len()
317 .eq(&3)
318 .then_some(())
319 .ok_or_else(|| Error::msg("Invalid HKDF CBOR"))?;
320 let mut iter = a.into_iter();
321 let _index: usize = iter
322 .next()
323 .ok_or_else(|| Error::msg("Missing index"))?
324 .try_into()?;
325 let salt: Salt = iter
326 .next()
327 .ok_or_else(|| Error::msg("Missing salt"))?
328 .try_into()?;
329 let hash_type: HashType = iter
330 .next()
331 .ok_or_else(|| Error::msg("Missing hash type"))?
332 .try_into()?;
333 Ok(Self { salt, hash_type })
334 }
335}
336
337impl HKDF {
338 pub fn new() -> Self {
339 Self::new_opt(Salt::new_with_len(SALT_LEN).unwrap(), HashType::SHA256)
340 }
341
342 pub fn new_opt(salt: Salt, hash_type: HashType) -> Self {
343 Self { salt, hash_type }
344 }
345
346 pub fn salt(&self) -> &Salt {
347 &self.salt
348 }
349
350 pub fn hash_type(&self) -> HashType {
351 self.hash_type
352 }
353}
354
355#[derive(Debug, Clone, PartialEq, Eq)]
362pub struct PBKDF2 {
363 salt: Salt,
364 iterations: u32,
365 hash_type: HashType,
366}
367
368impl KeyDerivation for PBKDF2 {
369 const INDEX: usize = KeyDerivationMethod::PBKDF2 as usize;
370
371 fn lock(&self, content_key: &SymmetricKey, secret: impl Into<Vec<u8>>) -> EncryptedMessage {
372 let derived_key: SymmetricKey = (
373 match self.hash_type {
374 HashType::SHA256 =>
375 pbkdf2_hmac_sha256(secret.into(), &self.salt, self.iterations, 32),
376 HashType::SHA512 =>
377 pbkdf2_hmac_sha256(secret.into(), &self.salt, self.iterations, 32),
378 }
379 )
380 .try_into()
381 .unwrap();
382 let encoded_method: Vec<u8> = self.to_cbor_data();
383 derived_key.encrypt(content_key, Some(encoded_method), Option::<Nonce>::None)
384 }
385
386 fn unlock(
387 &self,
388 encrypted_key: &EncryptedMessage,
389 secret: impl Into<Vec<u8>>
390 ) -> Result<SymmetricKey> {
391 let derived_key: SymmetricKey = (
392 match self.hash_type {
393 HashType::SHA256 =>
394 pbkdf2_hmac_sha256(secret.into(), &self.salt, self.iterations, 32),
395 HashType::SHA512 =>
396 pbkdf2_hmac_sha256(secret.into(), &self.salt, self.iterations, 32),
397 }
398 ).try_into()?;
399 derived_key.decrypt(encrypted_key)?.try_into()
400 }
401}
402
403impl std::fmt::Display for PBKDF2 {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 write!(f, "PBKDF2({})", self.hash_type)
406 }
407}
408
409impl Into<CBOR> for PBKDF2 {
410 fn into(self) -> CBOR {
411 vec![
412 CBOR::from(Self::INDEX),
413 self.salt.into(),
414 self.iterations.into(),
415 self.hash_type.into()
416 ].into()
417 }
418}
419
420impl TryFrom<CBOR> for PBKDF2 {
421 type Error = Error;
422
423 fn try_from(cbor: CBOR) -> Result<Self> {
424 let a = cbor.try_into_array()?;
425 a
426 .len()
427 .eq(&4)
428 .then_some(())
429 .ok_or_else(|| Error::msg("Invalid PBKDF2 CBOR"))?;
430 let mut iter = a.into_iter();
431 let _index: usize = iter
432 .next()
433 .ok_or_else(|| Error::msg("Missing index"))?
434 .try_into()?;
435 let salt: Salt = iter
436 .next()
437 .ok_or_else(|| Error::msg("Missing salt"))?
438 .try_into()?;
439 let iterations: u32 = iter
440 .next()
441 .ok_or_else(|| Error::msg("Missing iterations"))?
442 .try_into()?;
443 let hash_type: HashType = iter
444 .next()
445 .ok_or_else(|| Error::msg("Missing hash type"))?
446 .try_into()?;
447 Ok(Self { salt, iterations, hash_type })
448 }
449}
450
451impl PBKDF2 {
452 pub fn new() -> Self {
453 Self::new_opt(Salt::new_with_len(SALT_LEN).unwrap(), 100_000, HashType::SHA256)
454 }
455
456 pub fn new_opt(salt: Salt, iterations: u32, hash_type: HashType) -> Self {
457 Self { salt, iterations, hash_type }
458 }
459
460 pub fn salt(&self) -> &Salt {
461 &self.salt
462 }
463
464 pub fn iterations(&self) -> u32 {
465 self.iterations
466 }
467
468 pub fn hash_type(&self) -> HashType {
469 self.hash_type
470 }
471}
472
473#[derive(Debug, Clone, PartialEq, Eq)]
480pub struct Scrypt {
481 salt: Salt,
482 log_n: u8,
483 r: u32,
484 p: u32,
485}
486
487impl KeyDerivation for Scrypt {
488 const INDEX: usize = KeyDerivationMethod::Scrypt as usize;
489 fn lock(&self, content_key: &SymmetricKey, secret: impl Into<Vec<u8>>) -> EncryptedMessage {
490 let derived_key: SymmetricKey = scrypt_opt(
491 secret.into(),
492 &self.salt,
493 32,
494 self.log_n,
495 self.r,
496 self.p
497 )
498 .try_into()
499 .unwrap();
500 let encoded_method: Vec<u8> = self.to_cbor_data();
501 derived_key.encrypt(content_key, Some(encoded_method), Option::<Nonce>::None)
502 }
503
504 fn unlock(
505 &self,
506 encrypted_key: &EncryptedMessage,
507 secret: impl Into<Vec<u8>>
508 ) -> Result<SymmetricKey> {
509 let derived_key: SymmetricKey = scrypt_opt(
510 secret.into(),
511 &self.salt,
512 32,
513 self.log_n,
514 self.r,
515 self.p
516 ).try_into()?;
517 derived_key.decrypt(encrypted_key)?.try_into()
518 }
519}
520
521impl std::fmt::Display for Scrypt {
522 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 write!(f, "Scrypt")
524 }
525}
526
527impl Into<CBOR> for Scrypt {
528 fn into(self) -> CBOR {
529 vec![
530 CBOR::from(Self::INDEX),
531 self.salt.into(),
532 self.log_n.into(),
533 self.r.into(),
534 self.p.into()
535 ].into()
536 }
537}
538
539impl TryFrom<CBOR> for Scrypt {
540 type Error = Error;
541
542 fn try_from(cbor: CBOR) -> Result<Self> {
543 let a = cbor.try_into_array()?;
544 a
545 .len()
546 .eq(&5)
547 .then_some(())
548 .ok_or_else(|| Error::msg("Invalid Scrypt CBOR"))?;
549 let mut iter = a.into_iter();
550 let _index: usize = iter
551 .next()
552 .ok_or_else(|| Error::msg("Missing index"))?
553 .try_into()?;
554 let salt: Salt = iter
555 .next()
556 .ok_or_else(|| Error::msg("Missing salt"))?
557 .try_into()?;
558 let log_n: u8 = iter
559 .next()
560 .ok_or_else(|| Error::msg("Missing log_n"))?
561 .try_into()?;
562 let r: u32 = iter
563 .next()
564 .ok_or_else(|| Error::msg("Missing r"))?
565 .try_into()?;
566 let p: u32 = iter
567 .next()
568 .ok_or_else(|| Error::msg("Missing p"))?
569 .try_into()?;
570 Ok(Self { salt, log_n, r, p })
571 }
572}
573
574impl Scrypt {
575 pub fn new() -> Self {
576 Self::new_opt(Salt::new_with_len(SALT_LEN).unwrap(), 15, 8, 1)
577 }
578
579 pub fn new_opt(salt: Salt, log_n: u8, r: u32, p: u32) -> Self {
580 Self { salt, log_n, r, p }
581 }
582
583 pub fn salt(&self) -> &Salt {
584 &self.salt
585 }
586
587 pub fn log_n(&self) -> u8 {
588 self.log_n
589 }
590
591 pub fn r(&self) -> u32 {
592 self.r
593 }
594
595 pub fn p(&self) -> u32 {
596 self.p
597 }
598}
599
600#[derive(Debug, Clone, PartialEq, Eq)]
607pub enum DerivationParams {
608 HKDF(HKDF),
609 PBKDF2(PBKDF2),
610 Scrypt(Scrypt),
611}
612
613impl std::fmt::Display for DerivationParams {
614 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615 match self {
616 DerivationParams::HKDF(params) => write!(f, "{}", params),
617 DerivationParams::PBKDF2(params) => write!(f, "{}", params),
618 DerivationParams::Scrypt(params) => write!(f, "{}", params),
619 }
620 }
621}
622
623impl From<DerivationParams> for CBOR {
624 fn from(value: DerivationParams) -> Self {
625 match value {
626 DerivationParams::HKDF(params) => params.into(),
627 DerivationParams::PBKDF2(params) => params.into(),
628 DerivationParams::Scrypt(params) => params.into(),
629 }
630 }
631}
632
633impl TryFrom<CBOR> for DerivationParams {
634 type Error = Error;
635
636 fn try_from(cbor: CBOR) -> Result<Self> {
637 let a = cbor.clone().try_into_array()?;
638 let mut iter = a.into_iter();
639 let index: usize = iter
640 .next()
641 .ok_or_else(|| Error::msg("Missing index"))?
642 .try_into()?;
643 match KeyDerivationMethod::from_index(index) {
644 Some(KeyDerivationMethod::HKDF) => Ok(DerivationParams::HKDF(HKDF::try_from(cbor)?)),
645 Some(KeyDerivationMethod::PBKDF2) =>
646 Ok(DerivationParams::PBKDF2(PBKDF2::try_from(cbor)?)),
647 Some(KeyDerivationMethod::Scrypt) =>
648 Ok(DerivationParams::Scrypt(Scrypt::try_from(cbor)?)),
649 None => Err(Error::msg("Invalid KeyDerivationMethod")),
650 }
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 fn test_secret() -> &'static [u8] {
659 b"correct horse battery staple"
660 }
661
662 fn test_content_key() -> SymmetricKey {
663 SymmetricKey::new()
664 }
665
666 #[test]
667 fn test_encrypted_key_hkdf_roundtrip() {
668 crate::register_tags();
669 let secret = test_secret();
670 let content_key = test_content_key();
671
672 let encrypted = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
673 assert_eq!(format!("{}", encrypted), "EncryptedKey(HKDF(SHA256))");
674 let cbor = encrypted.clone().to_cbor();
675 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
676 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
677
678 assert_eq!(content_key, decrypted);
679 }
680
681 #[test]
682 fn test_encrypted_key_pbkdf2_roundtrip() {
683 let secret = test_secret();
684 let content_key = test_content_key();
685
686 let encrypted = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
687 assert_eq!(format!("{}", encrypted), "EncryptedKey(PBKDF2(SHA256))");
688 let cbor = encrypted.clone().to_cbor();
689 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
690 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
691
692 assert_eq!(content_key, decrypted);
693 }
694
695 #[test]
696 fn test_encrypted_key_scrypt_roundtrip() {
697 let secret = test_secret();
698 let content_key = test_content_key();
699
700 let encrypted = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
701 assert_eq!(format!("{}", encrypted), "EncryptedKey(Scrypt)");
702 let cbor = encrypted.clone().to_cbor();
703 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
704 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
705
706 assert_eq!(content_key, decrypted);
707 }
708
709 #[test]
710 fn test_encrypted_key_wrong_secret_fails() {
711 let secret = test_secret();
712 let wrong_secret = b"wrong secret";
713 let content_key = test_content_key();
714
715 let encrypted = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
716 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
717 assert!(result.is_err());
718
719 let encrypted = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
720 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
721 assert!(result.is_err());
722
723 let encrypted = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
724 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
725 assert!(result.is_err());
726 }
727
728 #[test]
729 fn test_encrypted_key_params_variant() {
730 let secret = test_secret();
731 let content_key = test_content_key();
732
733 let hkdf = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
734 matches!(hkdf.params, DerivationParams::HKDF(_));
735
736 let pbkdf2 = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
737 matches!(pbkdf2.params, DerivationParams::PBKDF2(_));
738
739 let scrypt = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
740 matches!(scrypt.params, DerivationParams::Scrypt(_));
741 }
742}