1use std::error::Error;
2use std::net::SocketAddr;
3
4use async_trait::async_trait;
5use serde_cbor;
6use tokio::net::UdpSocket;
7
8use crate::crypto::X25519KeyExchange;
9use crate::handshake::{HandshakeContext, HandshakeError, HandshakeMessage, HandshakeTransport};
10use crate::messages::{CapabilitySet, DeviceIdentity};
11use crate::session::{AlnpSession, StaticKeyAuthenticator};
12use uuid::Uuid;
13
14struct UdpHandshakeTransport {
15 socket: UdpSocket,
16 peer: SocketAddr,
17 buf_size: usize,
18}
19
20impl UdpHandshakeTransport {
21 fn new(socket: UdpSocket, peer: SocketAddr, buf_size: usize) -> Self {
22 Self {
23 socket,
24 peer,
25 buf_size,
26 }
27 }
28}
29
30#[async_trait]
31impl HandshakeTransport for UdpHandshakeTransport {
32 async fn send(&mut self, msg: HandshakeMessage) -> Result<(), HandshakeError> {
33 let bytes = serde_cbor::to_vec(&msg)
34 .map_err(|e| HandshakeError::Protocol(format!("encode: {}", e)))?;
35 self.socket
36 .send_to(&bytes, self.peer)
37 .await
38 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
39 Ok(())
40 }
41
42 async fn recv(&mut self) -> Result<HandshakeMessage, HandshakeError> {
43 let mut buf = vec![0u8; self.buf_size];
44 let (len, _) = self
45 .socket
46 .recv_from(&mut buf)
47 .await
48 .map_err(|e| HandshakeError::Transport(e.to_string()))?;
49 serde_cbor::from_slice(&buf[..len])
50 .map_err(|e| HandshakeError::Protocol(format!("decode: {}", e)))
51 }
52}
53
54pub fn make_identity(prefix: &str) -> DeviceIdentity {
55 DeviceIdentity {
56 device_id: Uuid::new_v4().to_string(),
57 manufacturer_id: format!("{prefix}-manu"),
58 model_id: format!("{prefix}-model"),
59 hardware_rev: "rev1".into(),
60 firmware_rev: "1.0.11".into(),
61 }
62}
63
64pub async fn run_udp_handshake() -> Result<(AlnpSession, AlnpSession), Box<dyn Error>> {
65 let controller_socket = UdpSocket::bind(("127.0.0.1", 0)).await?;
66 let node_socket = UdpSocket::bind(("127.0.0.1", 0)).await?;
67 let controller_addr = controller_socket.local_addr()?;
68 let node_addr = node_socket.local_addr()?;
69
70 let controller_task = tokio::spawn(async move {
71 let mut transport = UdpHandshakeTransport::new(controller_socket, node_addr, 4096);
72 AlnpSession::connect(
73 make_identity("controller"),
74 CapabilitySet::default(),
75 StaticKeyAuthenticator::default(),
76 X25519KeyExchange::new(),
77 HandshakeContext::default(),
78 &mut transport,
79 )
80 .await
81 });
82
83 let node_task = tokio::spawn(async move {
84 let mut transport = UdpHandshakeTransport::new(node_socket, controller_addr, 4096);
85 AlnpSession::accept(
86 make_identity("node"),
87 CapabilitySet::default(),
88 StaticKeyAuthenticator::default(),
89 X25519KeyExchange::new(),
90 HandshakeContext::default(),
91 &mut transport,
92 )
93 .await
94 });
95
96 let (controller_res, node_res) = tokio::join!(controller_task, node_task);
97 let controller_session = controller_res??;
98 let node_session = node_res??;
99 Ok((controller_session, node_session))
100}