alpine/handshake/
transport.rs

1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use hex;
6use serde_cbor::value::Value as CborValue;
7use std::sync::Arc;
8use tokio::net::UdpSocket;
9use tokio::time;
10use tracing::{debug, info, trace};
11
12use super::{HandshakeError, HandshakeMessage, HandshakeTransport};
13use crate::messages::{Acknowledge, ControlEnvelope};
14
15/// CBOR-over-UDP transport for handshake and control-plane exchange.
16#[derive(Debug)]
17pub struct CborUdpTransport {
18    socket: Arc<UdpSocket>,
19    peer: SocketAddr,
20    max_size: usize,
21    debug_cbor: bool,
22}
23
24impl CborUdpTransport {
25    pub async fn bind(
26        local: SocketAddr,
27        peer: SocketAddr,
28        max_size: usize,
29        debug_cbor: bool,
30    ) -> Result<Self, HandshakeError> {
31        let socket = UdpSocket::bind(local)
32            .await
33            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
34        socket
35            .connect(peer)
36            .await
37            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
38        let bound = socket
39            .local_addr()
40            .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)));
41        info!(
42            "[ALPINE][HANDSHAKE][SOCKET] UDP transport bound local_addr={} peer={} max_size={}",
43            bound, peer, max_size
44        );
45        Ok(Self {
46            socket: Arc::new(socket),
47            peer,
48            max_size,
49            debug_cbor,
50        })
51    }
52
53    pub fn from_socket(
54        socket: UdpSocket,
55        peer: SocketAddr,
56        max_size: usize,
57        debug_cbor: bool,
58    ) -> Result<Self, HandshakeError> {
59        let bound = socket
60            .local_addr()
61            .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)));
62        info!(
63            "[ALPINE][HANDSHAKE][SOCKET] UDP transport using provided socket local_addr={} peer={} max_size={}",
64            bound, peer, max_size
65        );
66        Ok(Self {
67            socket: Arc::new(socket),
68            peer,
69            max_size,
70            debug_cbor,
71        })
72    }
73
74    pub fn local_addr(&self) -> SocketAddr {
75        self.socket
76            .local_addr()
77            .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)))
78    }
79
80    pub fn peer_addr(&self) -> SocketAddr {
81        self.peer
82    }
83
84    pub fn socket(&self) -> Arc<UdpSocket> {
85        self.socket.clone()
86    }
87}
88
89#[async_trait]
90impl HandshakeTransport for CborUdpTransport {
91    async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
92        let bytes = serde_cbor::to_vec(&msg)
93            .map_err(|e| HandshakeError::Transport(format!("encode: {}", e)))?;
94        let local_addr = self.local_addr();
95        info!(
96            "[ALPINE][TX] msg_type={} local_addr={} remote_addr={} bytes={}",
97            message_label(&msg),
98            local_addr,
99            self.peer,
100            bytes.len()
101        );
102        trace!(peer=%self.peer, len=%bytes.len(), message=?msg, "handshake send");
103        self.socket
104            .send_to(&bytes, self.peer)
105            .await
106            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
107        Ok(())
108    }
109
110    async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
111        let mut buf = vec![0u8; self.max_size];
112        let (len, from) = self
113            .socket
114            .recv_from(&mut buf)
115            .await
116            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
117        let local_addr = self.local_addr();
118        let preview_len = len.min(32);
119        let preview = hex::encode(&buf[..preview_len]);
120        let tail_len = len.min(16);
121        let tail = hex::encode(&buf[len.saturating_sub(tail_len)..len]);
122        info!(
123            "[ALPINE][RX] raw packet received local_addr={} from={} bytes={} buf_cap={} first32={} last16={}",
124            local_addr,
125            from,
126            len,
127            self.max_size,
128            preview,
129            tail
130        );
131        trace!(peer=%from, len=%len, "handshake raw recv");
132        if self.debug_cbor {
133            debug!(
134                "[ALPINE][HANDSHAKE][DEBUG_CBOR] raw_hex={}",
135                hex::encode(&buf[..len])
136            );
137            log_cbor_structure(&buf[..len]);
138        }
139        let mut msg = serde_cbor::from_slice(&buf[..len])
140            .map_err(|e| HandshakeError::Transport(format!("decode: {}", e)));
141        if let Err(_) = &msg {
142            if !buf.is_empty() && (buf[0] & 0xE0) == 0xA0 {
143                debug!(
144                    "[ALPINE][HANDSHAKE][RX] attempting truncated CBOR map repair len={} cap={} first_byte=0x{:x}",
145                    len,
146                    self.max_size,
147                    buf[0]
148                );
149                let mut repaired = buf[..len].to_vec();
150                repaired[0] = 0xBF; // Indefinite-length map.
151                repaired.push(0xFF); // Break.
152                msg = serde_cbor::from_slice(&repaired)
153                    .map_err(|e| HandshakeError::Transport(format!("decode(repaired): {}", e)));
154            }
155        }
156        if let Ok(parsed) = &msg {
157            info!(
158                "[ALPINE][HANDSHAKE][RX][parsed] variant={} local_addr={} from={} fields={}",
159                message_label(parsed),
160                local_addr,
161                from,
162                describe_fields(parsed)
163            );
164        }
165        trace!(peer=%from, result=?msg, "handshake parsed message");
166        msg
167    }
168}
169
170/// Wrapper that enforces per-message timeouts on recv.
171#[derive(Debug)]
172pub struct TimeoutTransport<T> {
173    inner: T,
174    recv_timeout: Duration,
175}
176
177impl<T> TimeoutTransport<T> {
178    pub fn new(inner: T, recv_timeout: Duration) -> Self {
179        Self {
180            inner,
181            recv_timeout,
182        }
183    }
184}
185
186#[async_trait]
187impl<T> HandshakeTransport for TimeoutTransport<T>
188where
189    T: HandshakeTransport + Send,
190{
191    async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
192        self.inner.send(msg).await
193    }
194
195    async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
196        debug!(
197            "[ALPINE][HANDSHAKE] recv with timeout_ms={}",
198            self.recv_timeout.as_millis()
199        );
200        match time::timeout(self.recv_timeout, self.inner.recv()).await {
201            Ok(res) => res,
202            Err(_) => Err(HandshakeError::Transport("recv timeout".into())),
203        }
204    }
205}
206
207/// Minimal reliability layer for control envelopes with retransmissions and replay protection.
208pub struct ReliableControlChannel<T> {
209    transport: T,
210    seq: u64,
211    max_attempts: u8,
212    base_timeout: Duration,
213    drop_threshold: u8,
214}
215
216impl<T> ReliableControlChannel<T> {
217    pub fn new(transport: T) -> Self {
218        Self {
219            transport,
220            seq: 0,
221            max_attempts: 5,
222            base_timeout: Duration::from_millis(200),
223            drop_threshold: 5,
224        }
225    }
226}
227
228impl<T> ReliableControlChannel<T>
229where
230    T: HandshakeTransport + Send,
231{
232    pub async fn send_reliable(
233        &mut self,
234        mut envelope: ControlEnvelope,
235    ) -> Result<Acknowledge, HandshakeError> {
236        self.seq = self.seq.wrapping_add(1);
237        envelope.seq = self.seq;
238
239        let mut attempt: u8 = 0;
240        loop {
241            attempt += 1;
242            self.transport
243                .send(HandshakeMessage::Control(envelope.clone()))
244                .await?;
245
246            let timeout = self
247                .base_timeout
248                .checked_mul(2u32.saturating_pow((attempt - 1) as u32))
249                .unwrap_or(self.base_timeout * 4);
250
251            match time::timeout(timeout, self.transport.recv()).await {
252                Ok(Ok(HandshakeMessage::Ack(ack))) => {
253                    if ack.seq == envelope.seq && ack.ok {
254                        return Ok(ack);
255                    }
256                }
257                Ok(Ok(HandshakeMessage::Keepalive(_))) => {
258                    // keepalive resets attempt counter
259                    attempt = 0;
260                }
261                _ => {
262                    if attempt >= self.max_attempts || attempt >= self.drop_threshold {
263                        return Err(HandshakeError::Transport(
264                            "control channel retransmit limit exceeded".into(),
265                        ));
266                    }
267                }
268            }
269        }
270    }
271
272    pub fn next_seq(&mut self) -> u64 {
273        self.seq = self.seq.wrapping_add(1);
274        self.seq
275    }
276}
277
278fn message_label(msg: &HandshakeMessage) -> &'static str {
279    match msg {
280        HandshakeMessage::SessionInit(_) => "SessionInit",
281        HandshakeMessage::SessionAck(_) => "SessionAck",
282        HandshakeMessage::SessionReady(_) => "SessionReady",
283        HandshakeMessage::SessionComplete(_) => "SessionComplete",
284        HandshakeMessage::SessionEstablished(_) => "SessionEstablished",
285        HandshakeMessage::Keepalive(_) => "Keepalive",
286        HandshakeMessage::Control(_) => "Control",
287        HandshakeMessage::Ack(_) => "Ack",
288    }
289}
290
291fn log_cbor_structure(bytes: &[u8]) {
292    if let Ok(value) = serde_cbor::from_slice::<CborValue>(bytes) {
293        match value {
294            CborValue::Map(map) => {
295                debug!(
296                    "[ALPINE][HANDSHAKE][DEBUG_CBOR] map_len={} entries={}",
297                    map.len(),
298                    map.len()
299                );
300                for (idx, (key, val)) in map.iter().enumerate() {
301                    debug!(
302                        "[ALPINE][HANDSHAKE][DEBUG_CBOR] entry={} key_type={} value_type={}",
303                        idx,
304                        describe_value(key),
305                        describe_value(val)
306                    );
307                }
308            }
309            other => {
310                debug!(
311                    "[ALPINE][HANDSHAKE][DEBUG_CBOR] non-map top-level type={}",
312                    describe_value(&other)
313                );
314            }
315        }
316    } else {
317        debug!("[ALPINE][HANDSHAKE][DEBUG_CBOR] decode failed");
318    }
319}
320
321fn describe_value(val: &CborValue) -> &'static str {
322    match val {
323        CborValue::Null => "null",
324        CborValue::Bool(_) => "bool",
325        CborValue::Integer(_) => "integer",
326        CborValue::Bytes(_) => "bytes",
327        CborValue::Text(_) => "text",
328        CborValue::Array(_) => "array",
329        CborValue::Map(_) => "map",
330        CborValue::Tag(_, _) => "tag",
331        CborValue::Float(_) => "float",
332        _ => "other",
333    }
334}
335
336fn describe_fields(msg: &HandshakeMessage) -> String {
337    match msg {
338        HandshakeMessage::SessionInit(init) => format!(
339            "session_id={} controller_nonce_len={} controller_pubkey_len={} requested={:?}",
340            init.session_id,
341            init.controller_nonce.len(),
342            init.controller_pubkey.len(),
343            init.requested
344        ),
345        HandshakeMessage::SessionAck(ack) => format!(
346            "session_id={} device_nonce_len={} device_pubkey_len={} device_identity_pubkey_len={} device_id={}",
347            ack.session_id,
348            ack.device_nonce.len(),
349            ack.device_pubkey.len(),
350            ack.device_identity_pubkey.len(),
351            ack.device_identity.device_id
352        ),
353        HandshakeMessage::SessionReady(ready) => format!(
354            "session_id={} mac_len={}",
355            ready.session_id,
356            ready.mac.len()
357        ),
358        HandshakeMessage::SessionComplete(comp) => format!(
359            "session_id={} ok={} error={:?}",
360            comp.session_id, comp.ok, comp.error
361        ),
362        HandshakeMessage::SessionEstablished(est) => format!(
363            "session_id={} controller_nonce_len={} device_nonce_len={} device_id={}",
364            est.session_id,
365            est.controller_nonce.len(),
366            est.device_nonce.len(),
367            est.device_identity.device_id
368        ),
369        HandshakeMessage::Keepalive(k) => {
370            format!("session_id={} tick_ms={}", k.session_id, k.tick_ms)
371        }
372        HandshakeMessage::Control(ctrl) => format!(
373            "session_id={} seq={} op={:?} mac_len={}",
374            ctrl.session_id,
375            ctrl.seq,
376            ctrl.op,
377            ctrl.mac.len()
378        ),
379        HandshakeMessage::Ack(ack) => format!(
380            "session_id={} seq={} ok={} detail_present={} payload_present={}",
381            ack.session_id,
382            ack.seq,
383            ack.ok,
384            ack.detail.is_some(),
385            ack.payload.is_some()
386        ),
387    }
388}