1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use hex;
6use serde_cbor::value::Value as CborValue;
7use std::sync::Arc;
8use tokio::net::UdpSocket;
9use tokio::time;
10use tracing::{debug, info, trace};
11
12use super::{HandshakeError, HandshakeMessage, HandshakeTransport};
13use crate::messages::{Acknowledge, ControlEnvelope};
14
15#[derive(Debug)]
17pub struct CborUdpTransport {
18 socket: Arc<UdpSocket>,
19 peer: SocketAddr,
20 max_size: usize,
21 debug_cbor: bool,
22}
23
24impl CborUdpTransport {
25 pub async fn bind(
26 local: SocketAddr,
27 peer: SocketAddr,
28 max_size: usize,
29 debug_cbor: bool,
30 ) -> Result<Self, HandshakeError> {
31 let socket = UdpSocket::bind(local)
32 .await
33 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
34 socket
35 .connect(peer)
36 .await
37 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
38 let bound = socket
39 .local_addr()
40 .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)));
41 info!(
42 "[ALPINE][HANDSHAKE][SOCKET] UDP transport bound local_addr={} peer={} max_size={}",
43 bound, peer, max_size
44 );
45 Ok(Self {
46 socket: Arc::new(socket),
47 peer,
48 max_size,
49 debug_cbor,
50 })
51 }
52
53 pub fn from_socket(
54 socket: UdpSocket,
55 peer: SocketAddr,
56 max_size: usize,
57 debug_cbor: bool,
58 ) -> Result<Self, HandshakeError> {
59 let bound = socket
60 .local_addr()
61 .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)));
62 info!(
63 "[ALPINE][HANDSHAKE][SOCKET] UDP transport using provided socket local_addr={} peer={} max_size={}",
64 bound, peer, max_size
65 );
66 Ok(Self {
67 socket: Arc::new(socket),
68 peer,
69 max_size,
70 debug_cbor,
71 })
72 }
73
74 pub fn local_addr(&self) -> SocketAddr {
75 self.socket
76 .local_addr()
77 .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)))
78 }
79
80 pub fn peer_addr(&self) -> SocketAddr {
81 self.peer
82 }
83
84 pub fn socket(&self) -> Arc<UdpSocket> {
85 self.socket.clone()
86 }
87}
88
89#[async_trait]
90impl HandshakeTransport for CborUdpTransport {
91 async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
92 let bytes = serde_cbor::to_vec(&msg)
93 .map_err(|e| HandshakeError::Transport(format!("encode: {}", e)))?;
94 let local_addr = self.local_addr();
95 info!(
96 "[ALPINE][TX] msg_type={} local_addr={} remote_addr={} bytes={}",
97 message_label(&msg),
98 local_addr,
99 self.peer,
100 bytes.len()
101 );
102 trace!(peer=%self.peer, len=%bytes.len(), message=?msg, "handshake send");
103 self.socket
104 .send_to(&bytes, self.peer)
105 .await
106 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
107 Ok(())
108 }
109
110 async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
111 let mut buf = vec![0u8; self.max_size];
112 let (len, from) = self
113 .socket
114 .recv_from(&mut buf)
115 .await
116 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
117 let local_addr = self.local_addr();
118 let preview_len = len.min(32);
119 let preview = hex::encode(&buf[..preview_len]);
120 let tail_len = len.min(16);
121 let tail = hex::encode(&buf[len.saturating_sub(tail_len)..len]);
122 info!(
123 "[ALPINE][RX] raw packet received local_addr={} from={} bytes={} buf_cap={} first32={} last16={}",
124 local_addr,
125 from,
126 len,
127 self.max_size,
128 preview,
129 tail
130 );
131 trace!(peer=%from, len=%len, "handshake raw recv");
132 if self.debug_cbor {
133 debug!(
134 "[ALPINE][HANDSHAKE][DEBUG_CBOR] raw_hex={}",
135 hex::encode(&buf[..len])
136 );
137 log_cbor_structure(&buf[..len]);
138 }
139 let mut msg = serde_cbor::from_slice(&buf[..len])
140 .map_err(|e| HandshakeError::Transport(format!("decode: {}", e)));
141 if let Err(_) = &msg {
142 if !buf.is_empty() && (buf[0] & 0xE0) == 0xA0 {
143 debug!(
144 "[ALPINE][HANDSHAKE][RX] attempting truncated CBOR map repair len={} cap={} first_byte=0x{:x}",
145 len,
146 self.max_size,
147 buf[0]
148 );
149 let mut repaired = buf[..len].to_vec();
150 repaired[0] = 0xBF; repaired.push(0xFF); msg = serde_cbor::from_slice(&repaired)
153 .map_err(|e| HandshakeError::Transport(format!("decode(repaired): {}", e)));
154 }
155 }
156 if let Ok(parsed) = &msg {
157 info!(
158 "[ALPINE][HANDSHAKE][RX][parsed] variant={} local_addr={} from={} fields={}",
159 message_label(parsed),
160 local_addr,
161 from,
162 describe_fields(parsed)
163 );
164 }
165 trace!(peer=%from, result=?msg, "handshake parsed message");
166 msg
167 }
168}
169
170#[derive(Debug)]
172pub struct TimeoutTransport<T> {
173 inner: T,
174 recv_timeout: Duration,
175}
176
177impl<T> TimeoutTransport<T> {
178 pub fn new(inner: T, recv_timeout: Duration) -> Self {
179 Self {
180 inner,
181 recv_timeout,
182 }
183 }
184}
185
186#[async_trait]
187impl<T> HandshakeTransport for TimeoutTransport<T>
188where
189 T: HandshakeTransport + Send,
190{
191 async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
192 self.inner.send(msg).await
193 }
194
195 async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
196 debug!(
197 "[ALPINE][HANDSHAKE] recv with timeout_ms={}",
198 self.recv_timeout.as_millis()
199 );
200 match time::timeout(self.recv_timeout, self.inner.recv()).await {
201 Ok(res) => res,
202 Err(_) => Err(HandshakeError::Transport("recv timeout".into())),
203 }
204 }
205}
206
207pub struct ReliableControlChannel<T> {
209 transport: T,
210 seq: u64,
211 max_attempts: u8,
212 base_timeout: Duration,
213 drop_threshold: u8,
214}
215
216impl<T> ReliableControlChannel<T> {
217 pub fn new(transport: T) -> Self {
218 Self {
219 transport,
220 seq: 0,
221 max_attempts: 5,
222 base_timeout: Duration::from_millis(200),
223 drop_threshold: 5,
224 }
225 }
226}
227
228impl<T> ReliableControlChannel<T>
229where
230 T: HandshakeTransport + Send,
231{
232 pub async fn send_reliable(
233 &mut self,
234 mut envelope: ControlEnvelope,
235 ) -> Result<Acknowledge, HandshakeError> {
236 self.seq = self.seq.wrapping_add(1);
237 envelope.seq = self.seq;
238
239 let mut attempt: u8 = 0;
240 loop {
241 attempt += 1;
242 self.transport
243 .send(HandshakeMessage::Control(envelope.clone()))
244 .await?;
245
246 let timeout = self
247 .base_timeout
248 .checked_mul(2u32.saturating_pow((attempt - 1) as u32))
249 .unwrap_or(self.base_timeout * 4);
250
251 match time::timeout(timeout, self.transport.recv()).await {
252 Ok(Ok(HandshakeMessage::Ack(ack))) => {
253 if ack.seq == envelope.seq && ack.ok {
254 return Ok(ack);
255 }
256 }
257 Ok(Ok(HandshakeMessage::Keepalive(_))) => {
258 attempt = 0;
260 }
261 _ => {
262 if attempt >= self.max_attempts || attempt >= self.drop_threshold {
263 return Err(HandshakeError::Transport(
264 "control channel retransmit limit exceeded".into(),
265 ));
266 }
267 }
268 }
269 }
270 }
271
272 pub fn next_seq(&mut self) -> u64 {
273 self.seq = self.seq.wrapping_add(1);
274 self.seq
275 }
276}
277
278fn message_label(msg: &HandshakeMessage) -> &'static str {
279 match msg {
280 HandshakeMessage::SessionInit(_) => "SessionInit",
281 HandshakeMessage::SessionAck(_) => "SessionAck",
282 HandshakeMessage::SessionReady(_) => "SessionReady",
283 HandshakeMessage::SessionComplete(_) => "SessionComplete",
284 HandshakeMessage::SessionEstablished(_) => "SessionEstablished",
285 HandshakeMessage::Keepalive(_) => "Keepalive",
286 HandshakeMessage::Control(_) => "Control",
287 HandshakeMessage::Ack(_) => "Ack",
288 }
289}
290
291fn log_cbor_structure(bytes: &[u8]) {
292 if let Ok(value) = serde_cbor::from_slice::<CborValue>(bytes) {
293 match value {
294 CborValue::Map(map) => {
295 debug!(
296 "[ALPINE][HANDSHAKE][DEBUG_CBOR] map_len={} entries={}",
297 map.len(),
298 map.len()
299 );
300 for (idx, (key, val)) in map.iter().enumerate() {
301 debug!(
302 "[ALPINE][HANDSHAKE][DEBUG_CBOR] entry={} key_type={} value_type={}",
303 idx,
304 describe_value(key),
305 describe_value(val)
306 );
307 }
308 }
309 other => {
310 debug!(
311 "[ALPINE][HANDSHAKE][DEBUG_CBOR] non-map top-level type={}",
312 describe_value(&other)
313 );
314 }
315 }
316 } else {
317 debug!("[ALPINE][HANDSHAKE][DEBUG_CBOR] decode failed");
318 }
319}
320
321fn describe_value(val: &CborValue) -> &'static str {
322 match val {
323 CborValue::Null => "null",
324 CborValue::Bool(_) => "bool",
325 CborValue::Integer(_) => "integer",
326 CborValue::Bytes(_) => "bytes",
327 CborValue::Text(_) => "text",
328 CborValue::Array(_) => "array",
329 CborValue::Map(_) => "map",
330 CborValue::Tag(_, _) => "tag",
331 CborValue::Float(_) => "float",
332 _ => "other",
333 }
334}
335
336fn describe_fields(msg: &HandshakeMessage) -> String {
337 match msg {
338 HandshakeMessage::SessionInit(init) => format!(
339 "session_id={} controller_nonce_len={} controller_pubkey_len={} requested={:?}",
340 init.session_id,
341 init.controller_nonce.len(),
342 init.controller_pubkey.len(),
343 init.requested
344 ),
345 HandshakeMessage::SessionAck(ack) => format!(
346 "session_id={} device_nonce_len={} device_pubkey_len={} device_identity_pubkey_len={} device_id={}",
347 ack.session_id,
348 ack.device_nonce.len(),
349 ack.device_pubkey.len(),
350 ack.device_identity_pubkey.len(),
351 ack.device_identity.device_id
352 ),
353 HandshakeMessage::SessionReady(ready) => format!(
354 "session_id={} mac_len={}",
355 ready.session_id,
356 ready.mac.len()
357 ),
358 HandshakeMessage::SessionComplete(comp) => format!(
359 "session_id={} ok={} error={:?}",
360 comp.session_id, comp.ok, comp.error
361 ),
362 HandshakeMessage::SessionEstablished(est) => format!(
363 "session_id={} controller_nonce_len={} device_nonce_len={} device_id={}",
364 est.session_id,
365 est.controller_nonce.len(),
366 est.device_nonce.len(),
367 est.device_identity.device_id
368 ),
369 HandshakeMessage::Keepalive(k) => {
370 format!("session_id={} tick_ms={}", k.session_id, k.tick_ms)
371 }
372 HandshakeMessage::Control(ctrl) => format!(
373 "session_id={} seq={} op={:?} mac_len={}",
374 ctrl.session_id,
375 ctrl.seq,
376 ctrl.op,
377 ctrl.mac.len()
378 ),
379 HandshakeMessage::Ack(ack) => format!(
380 "session_id={} seq={} ok={} detail_present={} payload_present={}",
381 ack.session_id,
382 ack.seq,
383 ack.ok,
384 ack.detail.is_some(),
385 ack.payload.is_some()
386 ),
387 }
388}