1use po_crypto::aead::SessionCipher;
11use po_crypto::exchange::EphemeralKeypair;
12use po_crypto::identity::Identity;
13use po_transport::traits::AsyncFrameTransport;
14use po_wire::{FrameHeader, FrameType};
15
16use crate::framer::{Framer, FramerError};
17use crate::message::{HandshakeComplete, HandshakeInit, HandshakeReply};
18
19use ed25519_dalek::{Signature, VerifyingKey};
20use std::time::{SystemTime, UNIX_EPOCH};
21
22pub struct HandshakeResult {
24 pub cipher: SessionCipher,
26 pub peer_pubkey: [u8; 32],
28 pub peer_node_id: po_crypto::identity::NodeId,
30}
31
32pub async fn perform_handshake_initiator(
34 identity: &Identity,
35 transport: &mut dyn AsyncFrameTransport,
36 framer: &mut Framer,
37) -> Result<HandshakeResult, HandshakeError> {
38 let ephemeral = EphemeralKeypair::generate();
40 let our_eph_pub = ephemeral.public_bytes();
41
42 let timestamp = now_millis();
44 let mut sign_data = Vec::with_capacity(41);
45 sign_data.push(1u8); sign_data.extend_from_slice(&our_eph_pub);
47 sign_data.extend_from_slice(×tamp.to_le_bytes());
48
49 let signature = identity.sign(&sign_data);
50
51 let init = HandshakeInit {
53 version: 1,
54 ed25519_pubkey: identity.public_key_bytes(),
55 x25519_ephemeral: our_eph_pub,
56 timestamp,
57 signature: signature.to_bytes().to_vec(),
58 };
59 let payload =
60 bincode::serialize(&init).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
61 let header = FrameHeader {
62 frame_type: FrameType::HandshakeInit,
63 flags: po_wire::FrameFlags::default(),
64 channel_id: 0,
65 stream_id: 0,
66 payload_len: payload.len() as u64,
67 };
68 framer
69 .write_frame(transport, &header, &payload)
70 .await
71 .map_err(HandshakeError::Framer)?;
72
73 let (reply_header, reply_payload) = framer
75 .read_frame(transport)
76 .await
77 .map_err(HandshakeError::Framer)?
78 .ok_or(HandshakeError::ConnectionClosed)?;
79
80 if reply_header.frame_type != FrameType::HandshakeReply {
81 return Err(HandshakeError::UnexpectedFrame(reply_header.frame_type));
82 }
83
84 let reply: HandshakeReply = bincode::deserialize(&reply_payload)
85 .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
86
87 let peer_verifying =
89 VerifyingKey::from_bytes(&reply.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
90 let mut verify_data = Vec::with_capacity(64);
91 verify_data.extend_from_slice(&our_eph_pub);
92 verify_data.extend_from_slice(&reply.x25519_ephemeral);
93
94 let peer_sig = Signature::from_bytes(
95 reply
96 .signature
97 .as_slice()
98 .try_into()
99 .map_err(|_| HandshakeError::InvalidSignature)?,
100 );
101 if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
102 return Err(HandshakeError::InvalidSignature);
103 }
104
105 let context = build_session_context(&identity.public_key_bytes(), &reply.ed25519_pubkey);
107 let session_key = ephemeral
108 .derive_session_key(&reply.x25519_ephemeral, &context)
109 .map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
110
111 let mut cipher = SessionCipher::new(session_key.as_bytes());
113 let confirmation = cipher
114 .encrypt(b"PO_READY", b"handshake-complete")
115 .map_err(|e| HandshakeError::Encryption(e.to_string()))?;
116
117 let complete = HandshakeComplete { confirmation };
118 let complete_payload =
119 bincode::serialize(&complete).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
120 let complete_header = FrameHeader {
121 frame_type: FrameType::HandshakeComplete,
122 flags: po_wire::FrameFlags::default(),
123 channel_id: 0,
124 stream_id: 0,
125 payload_len: complete_payload.len() as u64,
126 };
127 framer
128 .write_frame(transport, &complete_header, &complete_payload)
129 .await
130 .map_err(HandshakeError::Framer)?;
131
132 let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
133
134 Ok(HandshakeResult {
135 cipher,
136 peer_pubkey: reply.ed25519_pubkey,
137 peer_node_id,
138 })
139}
140
141pub async fn perform_handshake_responder(
143 identity: &Identity,
144 transport: &mut dyn AsyncFrameTransport,
145 framer: &mut Framer,
146) -> Result<HandshakeResult, HandshakeError> {
147 let (init_header, init_payload) = framer
149 .read_frame(transport)
150 .await
151 .map_err(HandshakeError::Framer)?
152 .ok_or(HandshakeError::ConnectionClosed)?;
153
154 if init_header.frame_type != FrameType::HandshakeInit {
155 return Err(HandshakeError::UnexpectedFrame(init_header.frame_type));
156 }
157
158 let init: HandshakeInit = bincode::deserialize(&init_payload)
159 .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
160
161 if init.version != 1 {
162 return Err(HandshakeError::UnsupportedVersion(init.version));
163 }
164
165 let peer_verifying =
167 VerifyingKey::from_bytes(&init.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
168 let mut verify_data = Vec::with_capacity(41);
169 verify_data.push(init.version);
170 verify_data.extend_from_slice(&init.x25519_ephemeral);
171 verify_data.extend_from_slice(&init.timestamp.to_le_bytes());
172
173 let peer_sig = Signature::from_bytes(
174 init.signature
175 .as_slice()
176 .try_into()
177 .map_err(|_| HandshakeError::InvalidSignature)?,
178 );
179 if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
180 return Err(HandshakeError::InvalidSignature);
181 }
182
183 let now = now_millis();
185 let drift = now.abs_diff(init.timestamp);
186 if drift > 30_000 {
187 return Err(HandshakeError::TimestampExpired);
188 }
189
190 let ephemeral = EphemeralKeypair::generate();
192 let our_eph_pub = ephemeral.public_bytes();
193
194 let mut sign_data = Vec::with_capacity(64);
196 sign_data.extend_from_slice(&init.x25519_ephemeral);
197 sign_data.extend_from_slice(&our_eph_pub);
198 let signature = identity.sign(&sign_data);
199
200 let reply = HandshakeReply {
202 ed25519_pubkey: identity.public_key_bytes(),
203 x25519_ephemeral: our_eph_pub,
204 signature: signature.to_bytes().to_vec(),
205 };
206 let payload =
207 bincode::serialize(&reply).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
208 let header = FrameHeader {
209 frame_type: FrameType::HandshakeReply,
210 flags: po_wire::FrameFlags::default(),
211 channel_id: 0,
212 stream_id: 0,
213 payload_len: payload.len() as u64,
214 };
215 framer
216 .write_frame(transport, &header, &payload)
217 .await
218 .map_err(HandshakeError::Framer)?;
219
220 let context = build_session_context(&init.ed25519_pubkey, &identity.public_key_bytes());
222 let session_key = ephemeral
223 .derive_session_key(&init.x25519_ephemeral, &context)
224 .map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
225 let cipher = SessionCipher::new(session_key.as_bytes());
226
227 let (complete_header, complete_payload) = framer
229 .read_frame(transport)
230 .await
231 .map_err(HandshakeError::Framer)?
232 .ok_or(HandshakeError::ConnectionClosed)?;
233
234 if complete_header.frame_type != FrameType::HandshakeComplete {
235 return Err(HandshakeError::UnexpectedFrame(complete_header.frame_type));
236 }
237
238 let complete: HandshakeComplete = bincode::deserialize(&complete_payload)
239 .map_err(|e| HandshakeError::Serialization(e.to_string()))?;
240
241 let decrypted = cipher
243 .decrypt(&complete.confirmation, b"handshake-complete")
244 .map_err(|_| HandshakeError::ConfirmationFailed)?;
245
246 if decrypted != b"PO_READY" {
247 return Err(HandshakeError::ConfirmationFailed);
248 }
249
250 let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
251
252 Ok(HandshakeResult {
253 cipher,
254 peer_pubkey: init.ed25519_pubkey,
255 peer_node_id,
256 })
257}
258
259fn build_session_context(initiator_pubkey: &[u8; 32], responder_pubkey: &[u8; 32]) -> Vec<u8> {
262 let mut ctx = Vec::with_capacity(64 + 10);
263 ctx.extend_from_slice(b"po-v1-");
264 ctx.extend_from_slice(initiator_pubkey);
266 ctx.extend_from_slice(responder_pubkey);
267 ctx
268}
269
270fn now_millis() -> u64 {
271 SystemTime::now()
272 .duration_since(UNIX_EPOCH)
273 .unwrap_or_default()
274 .as_millis() as u64
275}
276
277#[derive(Debug)]
279pub enum HandshakeError {
280 Framer(FramerError),
281 Serialization(String),
282 InvalidSignature,
283 InvalidKey,
284 UnsupportedVersion(u8),
285 TimestampExpired,
286 KeyDerivation(String),
287 Encryption(String),
288 ConfirmationFailed,
289 ConnectionClosed,
290 UnexpectedFrame(FrameType),
291}
292
293impl std::fmt::Display for HandshakeError {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 Self::Framer(e) => write!(f, "framer: {e}"),
297 Self::Serialization(e) => write!(f, "serialization: {e}"),
298 Self::InvalidSignature => write!(f, "invalid signature"),
299 Self::InvalidKey => write!(f, "invalid public key"),
300 Self::UnsupportedVersion(v) => write!(f, "unsupported protocol version: {v}"),
301 Self::TimestampExpired => write!(f, "handshake timestamp expired"),
302 Self::KeyDerivation(e) => write!(f, "key derivation: {e}"),
303 Self::Encryption(e) => write!(f, "encryption: {e}"),
304 Self::ConfirmationFailed => write!(f, "handshake confirmation failed"),
305 Self::ConnectionClosed => write!(f, "connection closed during handshake"),
306 Self::UnexpectedFrame(t) => write!(f, "unexpected frame type: {t}"),
307 }
308 }
309}
310
311impl std::error::Error for HandshakeError {}