1pub mod client;
2pub mod server;
3
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
5
6pub const MAGIC_COOKIE: u32 = 0x2112A442;
8pub const BINDING_REQUEST: u16 = 0x0001;
9pub const BINDING_RESPONSE: u16 = 0x0101;
10
11pub const XOR_MAPPED_ADDRESS: u16 = 0x0020;
13pub const MAPPED_ADDRESS: u16 = 0x0001;
14
15pub const FAMILY_IPV4: u8 = 0x01;
17pub const FAMILY_IPV6: u8 = 0x02;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct TransactionId([u8; 12]);
21
22impl TransactionId {
23 pub fn new() -> Self {
24 use std::time::{SystemTime, UNIX_EPOCH};
25 let mut bytes = [0u8; 12];
26
27 let now = SystemTime::now()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_nanos() as u64;
32
33 bytes[0..8].copy_from_slice(&now.to_be_bytes());
34 bytes[8..12].copy_from_slice(&std::process::id().to_be_bytes());
35
36 TransactionId(bytes)
37 }
38
39 pub fn as_bytes(&self) -> &[u8; 12] {
40 &self.0
41 }
42}
43
44impl Default for TransactionId {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50#[derive(Debug)]
51pub struct StunMessage {
52 pub message_type: u16,
53 pub length: u16,
54 pub transaction_id: TransactionId,
55 pub attributes: Vec<Attribute>,
56}
57
58impl StunMessage {
59 pub fn new_binding_request() -> Self {
60 StunMessage {
61 message_type: BINDING_REQUEST,
62 length: 0,
63 transaction_id: TransactionId::new(),
64 attributes: Vec::new(),
65 }
66 }
67
68 pub fn encode(&self) -> Vec<u8> {
69 let mut buffer = Vec::with_capacity(20 + self.length as usize);
70
71 buffer.extend_from_slice(&self.message_type.to_be_bytes());
73
74 buffer.extend_from_slice(&self.length.to_be_bytes());
76
77 buffer.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
79
80 buffer.extend_from_slice(self.transaction_id.as_bytes());
82
83 for attr in &self.attributes {
85 buffer.extend_from_slice(&attr.encode());
86 }
87
88 buffer
89 }
90
91 pub fn decode(data: &[u8]) -> Result<Self, StunError> {
92 if data.len() < 20 {
93 return Err(StunError::MessageTooShort);
94 }
95
96 if data[0] & 0xC0 != 0 {
98 return Err(StunError::InvalidMessage);
99 }
100
101 let message_type = u16::from_be_bytes([data[0], data[1]]);
102 let length = u16::from_be_bytes([data[2], data[3]]);
103 let magic = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
104
105 if magic != MAGIC_COOKIE {
106 return Err(StunError::InvalidMagicCookie);
107 }
108
109 let mut transaction_id_bytes = [0u8; 12];
110 transaction_id_bytes.copy_from_slice(&data[8..20]);
111 let transaction_id = TransactionId(transaction_id_bytes);
112
113 if data.len() < 20 + length as usize {
114 return Err(StunError::InvalidLength);
115 }
116
117 let mut attributes = Vec::new();
118 let mut offset = 20;
119
120 while offset < 20 + length as usize {
121 let attr = Attribute::decode(&data[offset..])?;
122 let attr_len = 4 + attr.length as usize;
123 let padded_len = (attr_len + 3) & !3; offset += padded_len;
125 attributes.push(attr);
126 }
127
128 Ok(StunMessage {
129 message_type,
130 length,
131 transaction_id,
132 attributes,
133 })
134 }
135}
136
137#[derive(Debug)]
138pub struct Attribute {
139 pub attr_type: u16,
140 pub length: u16,
141 pub value: Vec<u8>,
142}
143
144impl Attribute {
145 pub fn encode(&self) -> Vec<u8> {
146 let mut buffer = Vec::new();
147 buffer.extend_from_slice(&self.attr_type.to_be_bytes());
148 buffer.extend_from_slice(&self.length.to_be_bytes());
149 buffer.extend_from_slice(&self.value);
150
151 let padding = (4 - (self.length % 4)) % 4;
153 buffer.resize(buffer.len() + padding as usize, 0);
154
155 buffer
156 }
157
158 pub fn decode(data: &[u8]) -> Result<Self, StunError> {
159 if data.len() < 4 {
160 return Err(StunError::AttributeTooShort);
161 }
162
163 let attr_type = u16::from_be_bytes([data[0], data[1]]);
164 let length = u16::from_be_bytes([data[2], data[3]]);
165
166 if data.len() < 4 + length as usize {
167 return Err(StunError::InvalidLength);
168 }
169
170 let value = data[4..4 + length as usize].to_vec();
171
172 Ok(Attribute {
173 attr_type,
174 length,
175 value,
176 })
177 }
178
179 pub fn decode_xor_mapped_address(
180 &self,
181 transaction_id: &TransactionId,
182 ) -> Result<SocketAddr, StunError> {
183 if self.attr_type != XOR_MAPPED_ADDRESS {
184 return Err(StunError::WrongAttributeType);
185 }
186
187 if self.value.len() < 4 {
188 return Err(StunError::InvalidLength);
189 }
190
191 let family = self.value[1];
192 let x_port = u16::from_be_bytes([self.value[2], self.value[3]]);
193
194 let port = x_port ^ (MAGIC_COOKIE >> 16) as u16;
196
197 match family {
198 FAMILY_IPV4 => {
199 if self.value.len() < 8 {
200 return Err(StunError::InvalidLength);
201 }
202
203 let x_addr = u32::from_be_bytes([
204 self.value[4],
205 self.value[5],
206 self.value[6],
207 self.value[7],
208 ]);
209
210 let addr = x_addr ^ MAGIC_COOKIE;
212 let ip = Ipv4Addr::from(addr);
213
214 Ok(SocketAddr::new(IpAddr::V4(ip), port))
215 }
216 FAMILY_IPV6 => {
217 if self.value.len() < 20 {
218 return Err(StunError::InvalidLength);
219 }
220
221 let mut xor_key = [0u8; 16];
223 xor_key[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
224 xor_key[4..16].copy_from_slice(transaction_id.as_bytes());
225
226 let mut addr_bytes = [0u8; 16];
227 for i in 0..16 {
228 addr_bytes[i] = self.value[4 + i] ^ xor_key[i];
229 }
230
231 let ip = Ipv6Addr::from(addr_bytes);
232 Ok(SocketAddr::new(IpAddr::V6(ip), port))
233 }
234 _ => Err(StunError::UnsupportedAddressFamily),
235 }
236 }
237
238 pub fn decode_mapped_address(&self) -> Result<SocketAddr, StunError> {
239 if self.attr_type != MAPPED_ADDRESS {
240 return Err(StunError::WrongAttributeType);
241 }
242
243 if self.value.len() < 4 {
244 return Err(StunError::InvalidLength);
245 }
246
247 let family = self.value[1];
248 let port = u16::from_be_bytes([self.value[2], self.value[3]]);
249
250 match family {
251 FAMILY_IPV4 => {
252 if self.value.len() < 8 {
253 return Err(StunError::InvalidLength);
254 }
255
256 let addr = u32::from_be_bytes([
257 self.value[4],
258 self.value[5],
259 self.value[6],
260 self.value[7],
261 ]);
262
263 let ip = Ipv4Addr::from(addr);
264 Ok(SocketAddr::new(IpAddr::V4(ip), port))
265 }
266 FAMILY_IPV6 => {
267 if self.value.len() < 20 {
268 return Err(StunError::InvalidLength);
269 }
270
271 let mut addr_bytes = [0u8; 16];
272 addr_bytes.copy_from_slice(&self.value[4..20]);
273
274 let ip = Ipv6Addr::from(addr_bytes);
275 Ok(SocketAddr::new(IpAddr::V6(ip), port))
276 }
277 _ => Err(StunError::UnsupportedAddressFamily),
278 }
279 }
280}
281
282#[derive(Debug, PartialEq)]
283pub enum StunError {
284 MessageTooShort,
285 InvalidMessage,
286 InvalidMagicCookie,
287 InvalidLength,
288 AttributeTooShort,
289 WrongAttributeType,
290 UnsupportedAddressFamily,
291 NoMappedAddress,
292 IoError(String),
293}
294
295impl std::fmt::Display for StunError {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 match self {
298 StunError::MessageTooShort => write!(f, "STUN message too short"),
299 StunError::InvalidMessage => write!(f, "Invalid STUN message"),
300 StunError::InvalidMagicCookie => write!(f, "Invalid magic cookie"),
301 StunError::InvalidLength => write!(f, "Invalid length"),
302 StunError::AttributeTooShort => write!(f, "Attribute too short"),
303 StunError::WrongAttributeType => write!(f, "Wrong attribute type"),
304 StunError::UnsupportedAddressFamily => write!(f, "Unsupported address family"),
305 StunError::NoMappedAddress => write!(f, "No mapped address in response"),
306 StunError::IoError(e) => write!(f, "I/O error: {}", e),
307 }
308 }
309}
310
311impl std::error::Error for StunError {}
312
313impl From<std::io::Error> for StunError {
314 fn from(e: std::io::Error) -> Self {
315 StunError::IoError(e.to_string())
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_binding_request_encode() {
325 let msg = StunMessage::new_binding_request();
326 let encoded = msg.encode();
327
328 assert_eq!(encoded.len(), 20);
329 assert_eq!(&encoded[0..2], &BINDING_REQUEST.to_be_bytes());
330 assert_eq!(&encoded[2..4], &0u16.to_be_bytes());
331 assert_eq!(&encoded[4..8], &MAGIC_COOKIE.to_be_bytes());
332 }
333
334 #[test]
335 fn test_binding_request_decode() {
336 let msg = StunMessage::new_binding_request();
337 let encoded = msg.encode();
338 let decoded = StunMessage::decode(&encoded).unwrap();
339
340 assert_eq!(decoded.message_type, BINDING_REQUEST);
341 assert_eq!(decoded.length, 0);
342 assert_eq!(decoded.transaction_id, msg.transaction_id);
343 }
344
345 #[test]
346 fn test_xor_mapped_address_ipv4() {
347 let transaction_id = TransactionId([0; 12]);
348
349 let test_ip = Ipv4Addr::new(192, 0, 2, 1);
351 let test_port = 32768u16;
352
353 let x_port = test_port ^ (MAGIC_COOKIE >> 16) as u16;
355 let x_addr = u32::from(test_ip) ^ MAGIC_COOKIE;
356
357 let mut value = vec![0u8, FAMILY_IPV4];
358 value.extend_from_slice(&x_port.to_be_bytes());
359 value.extend_from_slice(&x_addr.to_be_bytes());
360
361 let attr = Attribute {
362 attr_type: XOR_MAPPED_ADDRESS,
363 length: value.len() as u16,
364 value,
365 };
366
367 let addr = attr.decode_xor_mapped_address(&transaction_id).unwrap();
368 assert_eq!(addr.ip(), IpAddr::V4(test_ip));
369 assert_eq!(addr.port(), test_port);
370 }
371}