phalanx_crypto/
protocol.rs

1//! Protocol messages and handshake logic for Phalanx
2
3use crate::{
4    error::{PhalanxError, Result},
5    identity::{Identity, PublicKey},
6    crypto::{EncryptedData, derive_phalanx_key, contexts},
7};
8use ed25519_dalek::Signature;
9use x25519_dalek::PublicKey as X25519PublicKey;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12#[cfg(feature = "serde")]
13use serde::{Serialize, Deserialize};
14
15/// Protocol version enumeration
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub enum ProtocolVersion {
19    /// Version 1 - Initial Phalanx protocol
20    V1 = 1,
21}
22
23impl ProtocolVersion {
24    /// Get the current protocol version
25    pub fn current() -> Self {
26        Self::V1
27    }
28    
29    /// Check if this version is compatible with another
30    pub fn is_compatible_with(self, other: Self) -> bool {
31        self == other // For now, exact match required
32    }
33}
34
35impl TryFrom<u8> for ProtocolVersion {
36    type Error = PhalanxError;
37    
38    fn try_from(value: u8) -> Result<Self> {
39        match value {
40            1 => Ok(Self::V1),
41            _ => Err(PhalanxError::version(format!("Unsupported protocol version: {}", value))),
42        }
43    }
44}
45
46impl From<ProtocolVersion> for u8 {
47    fn from(version: ProtocolVersion) -> u8 {
48        version as u8
49    }
50}
51
52/// Initial handshake message for group joining
53#[derive(Debug, Clone)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub struct HandshakeMessage {
56    /// Protocol version
57    pub version: ProtocolVersion,
58    /// Sender's public key
59    pub sender_key: PublicKey,
60    /// Ephemeral key for this handshake
61    pub ephemeral_key: X25519PublicKey,
62    /// Timestamp of the handshake
63    pub timestamp: u64,
64    /// Encrypted handshake payload
65    pub encrypted_payload: EncryptedData,
66    /// Signature of the handshake
67    pub signature: Signature,
68}
69
70/// Handshake payload content
71#[derive(Debug, Clone)]
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73pub struct HandshakePayload {
74    /// Group ID being joined
75    pub group_id: [u8; 32],
76    /// Supported capabilities
77    pub capabilities: Vec<String>,
78    /// Client information
79    pub client_info: String,
80    /// Proof of membership (if required)
81    pub membership_proof: Option<Vec<u8>>,
82    /// Encrypted group key for secure key sharing
83    pub encrypted_group_key: Option<Vec<u8>>,
84}
85
86/// Key rotation message for forward secrecy
87#[derive(Debug, Clone)]
88#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
89pub struct KeyRotationMessage {
90    /// Protocol version
91    pub version: ProtocolVersion,
92    /// Rotation sequence number
93    pub sequence: u64,
94    /// Timestamp of rotation
95    pub timestamp: u64,
96    /// New ephemeral keys for each member
97    pub member_keys: Vec<(PublicKey, X25519PublicKey)>,
98    /// Signature by group admin
99    pub signature: Signature,
100}
101
102/// Group membership change notification
103#[derive(Debug, Clone)]
104#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
105pub struct MembershipChange {
106    /// Change type
107    pub change_type: MembershipChangeType,
108    /// Member's public key
109    pub member_key: PublicKey,
110    /// Timestamp of change
111    pub timestamp: u64,
112    /// Admin signature
113    pub signature: Signature,
114}
115
116/// Types of membership changes
117#[derive(Debug, Clone, PartialEq, Eq)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub enum MembershipChangeType {
120    /// Member joined the group
121    Join,
122    /// Member left the group
123    Leave,
124    /// Member was removed from the group
125    Remove,
126    /// Member role changed
127    RoleChange,
128}
129
130impl HandshakeMessage {
131    /// Create a new handshake message
132    pub fn new(
133        sender: &Identity,
134        group_id: [u8; 32],
135        capabilities: Vec<String>,
136        client_info: String,
137    ) -> Result<Self> {
138        let timestamp = SystemTime::now()
139            .duration_since(UNIX_EPOCH)
140            .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
141            .as_secs();
142        
143        let sender_key = sender.public_key();
144        
145        // Generate ephemeral key for this handshake
146        let mut sender_mut = sender.clone();
147        let ephemeral_key = sender_mut.generate_kx_key();
148        
149        // Create handshake payload
150        let payload = HandshakePayload {
151            group_id,
152            capabilities,
153            client_info,
154            membership_proof: None,
155            encrypted_group_key: None,
156        };
157        
158        // Derive handshake key from sender's identity
159        let handshake_key = derive_phalanx_key(
160            &sender.id(),
161            b"PHALANX_HANDSHAKE",
162            contexts::KEY_EXCHANGE,
163        );
164        
165        // Encrypt payload
166        let payload_bytes = Self::serialize_payload(&payload)?;
167        let aad = Self::create_handshake_aad(&sender_key, &ephemeral_key, timestamp);
168        let encrypted_payload = handshake_key.encrypt(&payload_bytes, &aad)?;
169        
170        // Sign the handshake
171        let signature_data = Self::create_signature_data(
172            ProtocolVersion::current(),
173            &sender_key,
174            &ephemeral_key,
175            timestamp,
176            &encrypted_payload,
177        );
178        let signature = sender.sign(&signature_data);
179        
180        Ok(Self {
181            version: ProtocolVersion::current(),
182            sender_key,
183            ephemeral_key,
184            timestamp,
185            encrypted_payload,
186            signature,
187        })
188    }
189    
190    /// Create a handshake message with encrypted group key for secure key sharing
191    pub fn new_with_group_key(
192        sender: &mut Identity,
193        recipient_public_key: &PublicKey,
194        group_id: [u8; 32],
195        capabilities: Vec<String>,
196        client_info: String,
197        group_key: &crate::crypto::SymmetricKey,
198    ) -> Result<Self> {
199        use crate::crypto::{derive_phalanx_key, contexts};
200        
201        let timestamp = SystemTime::now()
202            .duration_since(UNIX_EPOCH)
203            .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
204            .as_secs();
205        
206        let sender_key = sender.public_key();
207        
208        // Generate ephemeral key for this handshake
209        let ephemeral_key = sender.generate_kx_key();
210        
211        // Perform X25519 key exchange to get shared secret
212        let shared_secret = sender.key_exchange(&recipient_public_key.kx_public)?;
213        
214        // Derive encryption key from shared secret
215        let encryption_key = derive_phalanx_key(
216            &shared_secret,
217            b"PHALANX_GROUP_KEY",
218            contexts::KEY_EXCHANGE,
219        );
220        
221        // Encrypt the group key
222        let group_key_bytes = group_key.as_bytes();
223        let aad = b"PHALANX_GROUP_KEY_V1";
224        let encrypted_group_key_data = encryption_key.encrypt(group_key_bytes, aad)?;
225        let encrypted_group_key_bytes = serde_json::to_vec(&encrypted_group_key_data)
226            .map_err(|e| PhalanxError::crypto(format!("Group key encryption serialization failed: {}", e)))?;
227        
228        // Create handshake payload with encrypted group key
229        let payload = HandshakePayload {
230            group_id,
231            capabilities,
232            client_info,
233            membership_proof: None,
234            encrypted_group_key: Some(encrypted_group_key_bytes),
235        };
236        
237        // Derive handshake key from sender's identity
238        let handshake_key = derive_phalanx_key(
239            &sender.id(),
240            b"PHALANX_HANDSHAKE",
241            contexts::KEY_EXCHANGE,
242        );
243        
244        // Encrypt payload
245        let payload_bytes = Self::serialize_payload(&payload)?;
246        let aad = Self::create_handshake_aad(&sender_key, &ephemeral_key, timestamp);
247        let encrypted_payload = handshake_key.encrypt(&payload_bytes, &aad)?;
248        
249        // Sign the handshake
250        let signature_data = Self::create_signature_data(
251            ProtocolVersion::current(),
252            &sender_key,
253            &ephemeral_key,
254            timestamp,
255            &encrypted_payload,
256        );
257        let signature = sender.sign(&signature_data);
258        
259        Ok(Self {
260            version: ProtocolVersion::current(),
261            sender_key,
262            ephemeral_key,
263            timestamp,
264            encrypted_payload,
265            signature,
266        })
267    }
268    
269    /// Extract group key from handshake message
270    pub fn extract_group_key(&self, recipient: &mut Identity) -> Result<Option<crate::crypto::SymmetricKey>> {
271        use crate::crypto::{derive_phalanx_key, contexts};
272        
273        // First verify and decrypt the handshake payload
274        let payload = self.verify_and_decrypt()?;
275        
276        if let Some(encrypted_group_key_bytes) = payload.encrypted_group_key {
277            // Perform key exchange to get shared secret using sender's ephemeral key
278            let shared_secret = recipient.static_key_exchange(&self.ephemeral_key)?;
279            
280            // Derive decryption key from shared secret
281            let decryption_key = derive_phalanx_key(
282                &shared_secret,
283                b"PHALANX_GROUP_KEY",
284                contexts::KEY_EXCHANGE,
285            );
286            
287            // Deserialize encrypted data
288            let encrypted_group_key_data: crate::crypto::EncryptedData = 
289                serde_json::from_slice(&encrypted_group_key_bytes)
290                    .map_err(|e| PhalanxError::crypto(format!("Group key decryption deserialization failed: {}", e)))?;
291            
292            // Decrypt the group key
293            let aad = b"PHALANX_GROUP_KEY_V1";
294            let group_key_bytes = decryption_key.decrypt(&encrypted_group_key_data, aad)?;
295            
296            // Convert back to SymmetricKey
297            if group_key_bytes.len() != 32 {
298                return Err(PhalanxError::crypto("Invalid group key size"));
299            }
300            let mut key_array = [0u8; 32];
301            key_array.copy_from_slice(&group_key_bytes);
302            let group_key = crate::crypto::SymmetricKey::from_bytes(key_array)?;
303            
304            Ok(Some(group_key))
305        } else {
306            Ok(None)
307        }
308    }
309    
310    /// Verify and decrypt a handshake message
311    pub fn verify_and_decrypt(&self) -> Result<HandshakePayload> {
312        // Verify signature first
313        let signature_data = Self::create_signature_data(
314            self.version,
315            &self.sender_key,
316            &self.ephemeral_key,
317            self.timestamp,
318            &self.encrypted_payload,
319        );
320        
321        self.sender_key.verify(&signature_data, &self.signature)?;
322        
323        // Derive handshake key
324        let handshake_key = derive_phalanx_key(
325            &self.sender_key.id(),
326            b"PHALANX_HANDSHAKE",
327            contexts::KEY_EXCHANGE,
328        );
329        
330        // Decrypt payload
331        let aad = Self::create_handshake_aad(&self.sender_key, &self.ephemeral_key, self.timestamp);
332        let decrypted_bytes = handshake_key.decrypt(&self.encrypted_payload, &aad)?;
333        
334        // Deserialize payload
335        Self::deserialize_payload(&decrypted_bytes)
336    }
337    
338    /// Check if handshake is recent (within last 5 minutes)
339    pub fn is_recent(&self) -> bool {
340        if let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) {
341            let age = now.as_secs().saturating_sub(self.timestamp);
342            age <= 300 // 5 minutes
343        } else {
344            false
345        }
346    }
347    
348    fn create_handshake_aad(sender: &PublicKey, ephemeral: &X25519PublicKey, timestamp: u64) -> Vec<u8> {
349        let mut aad = Vec::new();
350        aad.extend_from_slice(&sender.id());
351        aad.extend_from_slice(ephemeral.as_bytes());
352        aad.extend_from_slice(&timestamp.to_be_bytes());
353        aad.extend_from_slice(b"PHALANX_HANDSHAKE_V1");
354        aad
355    }
356    
357    fn create_signature_data(
358        version: ProtocolVersion,
359        sender: &PublicKey,
360        ephemeral: &X25519PublicKey,
361        timestamp: u64,
362        encrypted_payload: &EncryptedData,
363    ) -> Vec<u8> {
364        let mut data = Vec::new();
365        data.push(version.into());
366        data.extend_from_slice(&sender.id());
367        data.extend_from_slice(ephemeral.as_bytes());
368        data.extend_from_slice(&timestamp.to_be_bytes());
369        data.extend_from_slice(&encrypted_payload.ciphertext);
370        data.extend_from_slice(&encrypted_payload.nonce);
371        data.extend_from_slice(&encrypted_payload.aad_hash);
372        data.extend_from_slice(b"PHALANX_HANDSHAKE_SIG_V1");
373        data
374    }
375    
376    #[cfg(feature = "serde")]
377    fn serialize_payload(payload: &HandshakePayload) -> Result<Vec<u8>> {
378        serde_json::to_vec(payload)
379            .map_err(|e| PhalanxError::protocol(format!("Handshake payload serialization failed: {}", e)))
380    }
381    
382    #[cfg(not(feature = "serde"))]
383    fn serialize_payload(payload: &HandshakePayload) -> Result<Vec<u8>> {
384        let mut bytes = Vec::new();
385        
386        // Group ID
387        bytes.extend_from_slice(&payload.group_id);
388        
389        // Capabilities count and data
390        let cap_count = payload.capabilities.len() as u32;
391        bytes.extend_from_slice(&cap_count.to_be_bytes());
392        for cap in &payload.capabilities {
393            let cap_bytes = cap.as_bytes();
394            let cap_len = cap_bytes.len() as u32;
395            bytes.extend_from_slice(&cap_len.to_be_bytes());
396            bytes.extend_from_slice(cap_bytes);
397        }
398        
399        // Client info
400        let info_bytes = payload.client_info.as_bytes();
401        let info_len = info_bytes.len() as u32;
402        bytes.extend_from_slice(&info_len.to_be_bytes());
403        bytes.extend_from_slice(info_bytes);
404        
405        // Membership proof
406        if let Some(proof) = &payload.membership_proof {
407            bytes.push(1); // Present
408            let proof_len = proof.len() as u32;
409            bytes.extend_from_slice(&proof_len.to_be_bytes());
410            bytes.extend_from_slice(proof);
411        } else {
412            bytes.push(0); // Not present
413        }
414        
415        // Encrypted group key
416        if let Some(encrypted_key) = &payload.encrypted_group_key {
417            bytes.push(1); // Present
418            let key_len = encrypted_key.len() as u32;
419            bytes.extend_from_slice(&key_len.to_be_bytes());
420            bytes.extend_from_slice(encrypted_key);
421        } else {
422            bytes.push(0); // Not present
423        }
424        
425        Ok(bytes)
426    }
427    
428    #[cfg(feature = "serde")]
429    fn deserialize_payload(bytes: &[u8]) -> Result<HandshakePayload> {
430        serde_json::from_slice(bytes)
431            .map_err(|e| PhalanxError::protocol(format!("Handshake payload deserialization failed: {}", e)))
432    }
433    
434    #[cfg(not(feature = "serde"))]
435    fn deserialize_payload(bytes: &[u8]) -> Result<HandshakePayload> {
436        if bytes.len() < 32 + 4 {
437            return Err(PhalanxError::protocol("Invalid handshake payload"));
438        }
439        
440        let mut pos = 0;
441        
442        // Group ID
443        let mut group_id = [0u8; 32];
444        group_id.copy_from_slice(&bytes[pos..pos + 32]);
445        pos += 32;
446        
447        // Capabilities
448        let cap_count = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
449        pos += 4;
450        
451        let mut capabilities = Vec::new();
452        for _ in 0..cap_count {
453            if pos + 4 > bytes.len() {
454                return Err(PhalanxError::protocol("Truncated capability"));
455            }
456            
457            let cap_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
458            pos += 4;
459            
460            if pos + cap_len > bytes.len() {
461                return Err(PhalanxError::protocol("Truncated capability data"));
462            }
463            
464            let cap_str = String::from_utf8(bytes[pos..pos + cap_len].to_vec())
465                .map_err(|_| PhalanxError::protocol("Invalid UTF-8 in capability"))?;
466            capabilities.push(cap_str);
467            pos += cap_len;
468        }
469        
470        // Client info
471        if pos + 4 > bytes.len() {
472            return Err(PhalanxError::protocol("Truncated client info length"));
473        }
474        
475        let info_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
476        pos += 4;
477        
478        if pos + info_len > bytes.len() {
479            return Err(PhalanxError::protocol("Truncated client info"));
480        }
481        
482        let client_info = String::from_utf8(bytes[pos..pos + info_len].to_vec())
483            .map_err(|_| PhalanxError::protocol("Invalid UTF-8 in client info"))?;
484        pos += info_len;
485        
486        // Membership proof
487        if pos >= bytes.len() {
488            return Err(PhalanxError::protocol("Truncated membership proof marker"));
489        }
490        
491        let membership_proof = if bytes[pos] == 1 {
492            pos += 1;
493            if pos + 4 > bytes.len() {
494                return Err(PhalanxError::protocol("Truncated proof length"));
495            }
496            
497            let proof_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
498            pos += 4;
499            
500            if pos + proof_len > bytes.len() {
501                return Err(PhalanxError::protocol("Truncated proof data"));
502            }
503            
504            Some(bytes[pos..pos + proof_len].to_vec())
505        } else {
506            pos += 1;
507            None
508        };
509        
510        // Encrypted group key
511        let encrypted_group_key = if pos < bytes.len() && bytes[pos] == 1 {
512            pos += 1;
513            if pos + 4 > bytes.len() {
514                return Err(PhalanxError::protocol("Truncated encrypted group key length"));
515            }
516            
517            let key_len = u32::from_be_bytes([bytes[pos], bytes[pos+1], bytes[pos+2], bytes[pos+3]]) as usize;
518            pos += 4;
519            
520            if pos + key_len > bytes.len() {
521                return Err(PhalanxError::protocol("Truncated encrypted group key data"));
522            }
523            
524            Some(bytes[pos..pos + key_len].to_vec())
525        } else {
526            None
527        };
528        
529        Ok(HandshakePayload {
530            group_id,
531            capabilities,
532            client_info,
533            membership_proof,
534            encrypted_group_key,
535        })
536    }
537}
538
539impl KeyRotationMessage {
540    /// Create a new key rotation message
541    pub fn new(
542        admin: &Identity,
543        sequence: u64,
544        member_keys: Vec<(PublicKey, X25519PublicKey)>,
545    ) -> Result<Self> {
546        let timestamp = SystemTime::now()
547            .duration_since(UNIX_EPOCH)
548            .map_err(|e| PhalanxError::crypto(format!("System time error: {}", e)))?
549            .as_secs();
550        
551        // Sign the rotation message
552        let signature_data = Self::create_signature_data(sequence, timestamp, &member_keys);
553        let signature = admin.sign(&signature_data);
554        
555        Ok(Self {
556            version: ProtocolVersion::current(),
557            sequence,
558            timestamp,
559            member_keys,
560            signature,
561        })
562    }
563    
564    /// Verify key rotation message
565    pub fn verify(&self, admin_key: &PublicKey) -> Result<()> {
566        let signature_data = Self::create_signature_data(self.sequence, self.timestamp, &self.member_keys);
567        admin_key.verify(&signature_data, &self.signature)
568    }
569    
570    fn create_signature_data(
571        sequence: u64,
572        timestamp: u64,
573        member_keys: &[(PublicKey, X25519PublicKey)],
574    ) -> Vec<u8> {
575        let mut data = Vec::new();
576        data.push(ProtocolVersion::current().into());
577        data.extend_from_slice(&sequence.to_be_bytes());
578        data.extend_from_slice(&timestamp.to_be_bytes());
579        
580        for (pub_key, ephemeral) in member_keys {
581            data.extend_from_slice(&pub_key.id());
582            data.extend_from_slice(ephemeral.as_bytes());
583        }
584        
585        data.extend_from_slice(b"PHALANX_KEY_ROTATION_V1");
586        data
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    
594    #[test]
595    fn test_handshake_message() {
596        let sender = Identity::generate();
597        let group_id = [1u8; 32];
598        let capabilities = vec!["phalanx/v1".to_string(), "threading".to_string()];
599        let client_info = "test-client/1.0".to_string();
600        
601        let handshake = HandshakeMessage::new(
602            &sender,
603            group_id,
604            capabilities.clone(),
605            client_info.clone(),
606        ).unwrap();
607        
608        let payload = handshake.verify_and_decrypt().unwrap();
609        
610        assert_eq!(payload.group_id, group_id);
611        assert_eq!(payload.capabilities, capabilities);
612        assert_eq!(payload.client_info, client_info);
613    }
614    
615    #[test]
616    fn test_key_rotation() {
617        let admin = Identity::generate();
618        let member1 = Identity::generate();
619        let member2 = Identity::generate();
620        
621        let mut member1_clone = member1.clone();
622        let mut member2_clone = member2.clone();
623        
624        let member_keys = vec![
625            (member1.public_key(), member1_clone.generate_kx_key()),
626            (member2.public_key(), member2_clone.generate_kx_key()),
627        ];
628        
629        let rotation = KeyRotationMessage::new(&admin, 1, member_keys).unwrap();
630        
631        assert!(rotation.verify(&admin.public_key()).is_ok());
632    }
633    
634    #[test]
635    fn test_protocol_version_compatibility() {
636        let v1 = ProtocolVersion::V1;
637        assert!(v1.is_compatible_with(ProtocolVersion::V1));
638        
639        let converted: u8 = v1.into();
640        let back: ProtocolVersion = converted.try_into().unwrap();
641        assert_eq!(v1, back);
642    }
643}