1use crate::{ tags, EncryptedMessage };
6use anyhow::{ Result, Error };
7use dcbor::prelude::*;
8
9use super::{ Argon2id, DerivationParams, KeyDerivation, Scrypt, SymmetricKey, HKDF, PBKDF2 };
10
11#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct EncryptedKey {
58 params: DerivationParams,
59 encrypted_key: EncryptedMessage,
60}
61
62impl EncryptedKey {
63 pub fn lock(
64 method: KeyDerivationMethod,
65 secret: impl AsRef<[u8]>,
66 content_key: &SymmetricKey
67 ) -> Self {
68 match method {
69 KeyDerivationMethod::HKDF => {
70 let params = HKDF::new();
71 let encrypted_key = params.lock(content_key, secret);
72 Self { params: DerivationParams::HKDF(params), encrypted_key }
73 }
74 KeyDerivationMethod::PBKDF2 => {
75 let params = PBKDF2::new();
76 let encrypted_key = params.lock(content_key, secret);
77 Self { params: DerivationParams::PBKDF2(params), encrypted_key }
78 }
79 KeyDerivationMethod::Scrypt => {
80 let params = Scrypt::new();
81 let encrypted_key = params.lock(content_key, secret);
82 Self { params: DerivationParams::Scrypt(params), encrypted_key }
83 }
84 KeyDerivationMethod::Argon2id => {
85 let params = Argon2id::new();
86 let encrypted_key = params.lock(content_key, secret);
87 Self { params: DerivationParams::Argon2id(params), encrypted_key }
88 }
89 }
90 }
91
92 pub fn unlock(&self, secret: impl AsRef<[u8]>) -> Result<SymmetricKey> {
93 let encrypted_message = &self.encrypted_key;
94 let aad = encrypted_message.aad();
95 let cbor = CBOR::try_from_data(aad)?;
96 let array = cbor.clone().try_into_array()?;
97 let method: KeyDerivationMethod = array
98 .get(0)
99 .ok_or_else(|| Error::msg("Missing method"))?
100 .try_into()?;
101 match method {
102 KeyDerivationMethod::HKDF => {
103 let params = HKDF::try_from(cbor)?;
104 params.unlock(&encrypted_message, secret)
105 }
106 KeyDerivationMethod::PBKDF2 => {
107 let params = PBKDF2::try_from(cbor)?;
108 params.unlock(&encrypted_message, secret)
109 }
110 KeyDerivationMethod::Scrypt => {
111 let params = Scrypt::try_from(cbor)?;
112 params.unlock(&encrypted_message, secret)
113 }
114 KeyDerivationMethod::Argon2id => {
115 let params = Argon2id::try_from(cbor)?;
116 params.unlock(&encrypted_message, secret)
117 }
118 }
119 }
120}
121
122impl std::fmt::Display for EncryptedKey {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 write!(f, "EncryptedKey({})", self.params)
125 }
126}
127
128impl CBORTagged for EncryptedKey {
129 fn cbor_tags() -> Vec<Tag> {
130 tags_for_values(&[tags::TAG_ENCRYPTED_KEY])
131 }
132}
133
134impl From<EncryptedKey> for CBOR {
135 fn from(value: EncryptedKey) -> Self {
136 value.tagged_cbor()
137 }
138}
139
140impl CBORTaggedEncodable for EncryptedKey {
141 fn untagged_cbor(&self) -> CBOR {
142 return self.encrypted_key.clone().into();
143 }
144}
145
146impl TryFrom<CBOR> for EncryptedKey {
147 type Error = dcbor::Error;
148
149 fn try_from(value: CBOR) -> dcbor::Result<Self> {
150 Self::from_tagged_cbor(value)
151 }
152}
153
154impl CBORTaggedDecodable for EncryptedKey {
155 fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
156 let encrypted_key: EncryptedMessage = untagged_cbor.try_into()?;
157 let params_cbor = CBOR::try_from_data(encrypted_key.aad())?;
158 let params = params_cbor.try_into()?;
159 Ok(Self { params, encrypted_key })
160 }
161}
162
163#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
170pub enum KeyDerivationMethod {
171 HKDF = 0,
172 PBKDF2 = 1,
173 Scrypt = 2,
174 Argon2id = 3,
175}
176
177impl KeyDerivationMethod {
178 pub fn index(&self) -> usize {
180 *self as usize
181 }
182
183 pub fn from_index(index: usize) -> Option<Self> {
185 match index {
186 0 => Some(KeyDerivationMethod::HKDF),
187 1 => Some(KeyDerivationMethod::PBKDF2),
188 2 => Some(KeyDerivationMethod::Scrypt),
189 3 => Some(KeyDerivationMethod::Argon2id),
190 _ => None,
191 }
192 }
193}
194
195impl std::fmt::Display for KeyDerivationMethod {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 KeyDerivationMethod::HKDF => write!(f, "HKDF"),
199 KeyDerivationMethod::PBKDF2 => write!(f, "PBKDF2"),
200 KeyDerivationMethod::Scrypt => write!(f, "Scrypt"),
201 KeyDerivationMethod::Argon2id => write!(f, "Argon2id"),
202 }
203 }
204}
205
206impl TryFrom<&CBOR> for KeyDerivationMethod {
207 type Error = Error;
208
209 fn try_from(cbor: &CBOR) -> Result<Self> {
210 let i: usize = cbor.clone().try_into()?;
211 KeyDerivationMethod::from_index(i).ok_or_else(|| Error::msg("Invalid KeyDerivationMethod"))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 fn test_secret() -> &'static [u8] {
220 b"correct horse battery staple"
221 }
222
223 fn test_content_key() -> SymmetricKey {
224 SymmetricKey::new()
225 }
226
227 #[test]
228 fn test_encrypted_key_hkdf_roundtrip() {
229 crate::register_tags();
230 let secret = test_secret();
231 let content_key = test_content_key();
232
233 let encrypted = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
234 assert_eq!(format!("{}", encrypted), "EncryptedKey(HKDF(SHA256))");
235 let cbor = encrypted.clone().to_cbor();
236 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
237 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
238
239 assert_eq!(content_key, decrypted);
240 }
241
242 #[test]
243 fn test_encrypted_key_pbkdf2_roundtrip() {
244 let secret = test_secret();
245 let content_key = test_content_key();
246
247 let encrypted = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
248 assert_eq!(format!("{}", encrypted), "EncryptedKey(PBKDF2(SHA256))");
249 let cbor = encrypted.clone().to_cbor();
250 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
251 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
252
253 assert_eq!(content_key, decrypted);
254 }
255
256 #[test]
257 fn test_encrypted_key_scrypt_roundtrip() {
258 let secret = test_secret();
259 let content_key = test_content_key();
260
261 let encrypted = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
262 assert_eq!(format!("{}", encrypted), "EncryptedKey(Scrypt)");
263 let cbor = encrypted.clone().to_cbor();
264 let encrypted2 = EncryptedKey::try_from(cbor).unwrap();
265 let decrypted = EncryptedKey::unlock(&encrypted2, secret).unwrap();
266
267 assert_eq!(content_key, decrypted);
268 }
269
270 #[test]
271 fn test_encrypted_key_wrong_secret_fails() {
272 let secret = test_secret();
273 let wrong_secret = b"wrong secret";
274 let content_key = test_content_key();
275
276 let encrypted = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
277 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
278 assert!(result.is_err());
279
280 let encrypted = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
281 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
282 assert!(result.is_err());
283
284 let encrypted = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
285 let result = EncryptedKey::unlock(&encrypted, wrong_secret);
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn test_encrypted_key_params_variant() {
291 let secret = test_secret();
292 let content_key = test_content_key();
293
294 let hkdf = EncryptedKey::lock(KeyDerivationMethod::HKDF, secret, &content_key);
295 matches!(hkdf.params, DerivationParams::HKDF(_));
296
297 let pbkdf2 = EncryptedKey::lock(KeyDerivationMethod::PBKDF2, secret, &content_key);
298 matches!(pbkdf2.params, DerivationParams::PBKDF2(_));
299
300 let scrypt = EncryptedKey::lock(KeyDerivationMethod::Scrypt, secret, &content_key);
301 matches!(scrypt.params, DerivationParams::Scrypt(_));
302 }
303}