hashtree_cli/server/
stun.rs1#[cfg(feature = "stun")]
2mod enabled {
3 use std::net::SocketAddr;
9 use std::sync::Arc;
10 use tokio::net::UdpSocket;
11 use tracing::{debug, error, info};
12 use webrtc_stun::message::{
13 Message, MessageType, Setter, BINDING_REQUEST, CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
14 };
15 use webrtc_stun::xoraddr::XORMappedAddress;
16
17 pub const DEFAULT_STUN_PORT: u16 = 3478;
19
20 pub struct StunServerHandle {
22 pub addr: SocketAddr,
23 shutdown: Arc<tokio::sync::Notify>,
24 }
25
26 impl StunServerHandle {
27 pub fn shutdown(&self) {
29 self.shutdown.notify_one();
30 }
31 }
32
33 pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
35 let socket = UdpSocket::bind(addr).await?;
36 let bound_addr = socket.local_addr()?;
37 let shutdown = Arc::new(tokio::sync::Notify::new());
38 let shutdown_clone = shutdown.clone();
39
40 info!("STUN server listening on {}", bound_addr);
41
42 tokio::spawn(async move {
43 run_stun_server(socket, shutdown_clone).await;
44 });
45
46 Ok(StunServerHandle {
47 addr: bound_addr,
48 shutdown,
49 })
50 }
51
52 async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
53 let mut buf = vec![0u8; 1500]; loop {
56 tokio::select! {
57 result = socket.recv_from(&mut buf) => {
58 match result {
59 Ok((len, src_addr)) => {
60 if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
61 debug!("Error handling STUN packet from {}: {}", src_addr, e);
62 }
63 }
64 Err(e) => {
65 error!("Error receiving UDP packet: {}", e);
66 }
67 }
68 }
69 _ = shutdown.notified() => {
70 info!("STUN server shutting down");
71 break;
72 }
73 }
74 }
75 }
76
77 async fn handle_stun_packet(
78 socket: &UdpSocket,
79 data: &[u8],
80 src_addr: SocketAddr,
81 ) -> anyhow::Result<()> {
82 let mut msg = Message::new();
84 msg.raw = data.to_vec();
85
86 if let Err(e) = msg.decode() {
87 debug!("Failed to decode STUN message from {}: {}", src_addr, e);
88 return Ok(()); }
90
91 if msg.typ != BINDING_REQUEST {
93 debug!(
94 "Received non-binding STUN message type {:?} from {}",
95 msg.typ, src_addr
96 );
97 return Ok(());
98 }
99
100 debug!("STUN binding request from {}", src_addr);
101
102 let mut response = Message::new();
104 response.typ = MessageType {
105 method: METHOD_BINDING,
106 class: CLASS_SUCCESS_RESPONSE,
107 };
108 response.transaction_id = msg.transaction_id;
109
110 let xor_addr = XORMappedAddress {
112 ip: src_addr.ip(),
113 port: src_addr.port(),
114 };
115 xor_addr.add_to(&mut response)?;
116
117 response.encode();
119 socket.send_to(&response.raw, src_addr).await?;
120
121 debug!(
122 "STUN binding response sent to {} (mapped: {})",
123 src_addr, src_addr
124 );
125
126 Ok(())
127 }
128
129 #[cfg(test)]
130 mod tests {
131 use super::*;
132 use std::time::Duration;
133 use webrtc_stun::message::Getter;
134
135 #[tokio::test]
136 async fn test_stun_server_responds_to_binding_request() {
137 let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
139 let handle = start_stun_server(server_addr).await.unwrap();
140 let server_addr = handle.addr;
141
142 tokio::time::sleep(Duration::from_millis(50)).await;
144
145 let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
147 let client_addr = client.local_addr().unwrap();
148
149 let mut request = Message::new();
151 request.typ = BINDING_REQUEST;
152 request
153 .new_transaction_id()
154 .expect("Failed to generate transaction ID");
155 request.encode();
156
157 client.send_to(&request.raw, server_addr).await.unwrap();
159
160 let mut buf = vec![0u8; 1500];
162 let result =
163 tokio::time::timeout(Duration::from_secs(2), client.recv_from(&mut buf)).await;
164
165 let (len, _) = result.expect("Timeout waiting for response").unwrap();
166
167 let mut response = Message::new();
169 response.raw = buf[..len].to_vec();
170 response.decode().expect("Failed to decode response");
171
172 assert_eq!(response.typ.method, METHOD_BINDING);
174 assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
175 assert_eq!(response.transaction_id, request.transaction_id);
176
177 let mut xor_addr = XORMappedAddress::default();
179 xor_addr
180 .get_from(&response)
181 .expect("Failed to get XOR-MAPPED-ADDRESS");
182
183 assert_eq!(xor_addr.ip, client_addr.ip());
185 assert_eq!(xor_addr.port, client_addr.port());
186
187 handle.shutdown();
189 }
190 }
191}
192
193#[cfg(feature = "stun")]
194pub use enabled::*;
195
196#[cfg(not(feature = "stun"))]
197mod disabled {
198 use std::net::SocketAddr;
199
200 pub const DEFAULT_STUN_PORT: u16 = 3478;
202
203 pub struct StunServerHandle {
205 pub addr: SocketAddr,
206 }
207
208 impl StunServerHandle {
209 pub fn shutdown(&self) {}
210 }
211
212 pub async fn start_stun_server(_addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
213 anyhow::bail!("STUN server support disabled; rebuild with --features stun")
214 }
215}
216
217#[cfg(not(feature = "stun"))]
218pub use disabled::*;