Skip to main content

vcl_protocol/
connection.rs

1use crate::packet::{VCLPacket, PacketType};
2use crate::crypto::{KeyPair, encrypt_payload, decrypt_payload};
3use crate::handshake::{HandshakeMessage, create_client_hello, process_client_hello, process_server_hello};
4use crate::error::VCLError;
5use crate::event::VCLEvent;
6use ed25519_dalek::SigningKey;
7use x25519_dalek::{EphemeralSecret, PublicKey};
8use rand::rngs::OsRng;
9use tokio::net::UdpSocket;
10use tokio::sync::mpsc;
11use std::net::SocketAddr;
12use std::collections::HashSet;
13use std::time::Instant;
14
15pub struct VCLConnection {
16    socket: UdpSocket,
17    keypair: KeyPair,
18    send_sequence: u64,
19    send_hash: Vec<u8>,
20    recv_hash: Vec<u8>,
21    last_sequence: u64,
22    seen_nonces: HashSet<[u8; 24]>,
23    peer_addr: Option<SocketAddr>,
24    peer_public_key: Option<Vec<u8>>,
25    shared_secret: Option<[u8; 32]>,
26    #[allow(dead_code)]
27    is_server: bool,
28    closed: bool,
29    last_activity: Instant,
30    timeout_secs: u64,
31    event_tx: Option<mpsc::Sender<VCLEvent>>,
32    ping_sent_at: Option<Instant>,
33}
34
35impl VCLConnection {
36    pub async fn bind(addr: &str) -> Result<Self, VCLError> {
37        let socket = UdpSocket::bind(addr).await?;
38        Ok(VCLConnection {
39            socket,
40            keypair: KeyPair::generate(),
41            send_sequence: 0,
42            send_hash: vec![0; 32],
43            recv_hash: vec![0; 32],
44            last_sequence: 0,
45            seen_nonces: HashSet::new(),
46            peer_addr: None,
47            peer_public_key: None,
48            shared_secret: None,
49            is_server: false,
50            closed: false,
51            last_activity: Instant::now(),
52            timeout_secs: 60,
53            event_tx: None,
54            ping_sent_at: None,
55        })
56    }
57
58    // ─── Events ──────────────────────────────────────────────────────────────
59
60    /// Subscribe to connection events. Returns an async receiver channel.
61    /// Call before connect() / accept_handshake() to catch Connected event.
62    pub fn subscribe(&mut self) -> mpsc::Receiver<VCLEvent> {
63        let (tx, rx) = mpsc::channel(64);
64        self.event_tx = Some(tx);
65        rx
66    }
67
68    fn emit(&self, event: VCLEvent) {
69        if let Some(tx) = &self.event_tx {
70            let _ = tx.try_send(event);
71        }
72    }
73
74    // ─── Configuration ────────────────────────────────────────────────────────
75
76    pub fn set_timeout(&mut self, secs: u64) {
77        self.timeout_secs = secs;
78    }
79
80    pub fn get_timeout(&self) -> u64 {
81        self.timeout_secs
82    }
83
84    pub fn last_activity(&self) -> Instant {
85        self.last_activity
86    }
87
88    pub fn set_shared_key(&mut self, private_key: &[u8]) {
89        let key_bytes: &[u8; 32] = private_key.try_into().unwrap();
90        let signing_key = SigningKey::from_bytes(key_bytes);
91        let verifying_key = signing_key.verifying_key();
92        self.keypair.private_key = private_key.to_vec();
93        self.keypair.public_key = verifying_key.to_bytes().to_vec();
94    }
95
96    // ─── Handshake ────────────────────────────────────────────────────────────
97
98    pub async fn connect(&mut self, addr: &str) -> Result<(), VCLError> {
99        let parsed: SocketAddr = addr.parse()?;
100        self.peer_addr = Some(parsed);
101
102        let (hello_msg, ephemeral) = create_client_hello();
103        let hello_bytes = bincode::serialize(&hello_msg)?;
104        self.socket.send_to(&hello_bytes, parsed).await?;
105
106        let mut buf = vec![0u8; 65535];
107        let (len, _) = self.socket.recv_from(&mut buf).await?;
108        let server_hello: HandshakeMessage = bincode::deserialize(&buf[..len])?;
109
110        match server_hello {
111            HandshakeMessage::ServerHello { public_key } => {
112                let shared = process_server_hello(ephemeral, public_key)
113                    .ok_or_else(|| VCLError::HandshakeFailed("Key exchange failed".to_string()))?;
114                self.shared_secret = Some(shared);
115            }
116            _ => return Err(VCLError::ExpectedServerHello),
117        }
118
119        self.last_activity = Instant::now();
120        self.emit(VCLEvent::Connected);
121        Ok(())
122    }
123
124    pub async fn accept_handshake(&mut self) -> Result<(), VCLError> {
125        let ephemeral = EphemeralSecret::random_from_rng(OsRng);
126
127        let mut buf = vec![0u8; 65535];
128        let (len, addr) = self.socket.recv_from(&mut buf).await?;
129        self.peer_addr = Some(addr);
130
131        let client_hello: HandshakeMessage = bincode::deserialize(&buf[..len])?;
132
133        match client_hello {
134            HandshakeMessage::ClientHello { public_key } => {
135                let (server_hello, shared) = process_client_hello(ephemeral, public_key);
136                let hello_bytes = bincode::serialize(&server_hello)?;
137                self.socket.send_to(&hello_bytes, addr).await?;
138                self.shared_secret = Some(
139                    shared.ok_or_else(|| VCLError::HandshakeFailed("Key exchange failed".to_string()))?
140                );
141                self.is_server = true;
142            }
143            _ => return Err(VCLError::ExpectedClientHello),
144        }
145
146        self.last_activity = Instant::now();
147        self.emit(VCLEvent::Connected);
148        Ok(())
149    }
150
151    // ─── Internal send ────────────────────────────────────────────────────────
152
153    async fn send_internal(&mut self, data: &[u8], packet_type: PacketType) -> Result<(), VCLError> {
154        let key = self.shared_secret.ok_or(VCLError::NoSharedSecret)?;
155        let (encrypted_payload, nonce) = encrypt_payload(data, &key)?;
156
157        let mut packet = VCLPacket::new_typed(
158            self.send_sequence,
159            self.send_hash.clone(),
160            encrypted_payload,
161            nonce,
162            packet_type,
163        );
164        packet.sign(&self.keypair.private_key)?;
165
166        let serialized = packet.serialize();
167        let addr = self.peer_addr.ok_or(VCLError::NoPeerAddress)?;
168        self.socket.send_to(&serialized, addr).await?;
169
170        self.send_hash = packet.compute_hash();
171        self.send_sequence += 1;
172        self.last_activity = Instant::now();
173        Ok(())
174    }
175
176    // ─── Public send ──────────────────────────────────────────────────────────
177
178    pub async fn send(&mut self, data: &[u8]) -> Result<(), VCLError> {
179        if self.closed { return Err(VCLError::ConnectionClosed); }
180        self.check_timeout()?;
181        self.send_internal(data, PacketType::Data).await
182    }
183
184    // ─── Ping / Heartbeat ─────────────────────────────────────────────────────
185
186    /// Send a ping to the peer. The pong reply is handled automatically inside
187    /// recv() — subscribe to events to receive PongReceived { latency }.
188    /// You must keep calling recv() for the pong to be processed.
189    pub async fn ping(&mut self) -> Result<(), VCLError> {
190        if self.closed { return Err(VCLError::ConnectionClosed); }
191        self.check_timeout()?;
192        self.ping_sent_at = Some(Instant::now());
193        self.send_internal(&[], PacketType::Ping).await
194    }
195
196    async fn handle_ping(&mut self) -> Result<(), VCLError> {
197        self.send_internal(&[], PacketType::Pong).await?;
198        self.emit(VCLEvent::PingReceived);
199        Ok(())
200    }
201
202    fn handle_pong(&mut self) {
203        if let Some(sent_at) = self.ping_sent_at.take() {
204            self.emit(VCLEvent::PongReceived { latency: sent_at.elapsed() });
205        }
206    }
207
208    // ─── Key Rotation ──────────────────────────────────────────────────────────
209
210    /// Initiate a key rotation. Generates a new X25519 ephemeral key pair,
211    /// sends the public key to the peer, and waits for the peer's response.
212    /// Both sides atomically switch to the new shared secret.
213    /// The peer must be in an active recv() loop to handle the rotation.
214    pub async fn rotate_keys(&mut self) -> Result<(), VCLError> {
215        if self.closed { return Err(VCLError::ConnectionClosed); }
216        self.check_timeout()?;
217
218        let our_ephemeral = EphemeralSecret::random_from_rng(OsRng);
219        let our_public = PublicKey::from(&our_ephemeral);
220
221        self.send_internal(&our_public.to_bytes(), PacketType::KeyRotation).await?;
222
223        let mut buf = vec![0u8; 65535];
224        let (len, _) = self.socket.recv_from(&mut buf).await?;
225        let packet = VCLPacket::deserialize(&buf[..len])?;
226
227        if self.seen_nonces.contains(&packet.nonce) {
228            return Err(VCLError::ReplayDetected("Duplicate nonce in key rotation".to_string()));
229        }
230        self.seen_nonces.insert(packet.nonce);
231
232        if !packet.validate_chain(&self.recv_hash) {
233            return Err(VCLError::ChainValidationFailed);
234        }
235
236        let verify_key = self.peer_public_key.as_ref().unwrap_or(&self.keypair.public_key);
237        if !packet.verify(verify_key)? {
238            return Err(VCLError::SignatureInvalid);
239        }
240
241        self.recv_hash = packet.compute_hash();
242        self.last_sequence = packet.sequence;
243        self.last_activity = Instant::now();
244
245        let old_key = self.shared_secret.ok_or(VCLError::NoSharedSecret)?;
246        let decrypted = decrypt_payload(&packet.payload, &old_key, &packet.nonce)?;
247
248        if packet.packet_type != PacketType::KeyRotation {
249            return Err(VCLError::HandshakeFailed("Expected KeyRotation response".to_string()));
250        }
251        if decrypted.len() != 32 {
252            return Err(VCLError::InvalidPacket("KeyRotation payload must be 32 bytes".to_string()));
253        }
254
255        let their_bytes: [u8; 32] = decrypted
256            .try_into()
257            .map_err(|_| VCLError::InvalidPacket("Invalid peer pubkey".to_string()))?;
258        let their_pubkey = PublicKey::from(their_bytes);
259        let new_secret = our_ephemeral.diffie_hellman(&their_pubkey);
260        self.shared_secret = Some(new_secret.to_bytes());
261        self.emit(VCLEvent::KeyRotated);
262        Ok(())
263    }
264
265    async fn handle_key_rotation_request(&mut self, their_pubkey_bytes: &[u8]) -> Result<(), VCLError> {
266        if their_pubkey_bytes.len() != 32 {
267            return Err(VCLError::InvalidPacket("KeyRotation payload must be 32 bytes".to_string()));
268        }
269
270        let their_bytes: [u8; 32] = their_pubkey_bytes
271            .try_into()
272            .map_err(|_| VCLError::InvalidPacket("Invalid peer pubkey".to_string()))?;
273        let their_pubkey = PublicKey::from(their_bytes);
274
275        let our_ephemeral = EphemeralSecret::random_from_rng(OsRng);
276        let our_public = PublicKey::from(&our_ephemeral);
277        let new_secret = our_ephemeral.diffie_hellman(&their_pubkey);
278
279        self.send_internal(&our_public.to_bytes(), PacketType::KeyRotation).await?;
280
281        self.shared_secret = Some(new_secret.to_bytes());
282        self.emit(VCLEvent::KeyRotated);
283        Ok(())
284    }
285
286    // ─── Receive ──────────────────────────────────────────────────────────────
287
288    pub async fn recv(&mut self) -> Result<VCLPacket, VCLError> {
289        if self.closed { return Err(VCLError::ConnectionClosed); }
290
291        loop {
292            self.check_timeout()?;
293
294            let mut buf = vec![0u8; 65535];
295            let (len, addr) = self.socket.recv_from(&mut buf).await?;
296            if self.peer_addr.is_none() {
297                self.peer_addr = Some(addr);
298            }
299
300            let packet = VCLPacket::deserialize(&buf[..len])?;
301
302            if self.last_sequence > 0 && packet.sequence <= self.last_sequence {
303                return Err(VCLError::ReplayDetected("Old sequence number".to_string()));
304            }
305            if self.seen_nonces.contains(&packet.nonce) {
306                return Err(VCLError::ReplayDetected("Duplicate nonce".to_string()));
307            }
308            self.seen_nonces.insert(packet.nonce);
309            if self.seen_nonces.len() > 1000 {
310                self.seen_nonces.clear();
311            }
312
313            if !packet.validate_chain(&self.recv_hash) {
314                return Err(VCLError::ChainValidationFailed);
315            }
316
317            let verify_key = self.peer_public_key.as_ref().unwrap_or(&self.keypair.public_key);
318            if !packet.verify(verify_key)? {
319                return Err(VCLError::SignatureInvalid);
320            }
321
322            self.recv_hash = packet.compute_hash();
323            self.last_sequence = packet.sequence;
324            self.last_activity = Instant::now();
325
326            let key = self.shared_secret.ok_or(VCLError::NoSharedSecret)?;
327            let decrypted = decrypt_payload(&packet.payload, &key, &packet.nonce)?;
328
329            match packet.packet_type {
330                PacketType::Data => {
331                    self.emit(VCLEvent::PacketReceived {
332                        sequence: packet.sequence,
333                        size: decrypted.len(),
334                    });
335                    return Ok(VCLPacket {
336                        version: packet.version,
337                        packet_type: PacketType::Data,
338                        sequence: packet.sequence,
339                        prev_hash: packet.prev_hash,
340                        nonce: packet.nonce,
341                        payload: decrypted,
342                        signature: packet.signature,
343                    });
344                }
345                PacketType::Ping => {
346                    self.handle_ping().await?;
347                }
348                PacketType::Pong => {
349                    self.handle_pong();
350                }
351                PacketType::KeyRotation => {
352                    self.handle_key_rotation_request(&decrypted).await?;
353                }
354            }
355        }
356    }
357
358    // ─── Session management ───────────────────────────────────────────────────
359
360    fn check_timeout(&self) -> Result<(), VCLError> {
361        if self.last_activity.elapsed().as_secs() > self.timeout_secs {
362            return Err(VCLError::Timeout);
363        }
364        Ok(())
365    }
366
367    pub fn close(&mut self) -> Result<(), VCLError> {
368        if self.closed {
369            return Err(VCLError::ConnectionClosed);
370        }
371        self.closed = true;
372        self.send_sequence = 0;
373        self.send_hash = vec![0; 32];
374        self.recv_hash = vec![0; 32];
375        self.last_sequence = 0;
376        self.seen_nonces.clear();
377        self.shared_secret = None;
378        self.ping_sent_at = None;
379        self.emit(VCLEvent::Disconnected);
380        Ok(())
381    }
382
383    pub fn is_closed(&self) -> bool {
384        self.closed
385    }
386
387    pub fn get_public_key(&self) -> Vec<u8> {
388        self.keypair.public_key.clone()
389    }
390
391    pub fn get_shared_secret(&self) -> Option<[u8; 32]> {
392        self.shared_secret
393    }
394}