alpine/handshake/
transport.rs1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::net::UdpSocket;
6use tokio::time;
7
8use super::{HandshakeError, HandshakeMessage, HandshakeTransport};
9use crate::messages::{Acknowledge, ControlEnvelope};
10
11#[derive(Debug)]
13pub struct CborUdpTransport {
14 socket: UdpSocket,
15 peer: SocketAddr,
16 max_size: usize,
17}
18
19impl CborUdpTransport {
20 pub async fn bind(
21 local: SocketAddr,
22 peer: SocketAddr,
23 max_size: usize,
24 ) -> Result<Self, HandshakeError> {
25 let socket = UdpSocket::bind(local)
26 .await
27 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
28 socket
29 .connect(peer)
30 .await
31 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
32 Ok(Self {
33 socket,
34 peer,
35 max_size,
36 })
37 }
38}
39
40#[async_trait]
41impl HandshakeTransport for CborUdpTransport {
42 async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
43 let bytes = serde_cbor::to_vec(&msg)
44 .map_err(|e| HandshakeError::Transport(format!("encode: {}", e)))?;
45 self.socket
46 .send_to(&bytes, self.peer)
47 .await
48 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
49 Ok(())
50 }
51
52 async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
53 let mut buf = vec![0u8; self.max_size];
54 let (len, _) = self
55 .socket
56 .recv_from(&mut buf)
57 .await
58 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
59 serde_cbor::from_slice(&buf[..len])
60 .map_err(|e| HandshakeError::Transport(format!("decode: {}", e)))
61 }
62}
63
64#[derive(Debug)]
66pub struct TimeoutTransport<T> {
67 inner: T,
68 recv_timeout: Duration,
69}
70
71impl<T> TimeoutTransport<T> {
72 pub fn new(inner: T, recv_timeout: Duration) -> Self {
73 Self {
74 inner,
75 recv_timeout,
76 }
77 }
78}
79
80#[async_trait]
81impl<T> HandshakeTransport for TimeoutTransport<T>
82where
83 T: HandshakeTransport + Send,
84{
85 async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
86 self.inner.send(msg).await
87 }
88
89 async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
90 match time::timeout(self.recv_timeout, self.inner.recv()).await {
91 Ok(res) => res,
92 Err(_) => Err(HandshakeError::Transport("recv timeout".into())),
93 }
94 }
95}
96
97pub struct ReliableControlChannel<T> {
99 transport: T,
100 seq: u64,
101 max_attempts: u8,
102 base_timeout: Duration,
103 drop_threshold: u8,
104}
105
106impl<T> ReliableControlChannel<T> {
107 pub fn new(transport: T) -> Self {
108 Self {
109 transport,
110 seq: 0,
111 max_attempts: 5,
112 base_timeout: Duration::from_millis(200),
113 drop_threshold: 5,
114 }
115 }
116}
117
118impl<T> ReliableControlChannel<T>
119where
120 T: HandshakeTransport + Send,
121{
122 pub async fn send_reliable(
123 &mut self,
124 mut envelope: ControlEnvelope,
125 ) -> Result<Acknowledge, HandshakeError> {
126 self.seq = self.seq.wrapping_add(1);
127 envelope.seq = self.seq;
128
129 let mut attempt: u8 = 0;
130 loop {
131 attempt += 1;
132 self.transport
133 .send(HandshakeMessage::Control(envelope.clone()))
134 .await?;
135
136 let timeout = self
137 .base_timeout
138 .checked_mul(2u32.saturating_pow((attempt - 1) as u32))
139 .unwrap_or(self.base_timeout * 4);
140
141 match time::timeout(timeout, self.transport.recv()).await {
142 Ok(Ok(HandshakeMessage::Ack(ack))) => {
143 if ack.seq == envelope.seq && ack.ok {
144 return Ok(ack);
145 }
146 }
147 Ok(Ok(HandshakeMessage::Keepalive(_))) => {
148 attempt = 0;
150 }
151 _ => {
152 if attempt >= self.max_attempts || attempt >= self.drop_threshold {
153 return Err(HandshakeError::Transport(
154 "control channel retransmit limit exceeded".into(),
155 ));
156 }
157 }
158 }
159 }
160 }
161
162 pub fn next_seq(&mut self) -> u64 {
163 self.seq = self.seq.wrapping_add(1);
164 self.seq
165 }
166}