Skip to main content

hashtree_cli/server/
stun.rs

1#[cfg(feature = "stun")]
2mod enabled {
3    //! STUN server for WebRTC NAT traversal
4    //!
5    //! Implements a simple STUN server that responds to binding requests
6    //! with the client's reflexive transport address (XOR-MAPPED-ADDRESS).
7
8    use std::net::SocketAddr;
9    use std::sync::Arc;
10    use tokio::net::UdpSocket;
11    use tracing::{debug, error, info};
12    use webrtc_stun::message::{
13        Message, MessageType, Setter, BINDING_REQUEST, CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
14    };
15    use webrtc_stun::xoraddr::XORMappedAddress;
16
17    /// Default STUN port (RFC 5389)
18    pub const DEFAULT_STUN_PORT: u16 = 3478;
19
20    /// STUN server handle for graceful shutdown
21    pub struct StunServerHandle {
22        pub addr: SocketAddr,
23        shutdown: Arc<tokio::sync::Notify>,
24    }
25
26    impl StunServerHandle {
27        /// Signal the server to shutdown
28        pub fn shutdown(&self) {
29            self.shutdown.notify_one();
30        }
31    }
32
33    /// Start a STUN server on the specified address
34    pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
35        let socket = UdpSocket::bind(addr).await?;
36        let bound_addr = socket.local_addr()?;
37        let shutdown = Arc::new(tokio::sync::Notify::new());
38        let shutdown_clone = shutdown.clone();
39
40        info!("STUN server listening on {}", bound_addr);
41
42        tokio::spawn(async move {
43            run_stun_server(socket, shutdown_clone).await;
44        });
45
46        Ok(StunServerHandle {
47            addr: bound_addr,
48            shutdown,
49        })
50    }
51
52    async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
53        let mut buf = vec![0u8; 1500]; // Standard MTU size
54
55        loop {
56            tokio::select! {
57                result = socket.recv_from(&mut buf) => {
58                    match result {
59                        Ok((len, src_addr)) => {
60                            if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
61                                debug!("Error handling STUN packet from {}: {}", src_addr, e);
62                            }
63                        }
64                        Err(e) => {
65                            error!("Error receiving UDP packet: {}", e);
66                        }
67                    }
68                }
69                _ = shutdown.notified() => {
70                    info!("STUN server shutting down");
71                    break;
72                }
73            }
74        }
75    }
76
77    async fn handle_stun_packet(
78        socket: &UdpSocket,
79        data: &[u8],
80        src_addr: SocketAddr,
81    ) -> anyhow::Result<()> {
82        // Parse the incoming message
83        let mut msg = Message::new();
84        msg.raw = data.to_vec();
85
86        if let Err(e) = msg.decode() {
87            debug!("Failed to decode STUN message from {}: {}", src_addr, e);
88            return Ok(()); // Silently ignore non-STUN packets
89        }
90
91        // Check if it's a binding request
92        if msg.typ != BINDING_REQUEST {
93            debug!(
94                "Received non-binding STUN message type {:?} from {}",
95                msg.typ, src_addr
96            );
97            return Ok(());
98        }
99
100        debug!("STUN binding request from {}", src_addr);
101
102        // Build binding success response
103        let mut response = Message::new();
104        response.typ = MessageType {
105            method: METHOD_BINDING,
106            class: CLASS_SUCCESS_RESPONSE,
107        };
108        response.transaction_id = msg.transaction_id;
109
110        // Add XOR-MAPPED-ADDRESS with the client's reflexive address
111        let xor_addr = XORMappedAddress {
112            ip: src_addr.ip(),
113            port: src_addr.port(),
114        };
115        xor_addr.add_to(&mut response)?;
116
117        // Encode and send the response
118        response.encode();
119        socket.send_to(&response.raw, src_addr).await?;
120
121        debug!(
122            "STUN binding response sent to {} (mapped: {})",
123            src_addr, src_addr
124        );
125
126        Ok(())
127    }
128
129    #[cfg(test)]
130    mod tests {
131        use super::*;
132        use std::time::Duration;
133        use webrtc_stun::message::Getter;
134
135        #[tokio::test]
136        async fn test_stun_server_responds_to_binding_request() {
137            // Start STUN server on random port
138            let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
139            let handle = start_stun_server(server_addr).await.unwrap();
140            let server_addr = handle.addr;
141
142            // Give server time to start
143            tokio::time::sleep(Duration::from_millis(50)).await;
144
145            // Create client socket
146            let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
147            let client_addr = client.local_addr().unwrap();
148
149            // Build binding request
150            let mut request = Message::new();
151            request.typ = BINDING_REQUEST;
152            request
153                .new_transaction_id()
154                .expect("Failed to generate transaction ID");
155            request.encode();
156
157            // Send request
158            client.send_to(&request.raw, server_addr).await.unwrap();
159
160            // Receive response
161            let mut buf = vec![0u8; 1500];
162            let result =
163                tokio::time::timeout(Duration::from_secs(2), client.recv_from(&mut buf)).await;
164
165            let (len, _) = result.expect("Timeout waiting for response").unwrap();
166
167            // Parse response
168            let mut response = Message::new();
169            response.raw = buf[..len].to_vec();
170            response.decode().expect("Failed to decode response");
171
172            // Verify response type
173            assert_eq!(response.typ.method, METHOD_BINDING);
174            assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
175            assert_eq!(response.transaction_id, request.transaction_id);
176
177            // Extract XOR-MAPPED-ADDRESS
178            let mut xor_addr = XORMappedAddress::default();
179            xor_addr
180                .get_from(&response)
181                .expect("Failed to get XOR-MAPPED-ADDRESS");
182
183            // The mapped address should match our client's address
184            assert_eq!(xor_addr.ip, client_addr.ip());
185            assert_eq!(xor_addr.port, client_addr.port());
186
187            // Cleanup
188            handle.shutdown();
189        }
190    }
191}
192
193#[cfg(feature = "stun")]
194pub use enabled::*;
195
196#[cfg(not(feature = "stun"))]
197mod disabled {
198    use std::net::SocketAddr;
199
200    /// Default STUN port (RFC 5389)
201    pub const DEFAULT_STUN_PORT: u16 = 3478;
202
203    /// Placeholder handle when STUN server support is disabled.
204    pub struct StunServerHandle {
205        pub addr: SocketAddr,
206    }
207
208    impl StunServerHandle {
209        pub fn shutdown(&self) {}
210    }
211
212    pub async fn start_stun_server(_addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
213        anyhow::bail!("STUN server support disabled; rebuild with --features stun")
214    }
215}
216
217#[cfg(not(feature = "stun"))]
218pub use disabled::*;