1use crate::error::{PhalanxError, Result};
4use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce, aead::{Aead, KeyInit}};
5use blake3::Hasher;
6use hkdf::Hkdf;
7use sha2::Sha256;
8use rand::{RngCore, rngs::OsRng};
9use zeroize::{Zeroize, ZeroizeOnDrop};
10
11pub mod contexts {
13 pub const GROUP_KEY: &str = "PHALANX_GROUP_KEY_V1";
15 pub const MESSAGE_KEY: &str = "PHALANX_MESSAGE_KEY_V1";
17 pub const AUTH_KEY: &str = "PHALANX_AUTH_KEY_V1";
19 pub const KEY_EXCHANGE: &str = "PHALANX_KEY_EXCHANGE_V1";
21 pub const KEY_DERIVATION: &str = "PHALANX_KEY_DERIVE_V1";
23}
24
25#[derive(Clone, Zeroize, ZeroizeOnDrop)]
27pub struct SymmetricKey([u8; 32]);
28
29#[derive(Debug, Clone)]
31pub struct EncryptedData {
32 pub ciphertext: Vec<u8>,
34 pub nonce: [u8; 12],
36 pub aad_hash: [u8; 32],
38}
39
40impl SymmetricKey {
41 pub fn generate() -> Self {
43 let mut key = [0u8; 32];
44 OsRng.fill_bytes(&mut key);
45 Self(key)
46 }
47
48 pub fn from_bytes(bytes: [u8; 32]) -> Result<Self> {
50 Ok(Self(bytes))
51 }
52
53 pub fn as_bytes(&self) -> &[u8; 32] {
55 &self.0
56 }
57
58 pub fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<EncryptedData> {
60 let mut nonce_bytes = [0u8; 12];
62 OsRng.fill_bytes(&mut nonce_bytes);
63 let nonce = Nonce::from_slice(&nonce_bytes);
64
65 let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.0));
67
68 let ciphertext = cipher.encrypt(nonce, aead::Payload {
70 msg: plaintext,
71 aad,
72 })?;
73
74 let aad_hash = blake3::hash(aad).into();
76
77 Ok(EncryptedData {
78 ciphertext,
79 nonce: nonce_bytes,
80 aad_hash,
81 })
82 }
83
84 pub fn decrypt(&self, data: &EncryptedData, aad: &[u8]) -> Result<Vec<u8>> {
86 let expected_hash = blake3::hash(aad);
88 if data.aad_hash != *expected_hash.as_bytes() {
89 return Err(PhalanxError::auth("AAD hash mismatch"));
90 }
91
92 let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.0));
94 let nonce = Nonce::from_slice(&data.nonce);
95
96 let plaintext = cipher.decrypt(nonce, aead::Payload {
97 msg: &data.ciphertext,
98 aad,
99 })?;
100
101 Ok(plaintext)
102 }
103}
104
105pub fn derive_phalanx_key(ikm: &[u8], _salt: &[u8], info: &str) -> SymmetricKey {
107 let derived = blake3::derive_key(info, ikm);
108 SymmetricKey(derived)
109}
110
111pub fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Result<Vec<u8>> {
113 let hk = Hkdf::<Sha256>::from_prk(prk)
114 .map_err(|e| PhalanxError::key_derivation(format!("HKDF PRK invalid: {}", e)))?;
115
116 let mut output = vec![0u8; length];
117 hk.expand(info, &mut output)
118 .map_err(|e| PhalanxError::key_derivation(format!("HKDF expand failed: {}", e)))?;
119
120 Ok(output)
121}
122
123pub fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
125 let (prk, _) = Hkdf::<Sha256>::extract(Some(salt), ikm);
126 prk.into()
127}
128
129pub fn hash(data: &[u8]) -> [u8; 32] {
131 blake3::hash(data).into()
132}
133
134pub fn hash_multiple(inputs: &[&[u8]]) -> [u8; 32] {
136 let mut hasher = Hasher::new();
137 for input in inputs {
138 hasher.update(input);
139 }
140 hasher.finalize().into()
141}
142
143pub fn generate_nonce() -> [u8; 12] {
145 let mut nonce = [0u8; 12];
146 OsRng.fill_bytes(&mut nonce);
147 nonce
148}
149
150pub fn random_bytes(len: usize) -> Vec<u8> {
152 let mut bytes = vec![0u8; len];
153 OsRng.fill_bytes(&mut bytes);
154 bytes
155}
156
157pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
159 use subtle::ConstantTimeEq;
160 a.ct_eq(b).into()
161}
162
163impl std::fmt::Debug for SymmetricKey {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("SymmetricKey")
166 .field("key", &"[REDACTED]")
167 .finish()
168 }
169}
170
171#[cfg(feature = "serde")]
172mod serde_impl {
173 use super::*;
174 use serde::{Serialize, Deserialize};
175
176 impl Serialize for SymmetricKey {
177 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
178 where
179 S: serde::Serializer,
180 {
181 serializer.serialize_str(&base64::encode(self.as_bytes()))
182 }
183 }
184
185 impl<'de> Deserialize<'de> for SymmetricKey {
186 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
187 where
188 D: serde::Deserializer<'de>,
189 {
190 use serde::de::{self, Visitor};
191
192 struct SymmetricKeyVisitor;
193
194 impl<'de> Visitor<'de> for SymmetricKeyVisitor {
195 type Value = SymmetricKey;
196
197 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
198 formatter.write_str("a base64 encoded 32-byte key")
199 }
200
201 fn visit_str<E>(self, value: &str) -> std::result::Result<SymmetricKey, E>
202 where
203 E: de::Error,
204 {
205 let decoded = base64::decode(value)
206 .map_err(de::Error::custom)?;
207 if decoded.len() != 32 {
208 return Err(de::Error::custom("Invalid key length"));
209 }
210 let mut key_bytes = [0u8; 32];
211 key_bytes.copy_from_slice(&decoded);
212 SymmetricKey::from_bytes(key_bytes)
213 .map_err(de::Error::custom)
214 }
215 }
216
217 deserializer.deserialize_str(SymmetricKeyVisitor)
218 }
219 }
220
221 impl Serialize for EncryptedData {
222 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
223 where
224 S: serde::Serializer,
225 {
226 use serde::ser::SerializeStruct;
227
228 let mut state = serializer.serialize_struct("EncryptedData", 3)?;
229 state.serialize_field("ciphertext", &base64::encode(&self.ciphertext))?;
230 state.serialize_field("nonce", &base64::encode(&self.nonce))?;
231 state.serialize_field("aad_hash", &base64::encode(&self.aad_hash))?;
232 state.end()
233 }
234 }
235
236 impl<'de> Deserialize<'de> for EncryptedData {
237 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
238 where
239 D: serde::Deserializer<'de>,
240 {
241 use serde::de::{self, Visitor, MapAccess};
242
243 struct EncryptedDataVisitor;
244
245 impl<'de> Visitor<'de> for EncryptedDataVisitor {
246 type Value = EncryptedData;
247
248 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
249 formatter.write_str("struct EncryptedData")
250 }
251
252 fn visit_map<V>(self, mut map: V) -> std::result::Result<EncryptedData, V::Error>
253 where
254 V: MapAccess<'de>,
255 {
256 let mut ciphertext = None;
257 let mut nonce = None;
258 let mut aad_hash = None;
259
260 while let Some(key) = map.next_key()? {
261 match key {
262 "ciphertext" => {
263 let encoded: String = map.next_value()?;
264 ciphertext = Some(base64::decode(&encoded)
265 .map_err(de::Error::custom)?);
266 }
267 "nonce" => {
268 let encoded: String = map.next_value()?;
269 let decoded = base64::decode(&encoded)
270 .map_err(de::Error::custom)?;
271 if decoded.len() != 12 {
272 return Err(de::Error::custom("Invalid nonce length"));
273 }
274 let mut n = [0u8; 12];
275 n.copy_from_slice(&decoded);
276 nonce = Some(n);
277 }
278 "aad_hash" => {
279 let encoded: String = map.next_value()?;
280 let decoded = base64::decode(&encoded)
281 .map_err(de::Error::custom)?;
282 if decoded.len() != 32 {
283 return Err(de::Error::custom("Invalid AAD hash length"));
284 }
285 let mut h = [0u8; 32];
286 h.copy_from_slice(&decoded);
287 aad_hash = Some(h);
288 }
289 _ => {
290 let _: serde::de::IgnoredAny = map.next_value()?;
291 }
292 }
293 }
294
295 let ciphertext = ciphertext.ok_or_else(|| de::Error::missing_field("ciphertext"))?;
296 let nonce = nonce.ok_or_else(|| de::Error::missing_field("nonce"))?;
297 let aad_hash = aad_hash.ok_or_else(|| de::Error::missing_field("aad_hash"))?;
298
299 Ok(EncryptedData {
300 ciphertext,
301 nonce,
302 aad_hash,
303 })
304 }
305 }
306
307 deserializer.deserialize_struct("EncryptedData", &["ciphertext", "nonce", "aad_hash"], EncryptedDataVisitor)
308 }
309 }
310}
311
312use chacha20poly1305::aead;
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_symmetric_encryption() {
321 let key = SymmetricKey::generate();
322 let plaintext = b"Hello, world!";
323 let aad = b"additional data";
324
325 let encrypted = key.encrypt(plaintext, aad).unwrap();
326 let decrypted = key.decrypt(&encrypted, aad).unwrap();
327
328 assert_eq!(decrypted, plaintext);
329 }
330
331 #[test]
332 fn test_key_derivation() {
333 let ikm = b"input key material";
334 let salt = b"salt";
335 let info = contexts::GROUP_KEY;
336
337 let key1 = derive_phalanx_key(ikm, salt, info);
338 let key2 = derive_phalanx_key(ikm, salt, info);
339
340 assert_eq!(key1.as_bytes(), key2.as_bytes());
342 }
343
344 #[test]
345 fn test_hkdf() {
346 let ikm = b"input key material";
347 let salt = b"salt";
348 let info = b"info";
349
350 let prk = hkdf_extract(salt, ikm);
351 let okm = hkdf_expand(&prk, info, 32).unwrap();
352
353 assert_eq!(okm.len(), 32);
354 }
355
356 #[test]
357 fn test_hash_functions() {
358 let data = b"test data";
359 let hash1 = hash(data);
360 let hash2 = hash(data);
361
362 assert_eq!(hash1, hash2);
363
364 let multi_hash = hash_multiple(&[b"part1", b"part2"]);
365 let single_hash = hash(b"part1part2");
366
367 assert_eq!(multi_hash, single_hash);
369 }
370}