chie_crypto/
kdf.rs

1//! Key derivation functions using HKDF.
2
3use hkdf::Hkdf;
4use sha2::Sha256;
5use thiserror::Error;
6
7use crate::EncryptionKey;
8
9/// KDF error types.
10#[derive(Debug, Error)]
11pub enum KdfError {
12    #[error("Output key material too long")]
13    OutputTooLong,
14
15    #[error("Invalid key length")]
16    InvalidKeyLength,
17}
18
19/// HKDF context for deriving keys.
20pub struct KeyDerivation {
21    hkdf: Hkdf<Sha256>,
22}
23
24impl KeyDerivation {
25    /// Create a new HKDF instance from input key material.
26    ///
27    /// # Arguments
28    /// * `ikm` - Input key material (e.g., master key)
29    /// * `salt` - Optional salt (use None for unsalted, not recommended)
30    pub fn new(ikm: &[u8], salt: Option<&[u8]>) -> Self {
31        let hkdf = Hkdf::<Sha256>::new(salt, ikm);
32        Self { hkdf }
33    }
34
35    /// Derive an encryption key from the master key.
36    ///
37    /// # Arguments
38    /// * `info` - Context/application-specific info (e.g., "content-encryption")
39    pub fn derive_encryption_key(&self, info: &[u8]) -> Result<EncryptionKey, KdfError> {
40        let mut okm = [0u8; 32];
41        self.hkdf
42            .expand(info, &mut okm)
43            .map_err(|_| KdfError::OutputTooLong)?;
44        Ok(okm)
45    }
46
47    /// Derive a key of arbitrary length.
48    ///
49    /// # Arguments
50    /// * `info` - Context/application-specific info
51    /// * `length` - Desired output length
52    pub fn derive_bytes(&self, info: &[u8], length: usize) -> Result<Vec<u8>, KdfError> {
53        let mut okm = vec![0u8; length];
54        self.hkdf
55            .expand(info, &mut okm)
56            .map_err(|_| KdfError::OutputTooLong)?;
57        Ok(okm)
58    }
59}
60
61/// Derive a content encryption key from a master key and content ID.
62///
63/// This is the recommended way to derive per-content keys.
64pub fn derive_content_key(
65    master_key: &EncryptionKey,
66    content_cid: &str,
67    chunk_index: u64,
68) -> Result<EncryptionKey, KdfError> {
69    let kdf = KeyDerivation::new(master_key, Some(b"chie-content-v1"));
70
71    // Create info with content ID and chunk index
72    let mut info = Vec::new();
73    info.extend_from_slice(content_cid.as_bytes());
74    info.extend_from_slice(&chunk_index.to_le_bytes());
75
76    kdf.derive_encryption_key(&info)
77}
78
79/// Derive a nonce for a specific chunk.
80///
81/// This ensures each chunk has a unique nonce without storing them.
82pub fn derive_chunk_nonce(
83    master_key: &EncryptionKey,
84    content_cid: &str,
85    chunk_index: u64,
86) -> Result<[u8; 12], KdfError> {
87    let kdf = KeyDerivation::new(master_key, Some(b"chie-nonce-v1"));
88
89    let mut info = Vec::new();
90    info.extend_from_slice(content_cid.as_bytes());
91    info.extend_from_slice(&chunk_index.to_le_bytes());
92
93    let bytes = kdf.derive_bytes(&info, 12)?;
94    let mut nonce = [0u8; 12];
95    nonce.copy_from_slice(&bytes);
96    Ok(nonce)
97}
98
99/// Derive multiple chunk keys at once for efficiency.
100pub fn derive_chunk_keys(
101    master_key: &EncryptionKey,
102    content_cid: &str,
103    start_chunk: u64,
104    count: usize,
105) -> Result<Vec<EncryptionKey>, KdfError> {
106    let mut keys = Vec::with_capacity(count);
107    for i in 0..count as u64 {
108        keys.push(derive_content_key(
109            master_key,
110            content_cid,
111            start_chunk + i,
112        )?);
113    }
114    Ok(keys)
115}
116
117/// Simple HKDF extract-and-expand in one operation.
118///
119/// # Arguments
120/// * `ikm` - Input key material
121/// * `salt` - Salt value
122/// * `info` - Context information
123///
124/// # Returns
125/// 32-byte derived key
126pub fn hkdf_extract_expand(ikm: &[u8], salt: &[u8], info: &[u8]) -> [u8; 32] {
127    let hkdf = Hkdf::<Sha256>::new(Some(salt), ikm);
128    let mut okm = [0u8; 32];
129    hkdf.expand(info, &mut okm)
130        .expect("32 bytes is a valid HKDF output length");
131    okm
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::generate_key;
138
139    #[test]
140    fn test_key_derivation_deterministic() {
141        let master = generate_key();
142        let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
143
144        let key1 = kdf.derive_encryption_key(b"test-info").unwrap();
145        let key2 = kdf.derive_encryption_key(b"test-info").unwrap();
146
147        assert_eq!(key1, key2);
148    }
149
150    #[test]
151    fn test_different_info_different_keys() {
152        let master = generate_key();
153        let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
154
155        let key1 = kdf.derive_encryption_key(b"info-1").unwrap();
156        let key2 = kdf.derive_encryption_key(b"info-2").unwrap();
157
158        assert_ne!(key1, key2);
159    }
160
161    #[test]
162    fn test_content_key_derivation() {
163        let master = generate_key();
164
165        let key1 = derive_content_key(&master, "QmTest123", 0).unwrap();
166        let key2 = derive_content_key(&master, "QmTest123", 1).unwrap();
167        let key3 = derive_content_key(&master, "QmOther", 0).unwrap();
168
169        // Different chunks = different keys
170        assert_ne!(key1, key2);
171        // Different content = different keys
172        assert_ne!(key1, key3);
173
174        // Same params = same key (deterministic)
175        let key1_again = derive_content_key(&master, "QmTest123", 0).unwrap();
176        assert_eq!(key1, key1_again);
177    }
178
179    #[test]
180    fn test_chunk_nonce_derivation() {
181        let master = generate_key();
182
183        let nonce1 = derive_chunk_nonce(&master, "QmTest", 0).unwrap();
184        let nonce2 = derive_chunk_nonce(&master, "QmTest", 1).unwrap();
185
186        assert_ne!(nonce1, nonce2);
187        assert_eq!(nonce1.len(), 12);
188    }
189
190    #[test]
191    fn test_derive_bytes_various_lengths() {
192        let master = generate_key();
193        let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
194
195        // Test various output lengths
196        for len in [16, 32, 64, 128] {
197            let bytes = kdf.derive_bytes(b"test-info", len).unwrap();
198            assert_eq!(bytes.len(), len);
199        }
200    }
201
202    #[test]
203    fn test_derive_bytes_different_lengths_different_output() {
204        let master = generate_key();
205        let kdf = KeyDerivation::new(&master, Some(b"test-salt"));
206
207        let bytes32 = kdf.derive_bytes(b"test-info", 32).unwrap();
208        let bytes64 = kdf.derive_bytes(b"test-info", 64).unwrap();
209
210        // First 32 bytes should be the same
211        assert_eq!(&bytes64[..32], &bytes32[..]);
212    }
213
214    #[test]
215    fn test_derive_chunk_keys_batch() {
216        let master = generate_key();
217        let cid = "QmTestContent";
218
219        // Derive keys individually
220        let key0 = derive_content_key(&master, cid, 0).unwrap();
221        let key1 = derive_content_key(&master, cid, 1).unwrap();
222        let key2 = derive_content_key(&master, cid, 2).unwrap();
223
224        // Derive keys in batch
225        let batch_keys = derive_chunk_keys(&master, cid, 0, 3).unwrap();
226
227        assert_eq!(batch_keys.len(), 3);
228        assert_eq!(batch_keys[0], key0);
229        assert_eq!(batch_keys[1], key1);
230        assert_eq!(batch_keys[2], key2);
231    }
232
233    #[test]
234    fn test_hkdf_extract_expand_deterministic() {
235        let ikm = b"input-key-material";
236        let salt = b"salt-value";
237        let info = b"context-info";
238
239        let key1 = hkdf_extract_expand(ikm, salt, info);
240        let key2 = hkdf_extract_expand(ikm, salt, info);
241
242        assert_eq!(key1, key2);
243        assert_eq!(key1.len(), 32);
244    }
245
246    #[test]
247    fn test_hkdf_extract_expand_different_inputs() {
248        let ikm = b"input-key-material";
249        let salt = b"salt-value";
250
251        let key1 = hkdf_extract_expand(ikm, salt, b"info1");
252        let key2 = hkdf_extract_expand(ikm, salt, b"info2");
253        let key3 = hkdf_extract_expand(ikm, b"other-salt", b"info1");
254
255        assert_ne!(key1, key2); // Different info
256        assert_ne!(key1, key3); // Different salt
257    }
258
259    #[test]
260    fn test_kdf_with_no_salt() {
261        let master = generate_key();
262        let kdf_with_salt = KeyDerivation::new(&master, Some(b"salt"));
263        let kdf_no_salt = KeyDerivation::new(&master, None);
264
265        let key_with_salt = kdf_with_salt.derive_encryption_key(b"info").unwrap();
266        let key_no_salt = kdf_no_salt.derive_encryption_key(b"info").unwrap();
267
268        // Keys should be different due to different salts
269        assert_ne!(key_with_salt, key_no_salt);
270    }
271
272    #[test]
273    fn test_content_key_with_large_chunk_index() {
274        let master = generate_key();
275        let cid = "QmTest";
276
277        let key1 = derive_content_key(&master, cid, u64::MAX / 2).unwrap();
278        let key2 = derive_content_key(&master, cid, u64::MAX / 2 + 1).unwrap();
279
280        assert_ne!(key1, key2);
281    }
282
283    #[test]
284    fn test_derive_bytes_empty_info() {
285        let master = generate_key();
286        let kdf = KeyDerivation::new(&master, Some(b"salt"));
287
288        let key_empty = kdf.derive_bytes(b"", 32).unwrap();
289        let key_nonempty = kdf.derive_bytes(b"info", 32).unwrap();
290
291        assert_ne!(key_empty, key_nonempty);
292        assert_eq!(key_empty.len(), 32);
293    }
294}