hashtree_cli/server/
stun.rs

1//! STUN server for WebRTC NAT traversal
2//!
3//! Implements a simple STUN server that responds to binding requests
4//! with the client's reflexive transport address (XOR-MAPPED-ADDRESS).
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8use tokio::net::UdpSocket;
9use tracing::{debug, error, info};
10use webrtc_stun::message::{
11    Message, MessageType, Setter, BINDING_REQUEST,
12    CLASS_SUCCESS_RESPONSE, METHOD_BINDING,
13};
14use webrtc_stun::xoraddr::XORMappedAddress;
15
16/// Default STUN port (RFC 5389)
17pub const DEFAULT_STUN_PORT: u16 = 3478;
18
19/// STUN server handle for graceful shutdown
20pub struct StunServerHandle {
21    pub addr: SocketAddr,
22    shutdown: Arc<tokio::sync::Notify>,
23}
24
25impl StunServerHandle {
26    /// Signal the server to shutdown
27    pub fn shutdown(&self) {
28        self.shutdown.notify_one();
29    }
30}
31
32/// Start a STUN server on the specified address
33pub async fn start_stun_server(addr: SocketAddr) -> anyhow::Result<StunServerHandle> {
34    let socket = UdpSocket::bind(addr).await?;
35    let bound_addr = socket.local_addr()?;
36    let shutdown = Arc::new(tokio::sync::Notify::new());
37    let shutdown_clone = shutdown.clone();
38
39    info!("STUN server listening on {}", bound_addr);
40
41    tokio::spawn(async move {
42        run_stun_server(socket, shutdown_clone).await;
43    });
44
45    Ok(StunServerHandle {
46        addr: bound_addr,
47        shutdown,
48    })
49}
50
51async fn run_stun_server(socket: UdpSocket, shutdown: Arc<tokio::sync::Notify>) {
52    let mut buf = vec![0u8; 1500]; // Standard MTU size
53
54    loop {
55        tokio::select! {
56            result = socket.recv_from(&mut buf) => {
57                match result {
58                    Ok((len, src_addr)) => {
59                        if let Err(e) = handle_stun_packet(&socket, &buf[..len], src_addr).await {
60                            debug!("Error handling STUN packet from {}: {}", src_addr, e);
61                        }
62                    }
63                    Err(e) => {
64                        error!("Error receiving UDP packet: {}", e);
65                    }
66                }
67            }
68            _ = shutdown.notified() => {
69                info!("STUN server shutting down");
70                break;
71            }
72        }
73    }
74}
75
76async fn handle_stun_packet(
77    socket: &UdpSocket,
78    data: &[u8],
79    src_addr: SocketAddr,
80) -> anyhow::Result<()> {
81    // Parse the incoming message
82    let mut msg = Message::new();
83    msg.raw = data.to_vec();
84
85    if let Err(e) = msg.decode() {
86        debug!("Failed to decode STUN message from {}: {}", src_addr, e);
87        return Ok(()); // Silently ignore non-STUN packets
88    }
89
90    // Check if it's a binding request
91    if msg.typ != BINDING_REQUEST {
92        debug!("Received non-binding STUN message type {:?} from {}", msg.typ, src_addr);
93        return Ok(());
94    }
95
96    debug!("STUN binding request from {}", src_addr);
97
98    // Build binding success response
99    let mut response = Message::new();
100    response.typ = MessageType {
101        method: METHOD_BINDING,
102        class: CLASS_SUCCESS_RESPONSE,
103    };
104    response.transaction_id = msg.transaction_id;
105
106    // Add XOR-MAPPED-ADDRESS with the client's reflexive address
107    let xor_addr = XORMappedAddress {
108        ip: src_addr.ip(),
109        port: src_addr.port(),
110    };
111    xor_addr.add_to(&mut response)?;
112
113    // Encode and send the response
114    response.encode();
115    socket.send_to(&response.raw, src_addr).await?;
116
117    debug!("STUN binding response sent to {} (mapped: {})", src_addr, src_addr);
118
119    Ok(())
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use std::time::Duration;
126    use webrtc_stun::message::Getter;
127
128    #[tokio::test]
129    async fn test_stun_server_responds_to_binding_request() {
130        // Start STUN server on random port
131        let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
132        let handle = start_stun_server(server_addr).await.unwrap();
133        let server_addr = handle.addr;
134
135        // Give server time to start
136        tokio::time::sleep(Duration::from_millis(50)).await;
137
138        // Create client socket
139        let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
140        let client_addr = client.local_addr().unwrap();
141
142        // Build binding request
143        let mut request = Message::new();
144        request.typ = BINDING_REQUEST;
145        request.new_transaction_id().expect("Failed to generate transaction ID");
146        request.encode();
147
148        // Send request
149        client.send_to(&request.raw, server_addr).await.unwrap();
150
151        // Receive response
152        let mut buf = vec![0u8; 1500];
153        let result = tokio::time::timeout(
154            Duration::from_secs(2),
155            client.recv_from(&mut buf)
156        ).await;
157
158        let (len, _) = result.expect("Timeout waiting for response").unwrap();
159
160        // Parse response
161        let mut response = Message::new();
162        response.raw = buf[..len].to_vec();
163        response.decode().expect("Failed to decode response");
164
165        // Verify response type
166        assert_eq!(response.typ.method, METHOD_BINDING);
167        assert_eq!(response.typ.class, CLASS_SUCCESS_RESPONSE);
168        assert_eq!(response.transaction_id, request.transaction_id);
169
170        // Extract XOR-MAPPED-ADDRESS
171        let mut xor_addr = XORMappedAddress::default();
172        xor_addr.get_from(&response).expect("Failed to get XOR-MAPPED-ADDRESS");
173
174        // The mapped address should match our client's address
175        assert_eq!(xor_addr.ip, client_addr.ip());
176        assert_eq!(xor_addr.port, client_addr.port());
177
178        // Cleanup
179        handle.shutdown();
180    }
181}