Skip to main content

ip_discovery/stun/
mod.rs

1//! STUN protocol implementation for public IP detection
2//!
3//! Implements a minimal RFC 5389 STUN client for detecting public IP addresses.
4
5mod message;
6pub(crate) mod providers;
7
8pub use providers::{default_providers, provider_names};
9
10use crate::error::ProviderError;
11use crate::provider::Provider;
12use crate::types::{IpVersion, Protocol};
13use async_trait::async_trait;
14use message::{StunMessage, StunMethod};
15use std::net::{IpAddr, SocketAddr};
16use std::time::Duration;
17use tokio::net::UdpSocket;
18use tokio::time::timeout;
19use tracing::debug;
20
21/// Default STUN timeout
22const STUN_TIMEOUT: Duration = Duration::from_secs(3);
23
24/// STUN provider for IP detection
25#[derive(Debug, Clone)]
26pub struct StunProvider {
27    name: String,
28    server: String,
29    port: u16,
30}
31
32impl StunProvider {
33    /// Create a new STUN provider
34    pub fn new(name: impl Into<String>, server: impl Into<String>, port: u16) -> Self {
35        Self {
36            name: name.into(),
37            server: server.into(),
38            port,
39        }
40    }
41
42    /// Perform STUN binding request
43    async fn binding_request(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
44        // Resolve server address
45        let server_addr = format!("{}:{}", self.server, self.port);
46        let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&server_addr)
47            .await
48            .map_err(|e| ProviderError::new(&self.name, e))?
49            .collect();
50
51        // Filter by IP version
52        let addr = addrs
53            .iter()
54            .find(|a| match version {
55                IpVersion::V4 => a.is_ipv4(),
56                IpVersion::V6 => a.is_ipv6(),
57                IpVersion::Any => true,
58            })
59            .ok_or_else(|| {
60                ProviderError::message(&self.name, "no suitable address for IP version")
61            })?;
62
63        // Create local socket
64        let local_addr = if addr.is_ipv4() {
65            SocketAddr::from(([0, 0, 0, 0], 0))
66        } else {
67            SocketAddr::from(([0u16; 8], 0))
68        };
69
70        let socket = UdpSocket::bind(local_addr)
71            .await
72            .map_err(|e| ProviderError::new(&self.name, e))?;
73
74        socket
75            .connect(addr)
76            .await
77            .map_err(|e| ProviderError::new(&self.name, e))?;
78
79        // Build and send STUN binding request
80        let request = StunMessage::new(StunMethod::Request);
81        let request_bytes = request.encode();
82
83        debug!(
84            server = %addr,
85            transaction_id = ?request.transaction_id(),
86            "sending STUN binding request"
87        );
88
89        socket
90            .send(&request_bytes)
91            .await
92            .map_err(|e| ProviderError::new(&self.name, e))?;
93
94        // Receive response
95        let mut buf = [0u8; 576]; // Minimum MTU
96        let len = timeout(STUN_TIMEOUT, socket.recv(&mut buf))
97            .await
98            .map_err(|_| ProviderError::message(&self.name, "timeout"))?
99            .map_err(|e| ProviderError::new(&self.name, e))?;
100
101        // Parse response
102        let response =
103            StunMessage::decode(&buf[..len]).map_err(|e| ProviderError::message(&self.name, e))?;
104
105        // Verify transaction ID
106        if response.transaction_id() != request.transaction_id() {
107            return Err(ProviderError::message(
108                &self.name,
109                "transaction ID mismatch",
110            ));
111        }
112
113        // Extract mapped address
114        response
115            .get_mapped_address()
116            .ok_or_else(|| ProviderError::message(&self.name, "no mapped address in response"))
117    }
118}
119
120#[async_trait]
121impl Provider for StunProvider {
122    fn name(&self) -> &str {
123        &self.name
124    }
125
126    fn protocol(&self) -> Protocol {
127        Protocol::Stun
128    }
129
130    fn supports_v4(&self) -> bool {
131        true
132    }
133
134    fn supports_v6(&self) -> bool {
135        true
136    }
137
138    async fn get_ip(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
139        self.binding_request(version).await
140    }
141}