Skip to main content

po_session/
handshake.rs

1//! Cryptographic handshake protocol for PO connections.
2//!
3//! Implements the 3-way handshake:
4//! 1. **Initiator → Responder**: `HandshakeInit` (Ed25519 pubkey + X25519 ephemeral + signature)
5//! 2. **Responder → Initiator**: `HandshakeReply` (Ed25519 pubkey + X25519 ephemeral + signature)
6//! 3. **Initiator → Responder**: `HandshakeComplete` (encrypted confirmation with session key)
7//!
8//! After step 3, both sides have the same ChaCha20-Poly1305 session key.
9
10use po_crypto::aead::SessionCipher;
11use po_crypto::exchange::EphemeralKeypair;
12use po_crypto::identity::Identity;
13use po_transport::traits::AsyncFrameTransport;
14use po_wire::{FrameHeader, FrameType};
15
16use crate::framer::{Framer, FramerError};
17use crate::message::{HandshakeComplete, HandshakeInit, HandshakeReply};
18
19use ed25519_dalek::{Signature, VerifyingKey};
20use std::time::{SystemTime, UNIX_EPOCH};
21
22/// Result of a successful handshake.
23pub struct HandshakeResult {
24    /// The session cipher for encrypting/decrypting frames.
25    pub cipher: SessionCipher,
26    /// The peer's verified Ed25519 public key.
27    pub peer_pubkey: [u8; 32],
28    /// The peer's NodeId (SHA-256 of their pubkey).
29    pub peer_node_id: po_crypto::identity::NodeId,
30}
31
32/// Perform the handshake as the initiator (client side).
33pub async fn perform_handshake_initiator(
34    identity: &Identity,
35    transport: &mut dyn AsyncFrameTransport,
36    framer: &mut Framer,
37) -> Result<HandshakeResult, HandshakeError> {
38    // Generate ephemeral X25519 keypair
39    let ephemeral = EphemeralKeypair::generate();
40    let our_eph_pub = ephemeral.public_bytes();
41
42    // Build signed data: [version(1) || x25519_ephemeral(32) || timestamp(8)]
43    let timestamp = now_millis();
44    let mut sign_data = Vec::with_capacity(41);
45    sign_data.push(1u8); // version
46    sign_data.extend_from_slice(&our_eph_pub);
47    sign_data.extend_from_slice(&timestamp.to_le_bytes());
48
49    let signature = identity.sign(&sign_data);
50
51    // Send HandshakeInit
52    let init = HandshakeInit {
53        version: 1,
54        ed25519_pubkey: identity.public_key_bytes(),
55        x25519_ephemeral: our_eph_pub,
56        timestamp,
57        signature: signature.to_bytes().to_vec(),
58    };
59    let payload =
60        bincode::serialize(&init).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
61    let header = FrameHeader {
62        frame_type: FrameType::HandshakeInit,
63        flags: po_wire::FrameFlags::default(),
64        channel_id: 0,
65        stream_id: 0,
66        payload_len: payload.len() as u64,
67    };
68    framer
69        .write_frame(transport, &header, &payload)
70        .await
71        .map_err(HandshakeError::Framer)?;
72
73    // Wait for HandshakeReply
74    let (reply_header, reply_payload) = framer
75        .read_frame(transport)
76        .await
77        .map_err(HandshakeError::Framer)?
78        .ok_or(HandshakeError::ConnectionClosed)?;
79
80    if reply_header.frame_type != FrameType::HandshakeReply {
81        return Err(HandshakeError::UnexpectedFrame(reply_header.frame_type));
82    }
83
84    let reply: HandshakeReply = bincode::deserialize(&reply_payload)
85        .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
86
87    // Verify responder's signature over [initiator_x25519_pub || responder_x25519_pub]
88    let peer_verifying =
89        VerifyingKey::from_bytes(&reply.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
90    let mut verify_data = Vec::with_capacity(64);
91    verify_data.extend_from_slice(&our_eph_pub);
92    verify_data.extend_from_slice(&reply.x25519_ephemeral);
93
94    let peer_sig = Signature::from_bytes(
95        reply
96            .signature
97            .as_slice()
98            .try_into()
99            .map_err(|_| HandshakeError::InvalidSignature)?,
100    );
101    if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
102        return Err(HandshakeError::InvalidSignature);
103    }
104
105    // Derive session key
106    let context = build_session_context(&identity.public_key_bytes(), &reply.ed25519_pubkey);
107    let session_key = ephemeral
108        .derive_session_key(&reply.x25519_ephemeral, &context)
109        .map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
110
111    // Send HandshakeComplete with encrypted confirmation
112    let mut cipher = SessionCipher::new(session_key.as_bytes());
113    let confirmation = cipher
114        .encrypt(b"PO_READY", b"handshake-complete")
115        .map_err(|e| HandshakeError::Encryption(e.to_string()))?;
116
117    let complete = HandshakeComplete { confirmation };
118    let complete_payload =
119        bincode::serialize(&complete).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
120    let complete_header = FrameHeader {
121        frame_type: FrameType::HandshakeComplete,
122        flags: po_wire::FrameFlags::default(),
123        channel_id: 0,
124        stream_id: 0,
125        payload_len: complete_payload.len() as u64,
126    };
127    framer
128        .write_frame(transport, &complete_header, &complete_payload)
129        .await
130        .map_err(HandshakeError::Framer)?;
131
132    let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
133
134    Ok(HandshakeResult {
135        cipher,
136        peer_pubkey: reply.ed25519_pubkey,
137        peer_node_id,
138    })
139}
140
141/// Perform the handshake as the responder (server side).
142pub async fn perform_handshake_responder(
143    identity: &Identity,
144    transport: &mut dyn AsyncFrameTransport,
145    framer: &mut Framer,
146) -> Result<HandshakeResult, HandshakeError> {
147    // Wait for HandshakeInit
148    let (init_header, init_payload) = framer
149        .read_frame(transport)
150        .await
151        .map_err(HandshakeError::Framer)?
152        .ok_or(HandshakeError::ConnectionClosed)?;
153
154    if init_header.frame_type != FrameType::HandshakeInit {
155        return Err(HandshakeError::UnexpectedFrame(init_header.frame_type));
156    }
157
158    let init: HandshakeInit = bincode::deserialize(&init_payload)
159        .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
160
161    if init.version != 1 {
162        return Err(HandshakeError::UnsupportedVersion(init.version));
163    }
164
165    // Verify initiator's signature over [version || x25519_ephemeral || timestamp]
166    let peer_verifying =
167        VerifyingKey::from_bytes(&init.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
168    let mut verify_data = Vec::with_capacity(41);
169    verify_data.push(init.version);
170    verify_data.extend_from_slice(&init.x25519_ephemeral);
171    verify_data.extend_from_slice(&init.timestamp.to_le_bytes());
172
173    let peer_sig = Signature::from_bytes(
174        init.signature
175            .as_slice()
176            .try_into()
177            .map_err(|_| HandshakeError::InvalidSignature)?,
178    );
179    if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
180        return Err(HandshakeError::InvalidSignature);
181    }
182
183    // Timestamp freshness check (allow 30 seconds drift)
184    let now = now_millis();
185    let drift = now.abs_diff(init.timestamp);
186    if drift > 30_000 {
187        return Err(HandshakeError::TimestampExpired);
188    }
189
190    // Generate our ephemeral X25519 keypair
191    let ephemeral = EphemeralKeypair::generate();
192    let our_eph_pub = ephemeral.public_bytes();
193
194    // Sign [initiator_x25519_pub || our_x25519_pub]
195    let mut sign_data = Vec::with_capacity(64);
196    sign_data.extend_from_slice(&init.x25519_ephemeral);
197    sign_data.extend_from_slice(&our_eph_pub);
198    let signature = identity.sign(&sign_data);
199
200    // Send HandshakeReply
201    let reply = HandshakeReply {
202        ed25519_pubkey: identity.public_key_bytes(),
203        x25519_ephemeral: our_eph_pub,
204        signature: signature.to_bytes().to_vec(),
205    };
206    let payload =
207        bincode::serialize(&reply).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
208    let header = FrameHeader {
209        frame_type: FrameType::HandshakeReply,
210        flags: po_wire::FrameFlags::default(),
211        channel_id: 0,
212        stream_id: 0,
213        payload_len: payload.len() as u64,
214    };
215    framer
216        .write_frame(transport, &header, &payload)
217        .await
218        .map_err(HandshakeError::Framer)?;
219
220    // Derive session key
221    let context = build_session_context(&init.ed25519_pubkey, &identity.public_key_bytes());
222    let session_key = ephemeral
223        .derive_session_key(&init.x25519_ephemeral, &context)
224        .map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
225    let cipher = SessionCipher::new(session_key.as_bytes());
226
227    // Wait for HandshakeComplete
228    let (complete_header, complete_payload) = framer
229        .read_frame(transport)
230        .await
231        .map_err(HandshakeError::Framer)?
232        .ok_or(HandshakeError::ConnectionClosed)?;
233
234    if complete_header.frame_type != FrameType::HandshakeComplete {
235        return Err(HandshakeError::UnexpectedFrame(complete_header.frame_type));
236    }
237
238    let complete: HandshakeComplete = bincode::deserialize(&complete_payload)
239        .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
240
241    // Decrypt and verify confirmation
242    let decrypted = cipher
243        .decrypt(&complete.confirmation, b"handshake-complete")
244        .map_err(|_| HandshakeError::ConfirmationFailed)?;
245
246    if decrypted != b"PO_READY" {
247        return Err(HandshakeError::ConfirmationFailed);
248    }
249
250    let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
251
252    Ok(HandshakeResult {
253        cipher,
254        peer_pubkey: init.ed25519_pubkey,
255        peer_node_id,
256    })
257}
258
259/// Build the HKDF context: sorted concatenation of both pubkeys.
260/// Sorting ensures both sides derive the same key regardless of who initiated.
261fn build_session_context(initiator_pubkey: &[u8; 32], responder_pubkey: &[u8; 32]) -> Vec<u8> {
262    let mut ctx = Vec::with_capacity(64 + 10);
263    ctx.extend_from_slice(b"po-v1-");
264    // Always put initiator first for deterministic derivation
265    ctx.extend_from_slice(initiator_pubkey);
266    ctx.extend_from_slice(responder_pubkey);
267    ctx
268}
269
270fn now_millis() -> u64 {
271    SystemTime::now()
272        .duration_since(UNIX_EPOCH)
273        .unwrap_or_default()
274        .as_millis() as u64
275}
276
277/// Handshake errors.
278#[derive(Debug)]
279pub enum HandshakeError {
280    Framer(FramerError),
281    Serialization(String),
282    InvalidSignature,
283    InvalidKey,
284    UnsupportedVersion(u8),
285    TimestampExpired,
286    KeyDerivation(String),
287    Encryption(String),
288    ConfirmationFailed,
289    ConnectionClosed,
290    UnexpectedFrame(FrameType),
291}
292
293impl std::fmt::Display for HandshakeError {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        match self {
296            Self::Framer(e) => write!(f, "framer: {e}"),
297            Self::Serialization(e) => write!(f, "serialization: {e}"),
298            Self::InvalidSignature => write!(f, "invalid signature"),
299            Self::InvalidKey => write!(f, "invalid public key"),
300            Self::UnsupportedVersion(v) => write!(f, "unsupported protocol version: {v}"),
301            Self::TimestampExpired => write!(f, "handshake timestamp expired"),
302            Self::KeyDerivation(e) => write!(f, "key derivation: {e}"),
303            Self::Encryption(e) => write!(f, "encryption: {e}"),
304            Self::ConfirmationFailed => write!(f, "handshake confirmation failed"),
305            Self::ConnectionClosed => write!(f, "connection closed during handshake"),
306            Self::UnexpectedFrame(t) => write!(f, "unexpected frame type: {t}"),
307        }
308    }
309}
310
311impl std::error::Error for HandshakeError {}