1use hkdf::Hkdf;
4use sha2::Sha256;
5use thiserror::Error;
6
7use crate::EncryptionKey;
8
9#[derive(Debug, Error)]
11pub enum KdfError {
12 #[error("Output key material too long")]
13 OutputTooLong,
14
15 #[error("Invalid key length")]
16 InvalidKeyLength,
17}
18
19pub struct KeyDerivation {
21 hkdf: Hkdf<Sha256>,
22}
23
24impl KeyDerivation {
25 pub fn new(ikm: &[u8], salt: Option<&[u8]>) -> Self {
31 let hkdf = Hkdf::<Sha256>::new(salt, ikm);
32 Self { hkdf }
33 }
34
35 pub fn derive_encryption_key(&self, info: &[u8]) -> Result<EncryptionKey, KdfError> {
40 let mut okm = [0u8; 32];
41 self.hkdf
42 .expand(info, &mut okm)
43 .map_err(|_| KdfError::OutputTooLong)?;
44 Ok(okm)
45 }
46
47 pub fn derive_bytes(&self, info: &[u8], length: usize) -> Result<Vec<u8>, KdfError> {
53 let mut okm = vec![0u8; length];
54 self.hkdf
55 .expand(info, &mut okm)
56 .map_err(|_| KdfError::OutputTooLong)?;
57 Ok(okm)
58 }
59}
60
61pub fn derive_content_key(
65 master_key: &EncryptionKey,
66 content_cid: &str,
67 chunk_index: u64,
68) -> Result<EncryptionKey, KdfError> {
69 let kdf = KeyDerivation::new(master_key, Some(b"chie-content-v1"));
70
71 let mut info = Vec::new();
73 info.extend_from_slice(content_cid.as_bytes());
74 info.extend_from_slice(&chunk_index.to_le_bytes());
75
76 kdf.derive_encryption_key(&info)
77}
78
79pub fn derive_chunk_nonce(
83 master_key: &EncryptionKey,
84 content_cid: &str,
85 chunk_index: u64,
86) -> Result<[u8; 12], KdfError> {
87 let kdf = KeyDerivation::new(master_key, Some(b"chie-nonce-v1"));
88
89 let mut info = Vec::new();
90 info.extend_from_slice(content_cid.as_bytes());
91 info.extend_from_slice(&chunk_index.to_le_bytes());
92
93 let bytes = kdf.derive_bytes(&info, 12)?;
94 let mut nonce = [0u8; 12];
95 nonce.copy_from_slice(&bytes);
96 Ok(nonce)
97}
98
99pub fn derive_chunk_keys(
101 master_key: &EncryptionKey,
102 content_cid: &str,
103 start_chunk: u64,
104 count: usize,
105) -> Result<Vec<EncryptionKey>, KdfError> {
106 let mut keys = Vec::with_capacity(count);
107 for i in 0..count as u64 {
108 keys.push(derive_content_key(
109 master_key,
110 content_cid,
111 start_chunk + i,
112 )?);
113 }
114 Ok(keys)
115}
116
117pub fn hkdf_extract_expand(ikm: &[u8], salt: &[u8], info: &[u8]) -> [u8; 32] {
127 let hkdf = Hkdf::<Sha256>::new(Some(salt), ikm);
128 let mut okm = [0u8; 32];
129 hkdf.expand(info, &mut okm)
130 .expect("32 bytes is a valid HKDF output length");
131 okm
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::generate_key;
138
139 #[test]
140 fn test_key_derivation_deterministic() {
141 let master = generate_key();
142 let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
143
144 let key1 = kdf.derive_encryption_key(b"test-info").unwrap();
145 let key2 = kdf.derive_encryption_key(b"test-info").unwrap();
146
147 assert_eq!(key1, key2);
148 }
149
150 #[test]
151 fn test_different_info_different_keys() {
152 let master = generate_key();
153 let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
154
155 let key1 = kdf.derive_encryption_key(b"info-1").unwrap();
156 let key2 = kdf.derive_encryption_key(b"info-2").unwrap();
157
158 assert_ne!(key1, key2);
159 }
160
161 #[test]
162 fn test_content_key_derivation() {
163 let master = generate_key();
164
165 let key1 = derive_content_key(&master, "QmTest123", 0).unwrap();
166 let key2 = derive_content_key(&master, "QmTest123", 1).unwrap();
167 let key3 = derive_content_key(&master, "QmOther", 0).unwrap();
168
169 assert_ne!(key1, key2);
171 assert_ne!(key1, key3);
173
174 let key1_again = derive_content_key(&master, "QmTest123", 0).unwrap();
176 assert_eq!(key1, key1_again);
177 }
178
179 #[test]
180 fn test_chunk_nonce_derivation() {
181 let master = generate_key();
182
183 let nonce1 = derive_chunk_nonce(&master, "QmTest", 0).unwrap();
184 let nonce2 = derive_chunk_nonce(&master, "QmTest", 1).unwrap();
185
186 assert_ne!(nonce1, nonce2);
187 assert_eq!(nonce1.len(), 12);
188 }
189
190 #[test]
191 fn test_derive_bytes_various_lengths() {
192 let master = generate_key();
193 let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
194
195 for len in [16, 32, 64, 128] {
197 let bytes = kdf.derive_bytes(b"test-info", len).unwrap();
198 assert_eq!(bytes.len(), len);
199 }
200 }
201
202 #[test]
203 fn test_derive_bytes_different_lengths_different_output() {
204 let master = generate_key();
205 let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
206
207 let bytes32 = kdf.derive_bytes(b"test-info", 32).unwrap();
208 let bytes64 = kdf.derive_bytes(b"test-info", 64).unwrap();
209
210 assert_eq!(&bytes64[..32], &bytes32[..]);
212 }
213
214 #[test]
215 fn test_derive_chunk_keys_batch() {
216 let master = generate_key();
217 let cid = "QmTestContent";
218
219 let key0 = derive_content_key(&master, cid, 0).unwrap();
221 let key1 = derive_content_key(&master, cid, 1).unwrap();
222 let key2 = derive_content_key(&master, cid, 2).unwrap();
223
224 let batch_keys = derive_chunk_keys(&master, cid, 0, 3).unwrap();
226
227 assert_eq!(batch_keys.len(), 3);
228 assert_eq!(batch_keys[0], key0);
229 assert_eq!(batch_keys[1], key1);
230 assert_eq!(batch_keys[2], key2);
231 }
232
233 #[test]
234 fn test_hkdf_extract_expand_deterministic() {
235 let ikm = b"input-key-material";
236 let salt = b"salt-value";
237 let info = b"context-info";
238
239 let key1 = hkdf_extract_expand(ikm, salt, info);
240 let key2 = hkdf_extract_expand(ikm, salt, info);
241
242 assert_eq!(key1, key2);
243 assert_eq!(key1.len(), 32);
244 }
245
246 #[test]
247 fn test_hkdf_extract_expand_different_inputs() {
248 let ikm = b"input-key-material";
249 let salt = b"salt-value";
250
251 let key1 = hkdf_extract_expand(ikm, salt, b"info1");
252 let key2 = hkdf_extract_expand(ikm, salt, b"info2");
253 let key3 = hkdf_extract_expand(ikm, b"other-salt", b"info1");
254
255 assert_ne!(key1, key2); assert_ne!(key1, key3); }
258
259 #[test]
260 fn test_kdf_with_no_salt() {
261 let master = generate_key();
262 let kdf_with_salt = KeyDerivation::new(&master, Some(b"salt"));
263 let kdf_no_salt = KeyDerivation::new(&master, None);
264
265 let key_with_salt = kdf_with_salt.derive_encryption_key(b"info").unwrap();
266 let key_no_salt = kdf_no_salt.derive_encryption_key(b"info").unwrap();
267
268 assert_ne!(key_with_salt, key_no_salt);
270 }
271
272 #[test]
273 fn test_content_key_with_large_chunk_index() {
274 let master = generate_key();
275 let cid = "QmTest";
276
277 let key1 = derive_content_key(&master, cid, u64::MAX / 2).unwrap();
278 let key2 = derive_content_key(&master, cid, u64::MAX / 2 + 1).unwrap();
279
280 assert_ne!(key1, key2);
281 }
282
283 #[test]
284 fn test_derive_bytes_empty_info() {
285 let master = generate_key();
286 let kdf = KeyDerivation::new(&master, Some(b"salt"));
287
288 let key_empty = kdf.derive_bytes(b"", 32).unwrap();
289 let key_nonempty = kdf.derive_bytes(b"info", 32).unwrap();
290
291 assert_ne!(key_empty, key_nonempty);
292 assert_eq!(key_empty.len(), 32);
293 }
294}