1mod message;
6pub(crate) mod providers;
7
8pub use providers::{default_providers, provider_names};
9
10use crate::error::ProviderError;
11use crate::provider::Provider;
12use crate::types::{IpVersion, Protocol};
13use async_trait::async_trait;
14use message::{StunMessage, StunMethod};
15use std::net::{IpAddr, SocketAddr};
16use std::time::Duration;
17use tokio::net::UdpSocket;
18use tokio::time::timeout;
19use tracing::debug;
20
21const STUN_TIMEOUT: Duration = Duration::from_secs(3);
23
24#[derive(Debug, Clone)]
26pub struct StunProvider {
27 name: String,
28 server: String,
29 port: u16,
30}
31
32impl StunProvider {
33 pub fn new(name: impl Into<String>, server: impl Into<String>, port: u16) -> Self {
35 Self {
36 name: name.into(),
37 server: server.into(),
38 port,
39 }
40 }
41
42 async fn binding_request(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
44 let server_addr = format!("{}:{}", self.server, self.port);
46 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&server_addr)
47 .await
48 .map_err(|e| ProviderError::new(&self.name, e))?
49 .collect();
50
51 let addr = addrs
53 .iter()
54 .find(|a| match version {
55 IpVersion::V4 => a.is_ipv4(),
56 IpVersion::V6 => a.is_ipv6(),
57 IpVersion::Any => true,
58 })
59 .ok_or_else(|| {
60 ProviderError::message(&self.name, "no suitable address for IP version")
61 })?;
62
63 let local_addr = if addr.is_ipv4() {
65 SocketAddr::from(([0, 0, 0, 0], 0))
66 } else {
67 SocketAddr::from(([0u16; 8], 0))
68 };
69
70 let socket = UdpSocket::bind(local_addr)
71 .await
72 .map_err(|e| ProviderError::new(&self.name, e))?;
73
74 socket
75 .connect(addr)
76 .await
77 .map_err(|e| ProviderError::new(&self.name, e))?;
78
79 let request = StunMessage::new(StunMethod::Request);
81 let request_bytes = request.encode();
82
83 debug!(
84 server = %addr,
85 transaction_id = ?request.transaction_id(),
86 "sending STUN binding request"
87 );
88
89 socket
90 .send(&request_bytes)
91 .await
92 .map_err(|e| ProviderError::new(&self.name, e))?;
93
94 let mut buf = [0u8; 576]; let len = timeout(STUN_TIMEOUT, socket.recv(&mut buf))
97 .await
98 .map_err(|_| ProviderError::message(&self.name, "timeout"))?
99 .map_err(|e| ProviderError::new(&self.name, e))?;
100
101 let response =
103 StunMessage::decode(&buf[..len]).map_err(|e| ProviderError::message(&self.name, e))?;
104
105 if response.transaction_id() != request.transaction_id() {
107 return Err(ProviderError::message(
108 &self.name,
109 "transaction ID mismatch",
110 ));
111 }
112
113 response
115 .get_mapped_address()
116 .ok_or_else(|| ProviderError::message(&self.name, "no mapped address in response"))
117 }
118}
119
120#[async_trait]
121impl Provider for StunProvider {
122 fn name(&self) -> &str {
123 &self.name
124 }
125
126 fn protocol(&self) -> Protocol {
127 Protocol::Stun
128 }
129
130 fn supports_v4(&self) -> bool {
131 true
132 }
133
134 fn supports_v6(&self) -> bool {
135 true
136 }
137
138 async fn get_ip(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
139 self.binding_request(version).await
140 }
141}