hightower_stun/
lib.rs

1pub mod client;
2pub mod server;
3
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
5
6// STUN constants
7pub const MAGIC_COOKIE: u32 = 0x2112A442;
8pub const BINDING_REQUEST: u16 = 0x0001;
9pub const BINDING_RESPONSE: u16 = 0x0101;
10
11// Attribute types
12pub const XOR_MAPPED_ADDRESS: u16 = 0x0020;
13pub const MAPPED_ADDRESS: u16 = 0x0001;
14
15// Address families
16pub const FAMILY_IPV4: u8 = 0x01;
17pub const FAMILY_IPV6: u8 = 0x02;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct TransactionId([u8; 12]);
21
22impl TransactionId {
23    pub fn new() -> Self {
24        use std::time::{SystemTime, UNIX_EPOCH};
25        let mut bytes = [0u8; 12];
26
27        // Simple transaction ID generation using time and process ID
28        let now = SystemTime::now()
29            .duration_since(UNIX_EPOCH)
30            .unwrap()
31            .as_nanos() as u64;
32
33        bytes[0..8].copy_from_slice(&now.to_be_bytes());
34        bytes[8..12].copy_from_slice(&std::process::id().to_be_bytes());
35
36        TransactionId(bytes)
37    }
38
39    pub fn as_bytes(&self) -> &[u8; 12] {
40        &self.0
41    }
42}
43
44impl Default for TransactionId {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50#[derive(Debug)]
51pub struct StunMessage {
52    pub message_type: u16,
53    pub length: u16,
54    pub transaction_id: TransactionId,
55    pub attributes: Vec<Attribute>,
56}
57
58impl StunMessage {
59    pub fn new_binding_request() -> Self {
60        StunMessage {
61            message_type: BINDING_REQUEST,
62            length: 0,
63            transaction_id: TransactionId::new(),
64            attributes: Vec::new(),
65        }
66    }
67
68    pub fn encode(&self) -> Vec<u8> {
69        let mut buffer = Vec::with_capacity(20 + self.length as usize);
70
71        // Message type
72        buffer.extend_from_slice(&self.message_type.to_be_bytes());
73
74        // Message length
75        buffer.extend_from_slice(&self.length.to_be_bytes());
76
77        // Magic cookie
78        buffer.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
79
80        // Transaction ID
81        buffer.extend_from_slice(self.transaction_id.as_bytes());
82
83        // Attributes
84        for attr in &self.attributes {
85            buffer.extend_from_slice(&attr.encode());
86        }
87
88        buffer
89    }
90
91    pub fn decode(data: &[u8]) -> Result<Self, StunError> {
92        if data.len() < 20 {
93            return Err(StunError::MessageTooShort);
94        }
95
96        // Check first two bits are 0
97        if data[0] & 0xC0 != 0 {
98            return Err(StunError::InvalidMessage);
99        }
100
101        let message_type = u16::from_be_bytes([data[0], data[1]]);
102        let length = u16::from_be_bytes([data[2], data[3]]);
103        let magic = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
104
105        if magic != MAGIC_COOKIE {
106            return Err(StunError::InvalidMagicCookie);
107        }
108
109        let mut transaction_id_bytes = [0u8; 12];
110        transaction_id_bytes.copy_from_slice(&data[8..20]);
111        let transaction_id = TransactionId(transaction_id_bytes);
112
113        if data.len() < 20 + length as usize {
114            return Err(StunError::InvalidLength);
115        }
116
117        let mut attributes = Vec::new();
118        let mut offset = 20;
119
120        while offset < 20 + length as usize {
121            let attr = Attribute::decode(&data[offset..])?;
122            let attr_len = 4 + attr.length as usize;
123            let padded_len = (attr_len + 3) & !3; // Round up to multiple of 4
124            offset += padded_len;
125            attributes.push(attr);
126        }
127
128        Ok(StunMessage {
129            message_type,
130            length,
131            transaction_id,
132            attributes,
133        })
134    }
135}
136
137#[derive(Debug)]
138pub struct Attribute {
139    pub attr_type: u16,
140    pub length: u16,
141    pub value: Vec<u8>,
142}
143
144impl Attribute {
145    pub fn encode(&self) -> Vec<u8> {
146        let mut buffer = Vec::new();
147        buffer.extend_from_slice(&self.attr_type.to_be_bytes());
148        buffer.extend_from_slice(&self.length.to_be_bytes());
149        buffer.extend_from_slice(&self.value);
150
151        // Pad to multiple of 4 bytes
152        let padding = (4 - (self.length % 4)) % 4;
153        buffer.resize(buffer.len() + padding as usize, 0);
154
155        buffer
156    }
157
158    pub fn decode(data: &[u8]) -> Result<Self, StunError> {
159        if data.len() < 4 {
160            return Err(StunError::AttributeTooShort);
161        }
162
163        let attr_type = u16::from_be_bytes([data[0], data[1]]);
164        let length = u16::from_be_bytes([data[2], data[3]]);
165
166        if data.len() < 4 + length as usize {
167            return Err(StunError::InvalidLength);
168        }
169
170        let value = data[4..4 + length as usize].to_vec();
171
172        Ok(Attribute {
173            attr_type,
174            length,
175            value,
176        })
177    }
178
179    pub fn decode_xor_mapped_address(
180        &self,
181        transaction_id: &TransactionId,
182    ) -> Result<SocketAddr, StunError> {
183        if self.attr_type != XOR_MAPPED_ADDRESS {
184            return Err(StunError::WrongAttributeType);
185        }
186
187        if self.value.len() < 4 {
188            return Err(StunError::InvalidLength);
189        }
190
191        let family = self.value[1];
192        let x_port = u16::from_be_bytes([self.value[2], self.value[3]]);
193
194        // XOR port with most significant 16 bits of magic cookie
195        let port = x_port ^ (MAGIC_COOKIE >> 16) as u16;
196
197        match family {
198            FAMILY_IPV4 => {
199                if self.value.len() < 8 {
200                    return Err(StunError::InvalidLength);
201                }
202
203                let x_addr = u32::from_be_bytes([
204                    self.value[4],
205                    self.value[5],
206                    self.value[6],
207                    self.value[7],
208                ]);
209
210                // XOR address with magic cookie
211                let addr = x_addr ^ MAGIC_COOKIE;
212                let ip = Ipv4Addr::from(addr);
213
214                Ok(SocketAddr::new(IpAddr::V4(ip), port))
215            }
216            FAMILY_IPV6 => {
217                if self.value.len() < 20 {
218                    return Err(StunError::InvalidLength);
219                }
220
221                // XOR address with magic cookie + transaction ID
222                let mut xor_key = [0u8; 16];
223                xor_key[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
224                xor_key[4..16].copy_from_slice(transaction_id.as_bytes());
225
226                let mut addr_bytes = [0u8; 16];
227                for i in 0..16 {
228                    addr_bytes[i] = self.value[4 + i] ^ xor_key[i];
229                }
230
231                let ip = Ipv6Addr::from(addr_bytes);
232                Ok(SocketAddr::new(IpAddr::V6(ip), port))
233            }
234            _ => Err(StunError::UnsupportedAddressFamily),
235        }
236    }
237
238    pub fn decode_mapped_address(&self) -> Result<SocketAddr, StunError> {
239        if self.attr_type != MAPPED_ADDRESS {
240            return Err(StunError::WrongAttributeType);
241        }
242
243        if self.value.len() < 4 {
244            return Err(StunError::InvalidLength);
245        }
246
247        let family = self.value[1];
248        let port = u16::from_be_bytes([self.value[2], self.value[3]]);
249
250        match family {
251            FAMILY_IPV4 => {
252                if self.value.len() < 8 {
253                    return Err(StunError::InvalidLength);
254                }
255
256                let addr = u32::from_be_bytes([
257                    self.value[4],
258                    self.value[5],
259                    self.value[6],
260                    self.value[7],
261                ]);
262
263                let ip = Ipv4Addr::from(addr);
264                Ok(SocketAddr::new(IpAddr::V4(ip), port))
265            }
266            FAMILY_IPV6 => {
267                if self.value.len() < 20 {
268                    return Err(StunError::InvalidLength);
269                }
270
271                let mut addr_bytes = [0u8; 16];
272                addr_bytes.copy_from_slice(&self.value[4..20]);
273
274                let ip = Ipv6Addr::from(addr_bytes);
275                Ok(SocketAddr::new(IpAddr::V6(ip), port))
276            }
277            _ => Err(StunError::UnsupportedAddressFamily),
278        }
279    }
280}
281
282#[derive(Debug, PartialEq)]
283pub enum StunError {
284    MessageTooShort,
285    InvalidMessage,
286    InvalidMagicCookie,
287    InvalidLength,
288    AttributeTooShort,
289    WrongAttributeType,
290    UnsupportedAddressFamily,
291    NoMappedAddress,
292    IoError(String),
293}
294
295impl std::fmt::Display for StunError {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        match self {
298            StunError::MessageTooShort => write!(f, "STUN message too short"),
299            StunError::InvalidMessage => write!(f, "Invalid STUN message"),
300            StunError::InvalidMagicCookie => write!(f, "Invalid magic cookie"),
301            StunError::InvalidLength => write!(f, "Invalid length"),
302            StunError::AttributeTooShort => write!(f, "Attribute too short"),
303            StunError::WrongAttributeType => write!(f, "Wrong attribute type"),
304            StunError::UnsupportedAddressFamily => write!(f, "Unsupported address family"),
305            StunError::NoMappedAddress => write!(f, "No mapped address in response"),
306            StunError::IoError(e) => write!(f, "I/O error: {}", e),
307        }
308    }
309}
310
311impl std::error::Error for StunError {}
312
313impl From<std::io::Error> for StunError {
314    fn from(e: std::io::Error) -> Self {
315        StunError::IoError(e.to_string())
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_binding_request_encode() {
325        let msg = StunMessage::new_binding_request();
326        let encoded = msg.encode();
327
328        assert_eq!(encoded.len(), 20);
329        assert_eq!(&encoded[0..2], &BINDING_REQUEST.to_be_bytes());
330        assert_eq!(&encoded[2..4], &0u16.to_be_bytes());
331        assert_eq!(&encoded[4..8], &MAGIC_COOKIE.to_be_bytes());
332    }
333
334    #[test]
335    fn test_binding_request_decode() {
336        let msg = StunMessage::new_binding_request();
337        let encoded = msg.encode();
338        let decoded = StunMessage::decode(&encoded).unwrap();
339
340        assert_eq!(decoded.message_type, BINDING_REQUEST);
341        assert_eq!(decoded.length, 0);
342        assert_eq!(decoded.transaction_id, msg.transaction_id);
343    }
344
345    #[test]
346    fn test_xor_mapped_address_ipv4() {
347        let transaction_id = TransactionId([0; 12]);
348
349        // Encode a test IPv4 address 192.0.2.1:32768
350        let test_ip = Ipv4Addr::new(192, 0, 2, 1);
351        let test_port = 32768u16;
352
353        // XOR the values
354        let x_port = test_port ^ (MAGIC_COOKIE >> 16) as u16;
355        let x_addr = u32::from(test_ip) ^ MAGIC_COOKIE;
356
357        let mut value = vec![0u8, FAMILY_IPV4];
358        value.extend_from_slice(&x_port.to_be_bytes());
359        value.extend_from_slice(&x_addr.to_be_bytes());
360
361        let attr = Attribute {
362            attr_type: XOR_MAPPED_ADDRESS,
363            length: value.len() as u16,
364            value,
365        };
366
367        let addr = attr.decode_xor_mapped_address(&transaction_id).unwrap();
368        assert_eq!(addr.ip(), IpAddr::V4(test_ip));
369        assert_eq!(addr.port(), test_port);
370    }
371}