phalanx_crypto/
crypto.rs

1//! Core cryptographic primitives for Phalanx Protocol
2
3use crate::error::{PhalanxError, Result};
4use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce, aead::{Aead, KeyInit}};
5use blake3::Hasher;
6use hkdf::Hkdf;
7use sha2::Sha256;
8use rand::{RngCore, rngs::OsRng};
9use zeroize::{Zeroize, ZeroizeOnDrop};
10
11/// Key derivation context strings
12pub mod contexts {
13    /// Context for deriving group encryption keys
14    pub const GROUP_KEY: &str = "PHALANX_GROUP_KEY_V1";
15    /// Context for deriving message keys
16    pub const MESSAGE_KEY: &str = "PHALANX_MESSAGE_KEY_V1";
17    /// Context for deriving authentication keys
18    pub const AUTH_KEY: &str = "PHALANX_AUTH_KEY_V1";
19    /// Context for deriving key exchange keys
20    pub const KEY_EXCHANGE: &str = "PHALANX_KEY_EXCHANGE_V1";
21    /// Context for key derivation
22    pub const KEY_DERIVATION: &str = "PHALANX_KEY_DERIVE_V1";
23}
24
25/// Symmetric encryption key that is automatically zeroized
26#[derive(Clone, Zeroize, ZeroizeOnDrop)]
27pub struct SymmetricKey([u8; 32]);
28
29/// Encrypted data with associated authentication tag
30#[derive(Debug, Clone)]
31pub struct EncryptedData {
32    /// The ciphertext
33    pub ciphertext: Vec<u8>,
34    /// The nonce used for encryption
35    pub nonce: [u8; 12],
36    /// Additional authenticated data hash
37    pub aad_hash: [u8; 32],
38}
39
40impl SymmetricKey {
41    /// Generate a new random symmetric key
42    pub fn generate() -> Self {
43        let mut key = [0u8; 32];
44        OsRng.fill_bytes(&mut key);
45        Self(key)
46    }
47    
48    /// Create key from bytes
49    pub fn from_bytes(bytes: [u8; 32]) -> Result<Self> {
50        Ok(Self(bytes))
51    }
52    
53    /// Get key bytes (use with caution)
54    pub fn as_bytes(&self) -> &[u8; 32] {
55        &self.0
56    }
57    
58    /// Encrypt data with associated authenticated data
59    pub fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<EncryptedData> {
60        // Generate random nonce
61        let mut nonce_bytes = [0u8; 12];
62        OsRng.fill_bytes(&mut nonce_bytes);
63        let nonce = Nonce::from_slice(&nonce_bytes);
64        
65        // Create cipher
66        let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.0));
67        
68        // Encrypt with AAD
69        let ciphertext = cipher.encrypt(nonce, aead::Payload {
70            msg: plaintext,
71            aad,
72        })?;
73        
74        // Hash the AAD for verification
75        let aad_hash = blake3::hash(aad).into();
76        
77        Ok(EncryptedData {
78            ciphertext,
79            nonce: nonce_bytes,
80            aad_hash,
81        })
82    }
83    
84    /// Decrypt data and verify associated authenticated data
85    pub fn decrypt(&self, data: &EncryptedData, aad: &[u8]) -> Result<Vec<u8>> {
86        // Verify AAD hash
87        let expected_hash = blake3::hash(aad);
88        if data.aad_hash != *expected_hash.as_bytes() {
89            return Err(PhalanxError::auth("AAD hash mismatch"));
90        }
91        
92        // Create cipher and decrypt
93        let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.0));
94        let nonce = Nonce::from_slice(&data.nonce);
95        
96        let plaintext = cipher.decrypt(nonce, aead::Payload {
97            msg: &data.ciphertext,
98            aad,
99        })?;
100        
101        Ok(plaintext)
102    }
103}
104
105/// Key derivation function using BLAKE3
106pub fn derive_phalanx_key(ikm: &[u8], _salt: &[u8], info: &str) -> SymmetricKey {
107    let derived = blake3::derive_key(info, ikm);
108    SymmetricKey(derived)
109}
110
111/// Key derivation using HKDF-SHA256
112pub fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Result<Vec<u8>> {
113    let hk = Hkdf::<Sha256>::from_prk(prk)
114        .map_err(|e| PhalanxError::key_derivation(format!("HKDF PRK invalid: {}", e)))?;
115    
116    let mut output = vec![0u8; length];
117    hk.expand(info, &mut output)
118        .map_err(|e| PhalanxError::key_derivation(format!("HKDF expand failed: {}", e)))?;
119    
120    Ok(output)
121}
122
123/// Extract key material using HKDF-SHA256
124pub fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
125    let (prk, _) = Hkdf::<Sha256>::extract(Some(salt), ikm);
126    prk.into()
127}
128
129/// Secure hash function
130pub fn hash(data: &[u8]) -> [u8; 32] {
131    blake3::hash(data).into()
132}
133
134/// Secure hash of multiple inputs
135pub fn hash_multiple(inputs: &[&[u8]]) -> [u8; 32] {
136    let mut hasher = Hasher::new();
137    for input in inputs {
138        hasher.update(input);
139    }
140    hasher.finalize().into()
141}
142
143/// Generate a random nonce
144pub fn generate_nonce() -> [u8; 12] {
145    let mut nonce = [0u8; 12];
146    OsRng.fill_bytes(&mut nonce);
147    nonce
148}
149
150/// Generate random bytes
151pub fn random_bytes(len: usize) -> Vec<u8> {
152    let mut bytes = vec![0u8; len];
153    OsRng.fill_bytes(&mut bytes);
154    bytes
155}
156
157/// Constant-time comparison
158pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
159    use subtle::ConstantTimeEq;
160    a.ct_eq(b).into()
161}
162
163impl std::fmt::Debug for SymmetricKey {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("SymmetricKey")
166            .field("key", &"[REDACTED]")
167            .finish()
168    }
169}
170
171#[cfg(feature = "serde")]
172mod serde_impl {
173    use super::*;
174    use serde::{Serialize, Deserialize};
175    
176    impl Serialize for SymmetricKey {
177        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
178        where
179            S: serde::Serializer,
180        {
181            serializer.serialize_str(&base64::encode(self.as_bytes()))
182        }
183    }
184    
185    impl<'de> Deserialize<'de> for SymmetricKey {
186        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
187        where
188            D: serde::Deserializer<'de>,
189        {
190            use serde::de::{self, Visitor};
191            
192            struct SymmetricKeyVisitor;
193            
194            impl<'de> Visitor<'de> for SymmetricKeyVisitor {
195                type Value = SymmetricKey;
196                
197                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
198                    formatter.write_str("a base64 encoded 32-byte key")
199                }
200                
201                fn visit_str<E>(self, value: &str) -> std::result::Result<SymmetricKey, E>
202                where
203                    E: de::Error,
204                {
205                    let decoded = base64::decode(value)
206                        .map_err(de::Error::custom)?;
207                    if decoded.len() != 32 {
208                        return Err(de::Error::custom("Invalid key length"));
209                    }
210                    let mut key_bytes = [0u8; 32];
211                    key_bytes.copy_from_slice(&decoded);
212                    SymmetricKey::from_bytes(key_bytes)
213                        .map_err(de::Error::custom)
214                }
215            }
216            
217            deserializer.deserialize_str(SymmetricKeyVisitor)
218        }
219    }
220    
221    impl Serialize for EncryptedData {
222        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
223        where
224            S: serde::Serializer,
225        {
226            use serde::ser::SerializeStruct;
227            
228            let mut state = serializer.serialize_struct("EncryptedData", 3)?;
229            state.serialize_field("ciphertext", &base64::encode(&self.ciphertext))?;
230            state.serialize_field("nonce", &base64::encode(&self.nonce))?;
231            state.serialize_field("aad_hash", &base64::encode(&self.aad_hash))?;
232            state.end()
233        }
234    }
235    
236    impl<'de> Deserialize<'de> for EncryptedData {
237        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
238        where
239            D: serde::Deserializer<'de>,
240        {
241            use serde::de::{self, Visitor, MapAccess};
242            
243            struct EncryptedDataVisitor;
244            
245            impl<'de> Visitor<'de> for EncryptedDataVisitor {
246                type Value = EncryptedData;
247                
248                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
249                    formatter.write_str("struct EncryptedData")
250                }
251                
252                fn visit_map<V>(self, mut map: V) -> std::result::Result<EncryptedData, V::Error>
253                where
254                    V: MapAccess<'de>,
255                {
256                    let mut ciphertext = None;
257                    let mut nonce = None;
258                    let mut aad_hash = None;
259                    
260                    while let Some(key) = map.next_key()? {
261                        match key {
262                            "ciphertext" => {
263                                let encoded: String = map.next_value()?;
264                                ciphertext = Some(base64::decode(&encoded)
265                                    .map_err(de::Error::custom)?);
266                            }
267                            "nonce" => {
268                                let encoded: String = map.next_value()?;
269                                let decoded = base64::decode(&encoded)
270                                    .map_err(de::Error::custom)?;
271                                if decoded.len() != 12 {
272                                    return Err(de::Error::custom("Invalid nonce length"));
273                                }
274                                let mut n = [0u8; 12];
275                                n.copy_from_slice(&decoded);
276                                nonce = Some(n);
277                            }
278                            "aad_hash" => {
279                                let encoded: String = map.next_value()?;
280                                let decoded = base64::decode(&encoded)
281                                    .map_err(de::Error::custom)?;
282                                if decoded.len() != 32 {
283                                    return Err(de::Error::custom("Invalid AAD hash length"));
284                                }
285                                let mut h = [0u8; 32];
286                                h.copy_from_slice(&decoded);
287                                aad_hash = Some(h);
288                            }
289                            _ => {
290                                let _: serde::de::IgnoredAny = map.next_value()?;
291                            }
292                        }
293                    }
294                    
295                    let ciphertext = ciphertext.ok_or_else(|| de::Error::missing_field("ciphertext"))?;
296                    let nonce = nonce.ok_or_else(|| de::Error::missing_field("nonce"))?;
297                    let aad_hash = aad_hash.ok_or_else(|| de::Error::missing_field("aad_hash"))?;
298                    
299                    Ok(EncryptedData {
300                        ciphertext,
301                        nonce,
302                        aad_hash,
303                    })
304                }
305            }
306            
307            deserializer.deserialize_struct("EncryptedData", &["ciphertext", "nonce", "aad_hash"], EncryptedDataVisitor)
308        }
309    }
310}
311
312// Add missing import for aead module
313use chacha20poly1305::aead;
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    
319    #[test]
320    fn test_symmetric_encryption() {
321        let key = SymmetricKey::generate();
322        let plaintext = b"Hello, world!";
323        let aad = b"additional data";
324        
325        let encrypted = key.encrypt(plaintext, aad).unwrap();
326        let decrypted = key.decrypt(&encrypted, aad).unwrap();
327        
328        assert_eq!(decrypted, plaintext);
329    }
330    
331    #[test]
332    fn test_key_derivation() {
333        let ikm = b"input key material";
334        let salt = b"salt";
335        let info = contexts::GROUP_KEY;
336        
337        let key1 = derive_phalanx_key(ikm, salt, info);
338        let key2 = derive_phalanx_key(ikm, salt, info);
339        
340        // Should be deterministic
341        assert_eq!(key1.as_bytes(), key2.as_bytes());
342    }
343    
344    #[test]
345    fn test_hkdf() {
346        let ikm = b"input key material";
347        let salt = b"salt";
348        let info = b"info";
349        
350        let prk = hkdf_extract(salt, ikm);
351        let okm = hkdf_expand(&prk, info, 32).unwrap();
352        
353        assert_eq!(okm.len(), 32);
354    }
355    
356    #[test]
357    fn test_hash_functions() {
358        let data = b"test data";
359        let hash1 = hash(data);
360        let hash2 = hash(data);
361        
362        assert_eq!(hash1, hash2);
363        
364        let multi_hash = hash_multiple(&[b"part1", b"part2"]);
365        let single_hash = hash(b"part1part2");
366        
367        // These should be the same since hash_multiple just concatenates
368        assert_eq!(multi_hash, single_hash);
369    }
370}