1use crate::{
4 error::{PhalanxError, Result},
5 identity::{Identity, PublicKey},
6 crypto::{EncryptedData, derive_phalanx_key, contexts},
7};
8use ed25519_dalek::Signature;
9use x25519_dalek::PublicKey as X25519PublicKey;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12#[cfg(feature = "serde")]
13use serde::{Serialize, Deserialize};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub enum ProtocolVersion {
19 V1 = 1,
21}
22
23impl ProtocolVersion {
24 pub fn current() -> Self {
26 Self::V1
27 }
28
29 pub fn is_compatible_with(self, other: Self) -> bool {
31 self == other }
33}
34
35impl TryFrom<u8> for ProtocolVersion {
36 type Error = PhalanxError;
37
38 fn try_from(value: u8) -> Result<Self> {
39 match value {
40 1 => Ok(Self::V1),
41 _ => Err(PhalanxError::version(format!("Unsupported protocol version: {}", value))),
42 }
43 }
44}
45
46impl From<ProtocolVersion> for u8 {
47 fn from(version: ProtocolVersion) -> u8 {
48 version as u8
49 }
50}
51
52#[derive(Debug, Clone)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub struct HandshakeMessage {
56 pub version: ProtocolVersion,
58 pub sender_key: PublicKey,
60 pub ephemeral_key: X25519PublicKey,
62 pub timestamp: u64,
64 pub encrypted_payload: EncryptedData,
66 pub signature: Signature,
68}
69
70#[derive(Debug, Clone)]
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73pub struct HandshakePayload {
74 pub group_id: [u8; 32],
76 pub capabilities: Vec<String>,
78 pub client_info: String,
80 pub membership_proof: Option<Vec<u8>>,
82 pub encrypted_group_key: Option<Vec<u8>>,
84}
85
86#[derive(Debug, Clone)]
88#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
89pub struct KeyRotationMessage {
90 pub version: ProtocolVersion,
92 pub sequence: u64,
94 pub timestamp: u64,
96 pub member_keys: Vec<(PublicKey, X25519PublicKey)>,
98 pub signature: Signature,
100}
101
102#[derive(Debug, Clone)]
104#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
105pub struct MembershipChange {
106 pub change_type: MembershipChangeType,
108 pub member_key: PublicKey,
110 pub timestamp: u64,
112 pub signature: Signature,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub enum MembershipChangeType {
120 Join,
122 Leave,
124 Remove,
126 RoleChange,
128}
129
130impl HandshakeMessage {
131 pub fn new(
133 sender: &Identity,
134 group_id: [u8; 32],
135 capabilities: Vec<String>,
136 client_info: String,
137 ) -> Result<Self> {
138 let timestamp = SystemTime::now()
139 .duration_since(UNIX_EPOCH)
140 .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
141 .as_secs();
142
143 let sender_key = sender.public_key();
144
145 let mut sender_mut = sender.clone();
147 let ephemeral_key = sender_mut.generate_kx_key();
148
149 let payload = HandshakePayload {
151 group_id,
152 capabilities,
153 client_info,
154 membership_proof: None,
155 encrypted_group_key: None,
156 };
157
158 let handshake_key = derive_phalanx_key(
160 &sender.id(),
161 b"PHALANX_HANDSHAKE",
162 contexts::KEY_EXCHANGE,
163 );
164
165 let payload_bytes = Self::serialize_payload(&payload)?;
167 let aad = Self::create_handshake_aad(&sender_key, &ephemeral_key, timestamp);
168 let encrypted_payload = handshake_key.encrypt(&payload_bytes, &aad)?;
169
170 let signature_data = Self::create_signature_data(
172 ProtocolVersion::current(),
173 &sender_key,
174 &ephemeral_key,
175 timestamp,
176 &encrypted_payload,
177 );
178 let signature = sender.sign(&signature_data);
179
180 Ok(Self {
181 version: ProtocolVersion::current(),
182 sender_key,
183 ephemeral_key,
184 timestamp,
185 encrypted_payload,
186 signature,
187 })
188 }
189
190 pub fn new_with_group_key(
192 sender: &mut Identity,
193 recipient_public_key: &PublicKey,
194 group_id: [u8; 32],
195 capabilities: Vec<String>,
196 client_info: String,
197 group_key: &crate::crypto::SymmetricKey,
198 ) -> Result<Self> {
199 use crate::crypto::{derive_phalanx_key, contexts};
200
201 let timestamp = SystemTime::now()
202 .duration_since(UNIX_EPOCH)
203 .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
204 .as_secs();
205
206 let sender_key = sender.public_key();
207
208 let ephemeral_key = sender.generate_kx_key();
210
211 let shared_secret = sender.key_exchange(&recipient_public_key.kx_public)?;
213
214 let encryption_key = derive_phalanx_key(
216 &shared_secret,
217 b"PHALANX_GROUP_KEY",
218 contexts::KEY_EXCHANGE,
219 );
220
221 let group_key_bytes = group_key.as_bytes();
223 let aad = b"PHALANX_GROUP_KEY_V1";
224 let encrypted_group_key_data = encryption_key.encrypt(group_key_bytes, aad)?;
225 let encrypted_group_key_bytes = serde_json::to_vec(&encrypted_group_key_data)
226 .map_err(|e| PhalanxError::crypto(format!("Group key encryption serialization failed: {}", e)))?;
227
228 let payload = HandshakePayload {
230 group_id,
231 capabilities,
232 client_info,
233 membership_proof: None,
234 encrypted_group_key: Some(encrypted_group_key_bytes),
235 };
236
237 let handshake_key = derive_phalanx_key(
239 &sender.id(),
240 b"PHALANX_HANDSHAKE",
241 contexts::KEY_EXCHANGE,
242 );
243
244 let payload_bytes = Self::serialize_payload(&payload)?;
246 let aad = Self::create_handshake_aad(&sender_key, &ephemeral_key, timestamp);
247 let encrypted_payload = handshake_key.encrypt(&payload_bytes, &aad)?;
248
249 let signature_data = Self::create_signature_data(
251 ProtocolVersion::current(),
252 &sender_key,
253 &ephemeral_key,
254 timestamp,
255 &encrypted_payload,
256 );
257 let signature = sender.sign(&signature_data);
258
259 Ok(Self {
260 version: ProtocolVersion::current(),
261 sender_key,
262 ephemeral_key,
263 timestamp,
264 encrypted_payload,
265 signature,
266 })
267 }
268
269 pub fn extract_group_key(&self, recipient: &mut Identity) -> Result<Option<crate::crypto::SymmetricKey>> {
271 use crate::crypto::{derive_phalanx_key, contexts};
272
273 let payload = self.verify_and_decrypt()?;
275
276 if let Some(encrypted_group_key_bytes) = payload.encrypted_group_key {
277 let shared_secret = recipient.static_key_exchange(&self.ephemeral_key)?;
279
280 let decryption_key = derive_phalanx_key(
282 &shared_secret,
283 b"PHALANX_GROUP_KEY",
284 contexts::KEY_EXCHANGE,
285 );
286
287 let encrypted_group_key_data: crate::crypto::EncryptedData =
289 serde_json::from_slice(&encrypted_group_key_bytes)
290 .map_err(|e| PhalanxError::crypto(format!("Group key decryption deserialization failed: {}", e)))?;
291
292 let aad = b"PHALANX_GROUP_KEY_V1";
294 let group_key_bytes = decryption_key.decrypt(&encrypted_group_key_data, aad)?;
295
296 if group_key_bytes.len() != 32 {
298 return Err(PhalanxError::crypto("Invalid group key size"));
299 }
300 let mut key_array = [0u8; 32];
301 key_array.copy_from_slice(&group_key_bytes);
302 let group_key = crate::crypto::SymmetricKey::from_bytes(key_array)?;
303
304 Ok(Some(group_key))
305 } else {
306 Ok(None)
307 }
308 }
309
310 pub fn verify_and_decrypt(&self) -> Result<HandshakePayload> {
312 let signature_data = Self::create_signature_data(
314 self.version,
315 &self.sender_key,
316 &self.ephemeral_key,
317 self.timestamp,
318 &self.encrypted_payload,
319 );
320
321 self.sender_key.verify(&signature_data, &self.signature)?;
322
323 let handshake_key = derive_phalanx_key(
325 &self.sender_key.id(),
326 b"PHALANX_HANDSHAKE",
327 contexts::KEY_EXCHANGE,
328 );
329
330 let aad = Self::create_handshake_aad(&self.sender_key, &self.ephemeral_key, self.timestamp);
332 let decrypted_bytes = handshake_key.decrypt(&self.encrypted_payload, &aad)?;
333
334 Self::deserialize_payload(&decrypted_bytes)
336 }
337
338 pub fn is_recent(&self) -> bool {
340 if let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) {
341 let age = now.as_secs().saturating_sub(self.timestamp);
342 age <= 300 } else {
344 false
345 }
346 }
347
348 fn create_handshake_aad(sender: &PublicKey, ephemeral: &X25519PublicKey, timestamp: u64) -> Vec<u8> {
349 let mut aad = Vec::new();
350 aad.extend_from_slice(&sender.id());
351 aad.extend_from_slice(ephemeral.as_bytes());
352 aad.extend_from_slice(×tamp.to_be_bytes());
353 aad.extend_from_slice(b"PHALANX_HANDSHAKE_V1");
354 aad
355 }
356
357 fn create_signature_data(
358 version: ProtocolVersion,
359 sender: &PublicKey,
360 ephemeral: &X25519PublicKey,
361 timestamp: u64,
362 encrypted_payload: &EncryptedData,
363 ) -> Vec<u8> {
364 let mut data = Vec::new();
365 data.push(version.into());
366 data.extend_from_slice(&sender.id());
367 data.extend_from_slice(ephemeral.as_bytes());
368 data.extend_from_slice(×tamp.to_be_bytes());
369 data.extend_from_slice(&encrypted_payload.ciphertext);
370 data.extend_from_slice(&encrypted_payload.nonce);
371 data.extend_from_slice(&encrypted_payload.aad_hash);
372 data.extend_from_slice(b"PHALANX_HANDSHAKE_SIG_V1");
373 data
374 }
375
376 #[cfg(feature = "serde")]
377 fn serialize_payload(payload: &HandshakePayload) -> Result<Vec<u8>> {
378 serde_json::to_vec(payload)
379 .map_err(|e| PhalanxError::protocol(format!("Handshake payload serialization failed: {}", e)))
380 }
381
382 #[cfg(not(feature = "serde"))]
383 fn serialize_payload(payload: &HandshakePayload) -> Result<Vec<u8>> {
384 let mut bytes = Vec::new();
385
386 bytes.extend_from_slice(&payload.group_id);
388
389 let cap_count = payload.capabilities.len() as u32;
391 bytes.extend_from_slice(&cap_count.to_be_bytes());
392 for cap in &payload.capabilities {
393 let cap_bytes = cap.as_bytes();
394 let cap_len = cap_bytes.len() as u32;
395 bytes.extend_from_slice(&cap_len.to_be_bytes());
396 bytes.extend_from_slice(cap_bytes);
397 }
398
399 let info_bytes = payload.client_info.as_bytes();
401 let info_len = info_bytes.len() as u32;
402 bytes.extend_from_slice(&info_len.to_be_bytes());
403 bytes.extend_from_slice(info_bytes);
404
405 if let Some(proof) = &payload.membership_proof {
407 bytes.push(1); let proof_len = proof.len() as u32;
409 bytes.extend_from_slice(&proof_len.to_be_bytes());
410 bytes.extend_from_slice(proof);
411 } else {
412 bytes.push(0); }
414
415 if let Some(encrypted_key) = &payload.encrypted_group_key {
417 bytes.push(1); let key_len = encrypted_key.len() as u32;
419 bytes.extend_from_slice(&key_len.to_be_bytes());
420 bytes.extend_from_slice(encrypted_key);
421 } else {
422 bytes.push(0); }
424
425 Ok(bytes)
426 }
427
428 #[cfg(feature = "serde")]
429 fn deserialize_payload(bytes: &[u8]) -> Result<HandshakePayload> {
430 serde_json::from_slice(bytes)
431 .map_err(|e| PhalanxError::protocol(format!("Handshake payload deserialization failed: {}", e)))
432 }
433
434 #[cfg(not(feature = "serde"))]
435 fn deserialize_payload(bytes: &[u8]) -> Result<HandshakePayload> {
436 if bytes.len() < 32 + 4 {
437 return Err(PhalanxError::protocol("Invalid handshake payload"));
438 }
439
440 let mut pos = 0;
441
442 let mut group_id = [0u8; 32];
444 group_id.copy_from_slice(&bytes[pos..pos + 32]);
445 pos += 32;
446
447 let cap_count = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
449 pos += 4;
450
451 let mut capabilities = Vec::new();
452 for _ in 0..cap_count {
453 if pos + 4 > bytes.len() {
454 return Err(PhalanxError::protocol("Truncated capability"));
455 }
456
457 let cap_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
458 pos += 4;
459
460 if pos + cap_len > bytes.len() {
461 return Err(PhalanxError::protocol("Truncated capability data"));
462 }
463
464 let cap_str = String::from_utf8(bytes[pos..pos + cap_len].to_vec())
465 .map_err(|_| PhalanxError::protocol("Invalid UTF-8 in capability"))?;
466 capabilities.push(cap_str);
467 pos += cap_len;
468 }
469
470 if pos + 4 > bytes.len() {
472 return Err(PhalanxError::protocol("Truncated client info length"));
473 }
474
475 let info_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
476 pos += 4;
477
478 if pos + info_len > bytes.len() {
479 return Err(PhalanxError::protocol("Truncated client info"));
480 }
481
482 let client_info = String::from_utf8(bytes[pos..pos + info_len].to_vec())
483 .map_err(|_| PhalanxError::protocol("Invalid UTF-8 in client info"))?;
484 pos += info_len;
485
486 if pos >= bytes.len() {
488 return Err(PhalanxError::protocol("Truncated membership proof marker"));
489 }
490
491 let membership_proof = if bytes[pos] == 1 {
492 pos += 1;
493 if pos + 4 > bytes.len() {
494 return Err(PhalanxError::protocol("Truncated proof length"));
495 }
496
497 let proof_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
498 pos += 4;
499
500 if pos + proof_len > bytes.len() {
501 return Err(PhalanxError::protocol("Truncated proof data"));
502 }
503
504 Some(bytes[pos..pos + proof_len].to_vec())
505 } else {
506 pos += 1;
507 None
508 };
509
510 let encrypted_group_key = if pos < bytes.len() && bytes[pos] == 1 {
512 pos += 1;
513 if pos + 4 > bytes.len() {
514 return Err(PhalanxError::protocol("Truncated encrypted group key length"));
515 }
516
517 let key_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
518 pos += 4;
519
520 if pos + key_len > bytes.len() {
521 return Err(PhalanxError::protocol("Truncated encrypted group key data"));
522 }
523
524 Some(bytes[pos..pos + key_len].to_vec())
525 } else {
526 None
527 };
528
529 Ok(HandshakePayload {
530 group_id,
531 capabilities,
532 client_info,
533 membership_proof,
534 encrypted_group_key,
535 })
536 }
537}
538
539impl KeyRotationMessage {
540 pub fn new(
542 admin: &Identity,
543 sequence: u64,
544 member_keys: Vec<(PublicKey, X25519PublicKey)>,
545 ) -> Result<Self> {
546 let timestamp = SystemTime::now()
547 .duration_since(UNIX_EPOCH)
548 .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
549 .as_secs();
550
551 let signature_data = Self::create_signature_data(sequence, timestamp, &member_keys);
553 let signature = admin.sign(&signature_data);
554
555 Ok(Self {
556 version: ProtocolVersion::current(),
557 sequence,
558 timestamp,
559 member_keys,
560 signature,
561 })
562 }
563
564 pub fn verify(&self, admin_key: &PublicKey) -> Result<()> {
566 let signature_data = Self::create_signature_data(self.sequence, self.timestamp, &self.member_keys);
567 admin_key.verify(&signature_data, &self.signature)
568 }
569
570 fn create_signature_data(
571 sequence: u64,
572 timestamp: u64,
573 member_keys: &[(PublicKey, X25519PublicKey)],
574 ) -> Vec<u8> {
575 let mut data = Vec::new();
576 data.push(ProtocolVersion::current().into());
577 data.extend_from_slice(&sequence.to_be_bytes());
578 data.extend_from_slice(×tamp.to_be_bytes());
579
580 for (pub_key, ephemeral) in member_keys {
581 data.extend_from_slice(&pub_key.id());
582 data.extend_from_slice(ephemeral.as_bytes());
583 }
584
585 data.extend_from_slice(b"PHALANX_KEY_ROTATION_V1");
586 data
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_handshake_message() {
596 let sender = Identity::generate();
597 let group_id = [1u8; 32];
598 let capabilities = vec!["phalanx/v1".to_string(), "threading".to_string()];
599 let client_info = "test-client/1.0".to_string();
600
601 let handshake = HandshakeMessage::new(
602 &sender,
603 group_id,
604 capabilities.clone(),
605 client_info.clone(),
606 ).unwrap();
607
608 let payload = handshake.verify_and_decrypt().unwrap();
609
610 assert_eq!(payload.group_id, group_id);
611 assert_eq!(payload.capabilities, capabilities);
612 assert_eq!(payload.client_info, client_info);
613 }
614
615 #[test]
616 fn test_key_rotation() {
617 let admin = Identity::generate();
618 let member1 = Identity::generate();
619 let member2 = Identity::generate();
620
621 let mut member1_clone = member1.clone();
622 let mut member2_clone = member2.clone();
623
624 let member_keys = vec![
625 (member1.public_key(), member1_clone.generate_kx_key()),
626 (member2.public_key(), member2_clone.generate_kx_key()),
627 ];
628
629 let rotation = KeyRotationMessage::new(&admin, 1, member_keys).unwrap();
630
631 assert!(rotation.verify(&admin.public_key()).is_ok());
632 }
633
634 #[test]
635 fn test_protocol_version_compatibility() {
636 let v1 = ProtocolVersion::V1;
637 assert!(v1.is_compatible_with(ProtocolVersion::V1));
638
639 let converted: u8 = v1.into();
640 let back: ProtocolVersion = converted.try_into().unwrap();
641 assert_eq!(v1, back);
642 }
643}