Skip to main content

hashtree_cli/server/
stun.rs

1#[cfg(feature = "stun")]
2mod enabled {
3    //! STUN server for WebRTC NAT traversal
4    //!
5    //! Implements a simple STUN server that responds to binding requests
6    //! with the client's reflexive transport address (XOR-MAPPED-ADDRESS).
7
8    use std::net::SocketAddr;
9    use std::sync::Arc;
10    use tokio::net::UdpSocket;
11    use tracing::{debug, error, info};
12    use webrtc_stun::message::{
13        Message, MessageType, Setter, BINDING_REQUEST,
14        CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
15    };
16    use webrtc_stun::xoraddr::XORMappedAddress;
17
18    /// Default STUN port (RFC 5389)
19    pub const DEFAULT_STUN_PORT: u16 = 3478;
20
21    /// STUN server handle for graceful shutdown
22    pub struct StunServerHandle {
23        pub addr: SocketAddr,
24        shutdown: Arc<tokio::sync::Notify>,
25    }
26
27    impl StunServerHandle {
28        /// Signal the server to shutdown
29        pub fn shutdown(&self) {
30            self.shutdown.notify_one();
31        }
32    }
33
34    /// Start a STUN server on the specified address
35    pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
36        let socket = UdpSocket::bind(addr).await?;
37        let bound_addr = socket.local_addr()?;
38        let shutdown = Arc::new(tokio::sync::Notify::new());
39        let shutdown_clone = shutdown.clone();
40
41        info!("STUN server listening on {}", bound_addr);
42
43        tokio::spawn(async move {
44            run_stun_server(socket, shutdown_clone).await;
45        });
46
47        Ok(StunServerHandle {
48            addr: bound_addr,
49            shutdown,
50        })
51    }
52
53    async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
54        let mut buf = vec![0u8; 1500]; // Standard MTU size
55
56        loop {
57            tokio::select! {
58                result = socket.recv_from(&mut buf) => {
59                    match result {
60                        Ok((len, src_addr)) => {
61                            if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
62                                debug!("Error handling STUN packet from {}: {}", src_addr, e);
63                            }
64                        }
65                        Err(e) => {
66                            error!("Error receiving UDP packet: {}", e);
67                        }
68                    }
69                }
70                _ = shutdown.notified() => {
71                    info!("STUN server shutting down");
72                    break;
73                }
74            }
75        }
76    }
77
78    async fn handle_stun_packet(
79        socket: &UdpSocket,
80        data: &[u8],
81        src_addr: SocketAddr,
82    ) -> anyhow::Result<()> {
83        // Parse the incoming message
84        let mut msg = Message::new();
85        msg.raw = data.to_vec();
86
87        if let Err(e) = msg.decode() {
88            debug!("Failed to decode STUN message from {}: {}", src_addr, e);
89            return Ok(()); // Silently ignore non-STUN packets
90        }
91
92        // Check if it's a binding request
93        if msg.typ != BINDING_REQUEST {
94            debug!("Received non-binding STUN message type {:?} from {}", msg.typ, src_addr);
95            return Ok(());
96        }
97
98        debug!("STUN binding request from {}", src_addr);
99
100        // Build binding success response
101        let mut response = Message::new();
102        response.typ = MessageType {
103            method: METHOD_BINDING,
104            class: CLASS_SUCCESS_RESPONSE,
105        };
106        response.transaction_id = msg.transaction_id;
107
108        // Add XOR-MAPPED-ADDRESS with the client's reflexive address
109        let xor_addr = XORMappedAddress {
110            ip: src_addr.ip(),
111            port: src_addr.port(),
112        };
113        xor_addr.add_to(&mut response)?;
114
115        // Encode and send the response
116        response.encode();
117        socket.send_to(&response.raw, src_addr).await?;
118
119        debug!("STUN binding response sent to {} (mapped: {})", src_addr, src_addr);
120
121        Ok(())
122    }
123
124    #[cfg(test)]
125    mod tests {
126        use super::*;
127        use std::time::Duration;
128        use webrtc_stun::message::Getter;
129
130        #[tokio::test]
131        async fn test_stun_server_responds_to_binding_request() {
132            // Start STUN server on random port
133            let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
134            let handle = start_stun_server(server_addr).await.unwrap();
135            let server_addr = handle.addr;
136
137            // Give server time to start
138            tokio::time::sleep(Duration::from_millis(50)).await;
139
140            // Create client socket
141            let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
142            let client_addr = client.local_addr().unwrap();
143
144            // Build binding request
145            let mut request = Message::new();
146            request.typ = BINDING_REQUEST;
147            request.new_transaction_id().expect("Failed to generate transaction ID");
148            request.encode();
149
150            // Send request
151            client.send_to(&request.raw, server_addr).await.unwrap();
152
153            // Receive response
154            let mut buf = vec![0u8; 1500];
155            let result = tokio::time::timeout(
156                Duration::from_secs(2),
157                client.recv_from(&mut buf)
158            ).await;
159
160            let (len, _) = result.expect("Timeout waiting for response").unwrap();
161
162            // Parse response
163            let mut response = Message::new();
164            response.raw = buf[..len].to_vec();
165            response.decode().expect("Failed to decode response");
166
167            // Verify response type
168            assert_eq!(response.typ.method, METHOD_BINDING);
169            assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
170            assert_eq!(response.transaction_id, request.transaction_id);
171
172            // Extract XOR-MAPPED-ADDRESS
173            let mut xor_addr = XORMappedAddress::default();
174            xor_addr.get_from(&response).expect("Failed to get XOR-MAPPED-ADDRESS");
175
176            // The mapped address should match our client's address
177            assert_eq!(xor_addr.ip, client_addr.ip());
178            assert_eq!(xor_addr.port, client_addr.port());
179
180            // Cleanup
181            handle.shutdown();
182        }
183    }
184}
185
186#[cfg(feature = "stun")]
187pub use enabled::*;
188
189#[cfg(not(feature = "stun"))]
190mod disabled {
191    use std::net::SocketAddr;
192
193    /// Default STUN port (RFC 5389)
194    pub const DEFAULT_STUN_PORT: u16 = 3478;
195
196    /// Placeholder handle when STUN server support is disabled.
197    pub struct StunServerHandle {
198        pub addr: SocketAddr,
199    }
200
201    impl StunServerHandle {
202        pub fn shutdown(&self) {}
203    }
204
205    pub async fn start_stun_server(_addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
206        anyhow::bail!("STUN server support disabled; rebuild with --features stun")
207    }
208}
209
210#[cfg(not(feature = "stun"))]
211pub use disabled::*;