1mod message;
17pub(crate) mod providers;
18
19pub use providers::{default_providers, provider_names};
20
21use crate::error::ProviderError;
22use crate::provider::Provider;
23use crate::types::{IpVersion, Protocol};
24use message::{StunMessage, StunMethod};
25use std::future::Future;
26use std::net::{IpAddr, SocketAddr};
27use std::pin::Pin;
28use tokio::net::UdpSocket;
29
30#[derive(Debug, Clone)]
32pub struct StunProvider {
33 name: String,
34 server: String,
35 port: u16,
36}
37
38impl StunProvider {
39 pub fn new(name: impl Into<String>, server: impl Into<String>, port: u16) -> Self {
41 Self {
42 name: name.into(),
43 server: server.into(),
44 port,
45 }
46 }
47
48 async fn binding_request(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
50 let server_addr = format!("{}:{}", self.server, self.port);
52 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&server_addr)
53 .await
54 .map_err(|e| ProviderError::new(&self.name, e))?
55 .collect();
56
57 let addr = addrs
59 .iter()
60 .find(|a| match version {
61 IpVersion::V4 => a.is_ipv4(),
62 IpVersion::V6 => a.is_ipv6(),
63 IpVersion::Any => true,
64 })
65 .ok_or_else(|| {
66 ProviderError::message(&self.name, "no suitable address for IP version")
67 })?;
68
69 let local_addr = if addr.is_ipv4() {
71 SocketAddr::from(([0, 0, 0, 0], 0))
72 } else {
73 SocketAddr::from(([0u16; 8], 0))
74 };
75
76 let socket = UdpSocket::bind(local_addr)
77 .await
78 .map_err(|e| ProviderError::new(&self.name, e))?;
79
80 socket
81 .connect(addr)
82 .await
83 .map_err(|e| ProviderError::new(&self.name, e))?;
84
85 let request = StunMessage::new(StunMethod::Request);
87 let request_bytes = request.encode();
88
89 socket
90 .send(&request_bytes)
91 .await
92 .map_err(|e| ProviderError::new(&self.name, e))?;
93
94 let mut buf = [0u8; 576]; let len = socket
97 .recv(&mut buf)
98 .await
99 .map_err(|e| ProviderError::new(&self.name, e))?;
100
101 let response =
103 StunMessage::decode(&buf[..len]).map_err(|e| ProviderError::message(&self.name, e))?;
104
105 if response.transaction_id() != request.transaction_id() {
107 return Err(ProviderError::message(
108 &self.name,
109 "transaction ID mismatch",
110 ));
111 }
112
113 response
115 .get_mapped_address()
116 .ok_or_else(|| ProviderError::message(&self.name, "no mapped address in response"))
117 }
118}
119
120impl Provider for StunProvider {
121 fn name(&self) -> &str {
122 &self.name
123 }
124
125 fn protocol(&self) -> Protocol {
126 Protocol::Stun
127 }
128
129 fn supports_v4(&self) -> bool {
130 true
131 }
132
133 fn supports_v6(&self) -> bool {
134 true
135 }
136
137 fn get_ip(
138 &self,
139 version: IpVersion,
140 ) -> Pin<Box<dyn Future<Output = Result<IpAddr, ProviderError>> + Send + '_>> {
141 Box::pin(self.binding_request(version))
142 }
143}