alpine/handshake/
transport.rs

1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::net::UdpSocket;
6use tokio::time;
7
8use super::{HandshakeError, HandshakeMessage, HandshakeTransport};
9use crate::messages::{Acknowledge, ControlEnvelope};
10
11/// CBOR-over-UDP transport for handshake and control-plane exchange.
12#[derive(Debug)]
13pub struct CborUdpTransport {
14    socket: UdpSocket,
15    peer: SocketAddr,
16    max_size: usize,
17}
18
19impl CborUdpTransport {
20    pub async fn bind(
21        local: SocketAddr,
22        peer: SocketAddr,
23        max_size: usize,
24    ) -> Result<Self, HandshakeError> {
25        let socket = UdpSocket::bind(local)
26            .await
27            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
28        socket
29            .connect(peer)
30            .await
31            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
32        Ok(Self {
33            socket,
34            peer,
35            max_size,
36        })
37    }
38}
39
40#[async_trait]
41impl HandshakeTransport for CborUdpTransport {
42    async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
43        let bytes = serde_cbor::to_vec(&msg)
44            .map_err(|e| HandshakeError::Transport(format!("encode: {}", e)))?;
45        self.socket
46            .send_to(&bytes, self.peer)
47            .await
48            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
49        Ok(())
50    }
51
52    async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
53        let mut buf = vec![0u8; self.max_size];
54        let (len, _) = self
55            .socket
56            .recv_from(&mut buf)
57            .await
58            .map_err(|e| HandshakeError::Transport(e.to_string()))?;
59        serde_cbor::from_slice(&buf[..len])
60            .map_err(|e| HandshakeError::Transport(format!("decode: {}", e)))
61    }
62}
63
64/// Wrapper that enforces per-message timeouts on recv.
65#[derive(Debug)]
66pub struct TimeoutTransport<T> {
67    inner: T,
68    recv_timeout: Duration,
69}
70
71impl<T> TimeoutTransport<T> {
72    pub fn new(inner: T, recv_timeout: Duration) -> Self {
73        Self {
74            inner,
75            recv_timeout,
76        }
77    }
78}
79
80#[async_trait]
81impl<T> HandshakeTransport for TimeoutTransport<T>
82where
83    T: HandshakeTransport + Send,
84{
85    async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
86        self.inner.send(msg).await
87    }
88
89    async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
90        match time::timeout(self.recv_timeout, self.inner.recv()).await {
91            Ok(res) => res,
92            Err(_) => Err(HandshakeError::Transport("recv timeout".into())),
93        }
94    }
95}
96
97/// Minimal reliability layer for control envelopes with retransmissions and replay protection.
98pub struct ReliableControlChannel<T> {
99    transport: T,
100    seq: u64,
101    max_attempts: u8,
102    base_timeout: Duration,
103    drop_threshold: u8,
104}
105
106impl<T> ReliableControlChannel<T> {
107    pub fn new(transport: T) -> Self {
108        Self {
109            transport,
110            seq: 0,
111            max_attempts: 5,
112            base_timeout: Duration::from_millis(200),
113            drop_threshold: 5,
114        }
115    }
116}
117
118impl<T> ReliableControlChannel<T>
119where
120    T: HandshakeTransport + Send,
121{
122    pub async fn send_reliable(
123        &mut self,
124        mut envelope: ControlEnvelope,
125    ) -> Result<Acknowledge, HandshakeError> {
126        self.seq = self.seq.wrapping_add(1);
127        envelope.seq = self.seq;
128
129        let mut attempt: u8 = 0;
130        loop {
131            attempt += 1;
132            self.transport
133                .send(HandshakeMessage::Control(envelope.clone()))
134                .await?;
135
136            let timeout = self
137                .base_timeout
138                .checked_mul(2u32.saturating_pow((attempt - 1) as u32))
139                .unwrap_or(self.base_timeout * 4);
140
141            match time::timeout(timeout, self.transport.recv()).await {
142                Ok(Ok(HandshakeMessage::Ack(ack))) => {
143                    if ack.seq == envelope.seq && ack.ok {
144                        return Ok(ack);
145                    }
146                }
147                Ok(Ok(HandshakeMessage::Keepalive(_))) => {
148                    // keepalive resets attempt counter
149                    attempt = 0;
150                }
151                _ => {
152                    if attempt >= self.max_attempts || attempt >= self.drop_threshold {
153                        return Err(HandshakeError::Transport(
154                            "control channel retransmit limit exceeded".into(),
155                        ));
156                    }
157                }
158            }
159        }
160    }
161
162    pub fn next_seq(&mut self) -> u64 {
163        self.seq = self.seq.wrapping_add(1);
164        self.seq
165    }
166}