hashtree_cli/server/
stun.rs1#[cfg(feature = "stun")]
2mod enabled {
3 use std::net::SocketAddr;
9 use std::sync::Arc;
10 use tokio::net::UdpSocket;
11 use tracing::{debug, error, info};
12 use webrtc_stun::message::{
13 Message, MessageType, Setter, BINDING_REQUEST,
14 CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
15 };
16 use webrtc_stun::xoraddr::XORMappedAddress;
17
18 pub const DEFAULT_STUN_PORT: u16 = 3478;
20
21 pub struct StunServerHandle {
23 pub addr: SocketAddr,
24 shutdown: Arc<tokio::sync::Notify>,
25 }
26
27 impl StunServerHandle {
28 pub fn shutdown(&self) {
30 self.shutdown.notify_one();
31 }
32 }
33
34 pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
36 let socket = UdpSocket::bind(addr).await?;
37 let bound_addr = socket.local_addr()?;
38 let shutdown = Arc::new(tokio::sync::Notify::new());
39 let shutdown_clone = shutdown.clone();
40
41 info!("STUN server listening on {}", bound_addr);
42
43 tokio::spawn(async move {
44 run_stun_server(socket, shutdown_clone).await;
45 });
46
47 Ok(StunServerHandle {
48 addr: bound_addr,
49 shutdown,
50 })
51 }
52
53 async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
54 let mut buf = vec![0u8; 1500]; loop {
57 tokio::select! {
58 result = socket.recv_from(&mut buf) => {
59 match result {
60 Ok((len, src_addr)) => {
61 if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
62 debug!("Error handling STUN packet from {}: {}", src_addr, e);
63 }
64 }
65 Err(e) => {
66 error!("Error receiving UDP packet: {}", e);
67 }
68 }
69 }
70 _ = shutdown.notified() => {
71 info!("STUN server shutting down");
72 break;
73 }
74 }
75 }
76 }
77
78 async fn handle_stun_packet(
79 socket: &UdpSocket,
80 data: &[u8],
81 src_addr: SocketAddr,
82 ) -> anyhow::Result<()> {
83 let mut msg = Message::new();
85 msg.raw = data.to_vec();
86
87 if let Err(e) = msg.decode() {
88 debug!("Failed to decode STUN message from {}: {}", src_addr, e);
89 return Ok(()); }
91
92 if msg.typ != BINDING_REQUEST {
94 debug!("Received non-binding STUN message type {:?} from {}", msg.typ, src_addr);
95 return Ok(());
96 }
97
98 debug!("STUN binding request from {}", src_addr);
99
100 let mut response = Message::new();
102 response.typ = MessageType {
103 method: METHOD_BINDING,
104 class: CLASS_SUCCESS_RESPONSE,
105 };
106 response.transaction_id = msg.transaction_id;
107
108 let xor_addr = XORMappedAddress {
110 ip: src_addr.ip(),
111 port: src_addr.port(),
112 };
113 xor_addr.add_to(&mut response)?;
114
115 response.encode();
117 socket.send_to(&response.raw, src_addr).await?;
118
119 debug!("STUN binding response sent to {} (mapped: {})", src_addr, src_addr);
120
121 Ok(())
122 }
123
124 #[cfg(test)]
125 mod tests {
126 use super::*;
127 use std::time::Duration;
128 use webrtc_stun::message::Getter;
129
130 #[tokio::test]
131 async fn test_stun_server_responds_to_binding_request() {
132 let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
134 let handle = start_stun_server(server_addr).await.unwrap();
135 let server_addr = handle.addr;
136
137 tokio::time::sleep(Duration::from_millis(50)).await;
139
140 let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
142 let client_addr = client.local_addr().unwrap();
143
144 let mut request = Message::new();
146 request.typ = BINDING_REQUEST;
147 request.new_transaction_id().expect("Failed to generate transaction ID");
148 request.encode();
149
150 client.send_to(&request.raw, server_addr).await.unwrap();
152
153 let mut buf = vec![0u8; 1500];
155 let result = tokio::time::timeout(
156 Duration::from_secs(2),
157 client.recv_from(&mut buf)
158 ).await;
159
160 let (len, _) = result.expect("Timeout waiting for response").unwrap();
161
162 let mut response = Message::new();
164 response.raw = buf[..len].to_vec();
165 response.decode().expect("Failed to decode response");
166
167 assert_eq!(response.typ.method, METHOD_BINDING);
169 assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
170 assert_eq!(response.transaction_id, request.transaction_id);
171
172 let mut xor_addr = XORMappedAddress::default();
174 xor_addr.get_from(&response).expect("Failed to get XOR-MAPPED-ADDRESS");
175
176 assert_eq!(xor_addr.ip, client_addr.ip());
178 assert_eq!(xor_addr.port, client_addr.port());
179
180 handle.shutdown();
182 }
183 }
184}
185
186#[cfg(feature = "stun")]
187pub use enabled::*;
188
189#[cfg(not(feature = "stun"))]
190mod disabled {
191 use std::net::SocketAddr;
192
193 pub const DEFAULT_STUN_PORT: u16 = 3478;
195
196 pub struct StunServerHandle {
198 pub addr: SocketAddr,
199 }
200
201 impl StunServerHandle {
202 pub fn shutdown(&self) {}
203 }
204
205 pub async fn start_stun_server(_addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
206 anyhow::bail!("STUN server support disabled; rebuild with --features stun")
207 }
208}
209
210#[cfg(not(feature = "stun"))]
211pub use disabled::*;