hashtree_cli/server/
stun.rs1use std::net::SocketAddr;
7use std::sync::Arc;
8use tokio::net::UdpSocket;
9use tracing::{debug, error, info};
10use webrtc_stun::message::{
11 Message, MessageType, Setter, BINDING_REQUEST,
12 CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
13};
14use webrtc_stun::xoraddr::XORMappedAddress;
15
16pub const DEFAULT_STUN_PORT: u16 = 3478;
18
19pub struct StunServerHandle {
21 pub addr: SocketAddr,
22 shutdown: Arc<tokio::sync::Notify>,
23}
24
25impl StunServerHandle {
26 pub fn shutdown(&self) {
28 self.shutdown.notify_one();
29 }
30}
31
32pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
34 let socket = UdpSocket::bind(addr).await?;
35 let bound_addr = socket.local_addr()?;
36 let shutdown = Arc::new(tokio::sync::Notify::new());
37 let shutdown_clone = shutdown.clone();
38
39 info!("STUN server listening on {}", bound_addr);
40
41 tokio::spawn(async move {
42 run_stun_server(socket, shutdown_clone).await;
43 });
44
45 Ok(StunServerHandle {
46 addr: bound_addr,
47 shutdown,
48 })
49}
50
51async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
52 let mut buf = vec![0u8; 1500]; loop {
55 tokio::select! {
56 result = socket.recv_from(&mut buf) => {
57 match result {
58 Ok((len, src_addr)) => {
59 if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
60 debug!("Error handling STUN packet from {}: {}", src_addr, e);
61 }
62 }
63 Err(e) => {
64 error!("Error receiving UDP packet: {}", e);
65 }
66 }
67 }
68 _ = shutdown.notified() => {
69 info!("STUN server shutting down");
70 break;
71 }
72 }
73 }
74}
75
76async fn handle_stun_packet(
77 socket: &UdpSocket,
78 data: &[u8],
79 src_addr: SocketAddr,
80) -> anyhow::Result<()> {
81 let mut msg = Message::new();
83 msg.raw = data.to_vec();
84
85 if let Err(e) = msg.decode() {
86 debug!("Failed to decode STUN message from {}: {}", src_addr, e);
87 return Ok(()); }
89
90 if msg.typ != BINDING_REQUEST {
92 debug!("Received non-binding STUN message type {:?} from {}", msg.typ, src_addr);
93 return Ok(());
94 }
95
96 debug!("STUN binding request from {}", src_addr);
97
98 let mut response = Message::new();
100 response.typ = MessageType {
101 method: METHOD_BINDING,
102 class: CLASS_SUCCESS_RESPONSE,
103 };
104 response.transaction_id = msg.transaction_id;
105
106 let xor_addr = XORMappedAddress {
108 ip: src_addr.ip(),
109 port: src_addr.port(),
110 };
111 xor_addr.add_to(&mut response)?;
112
113 response.encode();
115 socket.send_to(&response.raw, src_addr).await?;
116
117 debug!("STUN binding response sent to {} (mapped: {})", src_addr, src_addr);
118
119 Ok(())
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use std::time::Duration;
126 use webrtc_stun::message::Getter;
127
128 #[tokio::test]
129 async fn test_stun_server_responds_to_binding_request() {
130 let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
132 let handle = start_stun_server(server_addr).await.unwrap();
133 let server_addr = handle.addr;
134
135 tokio::time::sleep(Duration::from_millis(50)).await;
137
138 let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
140 let client_addr = client.local_addr().unwrap();
141
142 let mut request = Message::new();
144 request.typ = BINDING_REQUEST;
145 request.new_transaction_id().expect("Failed to generate transaction ID");
146 request.encode();
147
148 client.send_to(&request.raw, server_addr).await.unwrap();
150
151 let mut buf = vec![0u8; 1500];
153 let result = tokio::time::timeout(
154 Duration::from_secs(2),
155 client.recv_from(&mut buf)
156 ).await;
157
158 let (len, _) = result.expect("Timeout waiting for response").unwrap();
159
160 let mut response = Message::new();
162 response.raw = buf[..len].to_vec();
163 response.decode().expect("Failed to decode response");
164
165 assert_eq!(response.typ.method, METHOD_BINDING);
167 assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
168 assert_eq!(response.transaction_id, request.transaction_id);
169
170 let mut xor_addr = XORMappedAddress::default();
172 xor_addr.get_from(&response).expect("Failed to get XOR-MAPPED-ADDRESS");
173
174 assert_eq!(xor_addr.ip, client_addr.ip());
176 assert_eq!(xor_addr.port, client_addr.port());
177
178 handle.shutdown();
180 }
181}