1use std::collections::HashMap;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
4
5use rand::{thread_rng, Rng};
6
7use super::error::*;
8
9pub const MAGIC_COOKIE: u32 = 0x2112A442;
11
12pub const METHOD_BINDING: u16 = 0x0001;
15
16pub const CLASS_REQUEST: u16 = 0x0000;
19pub const CLASS_INDICATION: u16 = 0x0010;
21pub const CLASS_SUCCESS_RESPONSE: u16 = 0x0100;
23pub const CLASS_ERROR_RESPONSE: u16 = 0x0110;
25
26pub const HEADER_BYTE_SIZE: usize = 20;
28
29pub const ATTR_MAPPED_ADDRESS: u16 = 0x0001;
32pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
34pub const ATTR_ERROR_CODE: u16 = 0x0009;
36pub const ATTR_SOFTWARE: u16 = 0x8022;
38
39pub const ATTR_OTHER_ADDRESS: u16 = 0x802c;
42pub const ATTR_CHANGE_REQUEST: u16 = 0x0003;
44pub const ATTR_RESPONSE_ORIGIN: u16 = 0x802b;
46
47pub const CHANGE_REQUEST_IP_FLAG: u32 = 0x00000004;
49pub const CHANGE_REQUEST_PORT_FLAG: u32 = 0x00000002;
51
52pub const FAMILY_IPV4: u8 = 0x01;
53pub const FAMILY_IPV6: u8 = 0x02;
54
55#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
57pub enum Method {
58 Binding,
59 Unknown(u16),
60}
61
62impl Method {
63 pub fn from_u16(method: u16) -> Self {
65 match method {
66 METHOD_BINDING => Self::Binding,
67 _ => Self::Unknown(method),
68 }
69 }
70
71 pub fn to_u16(&self) -> u16 {
73 match self {
74 Self::Binding => METHOD_BINDING,
75 Self::Unknown(method) => method.clone(),
76 }
77 }
78}
79
80#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
82pub enum Class {
83 Request,
84 Indication,
85 SuccessResponse,
86 ErrorResponse,
87 Unknown(u16),
88}
89
90impl Class {
91 pub fn from_u16(class: u16) -> Self {
93 match class {
94 CLASS_REQUEST => Self::Request,
95 CLASS_INDICATION => Self::Indication,
96 CLASS_SUCCESS_RESPONSE => Self::SuccessResponse,
97 CLASS_ERROR_RESPONSE => Self::ErrorResponse,
98 _ => Self::Unknown(class),
99 }
100 }
101
102 pub fn to_u16(&self) -> u16 {
104 match self {
105 Self::Request => CLASS_REQUEST,
106 Self::Indication => CLASS_INDICATION,
107 Self::SuccessResponse => CLASS_SUCCESS_RESPONSE,
108 Self::ErrorResponse => CLASS_ERROR_RESPONSE,
109 Self::Unknown(class) => class.clone(),
110 }
111 }
112}
113
114#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
116pub enum Attribute {
117 MappedAddress,
118 XORMappedAddress,
119 Software,
120 OtherAddress,
121 ChangeRequest,
122 ResponseOrigin,
123 ErrorCode,
124 Unknown(u16),
125}
126
127impl Attribute {
128 pub fn from_u16(attribute: u16) -> Self {
130 match attribute {
131 ATTR_MAPPED_ADDRESS => Self::MappedAddress,
132 ATTR_XOR_MAPPED_ADDRESS => Self::XORMappedAddress,
133 ATTR_SOFTWARE => Self::Software,
134 ATTR_OTHER_ADDRESS => Self::OtherAddress,
135 ATTR_CHANGE_REQUEST => Self::ChangeRequest,
136 ATTR_RESPONSE_ORIGIN => Self::ResponseOrigin,
137 ATTR_ERROR_CODE => Self::ErrorCode,
138 _ => Self::Unknown(attribute),
139 }
140 }
141
142 pub fn to_u16(&self) -> u16 {
144 match self {
145 Self::MappedAddress => ATTR_MAPPED_ADDRESS,
146 Self::XORMappedAddress => ATTR_XOR_MAPPED_ADDRESS,
147 Self::Software => ATTR_SOFTWARE,
148 Self::OtherAddress => ATTR_OTHER_ADDRESS,
149 Self::ChangeRequest => ATTR_CHANGE_REQUEST,
150 Self::ResponseOrigin => ATTR_RESPONSE_ORIGIN,
151 Self::ErrorCode => ATTR_ERROR_CODE,
152 Self::Unknown(attribute) => attribute.clone(),
153 }
154 }
155
156 pub fn get_mapped_address(message: &Message) -> Option<SocketAddr> {
158 Self::decode_simple_address_attribute(message, Self::MappedAddress)
159 }
160
161 pub fn get_xor_mapped_address(message: &Message) -> Option<SocketAddr> {
163 let attr_value = message.get_raw_attr_value(Self::XORMappedAddress)?;
164 let family = attr_value[1];
165 let mc_bytes = MAGIC_COOKIE.to_be_bytes();
167 let port = u16::from_be_bytes([attr_value[2], attr_value[3]])
168 ^ u16::from_be_bytes([mc_bytes[0], mc_bytes[1]]);
169 match family {
170 FAMILY_IPV4 => {
171 let encoded_ip = &attr_value[4..];
173 let b: Vec<u8> = encoded_ip
174 .iter()
175 .zip(&MAGIC_COOKIE.to_be_bytes())
176 .map(|(b, m)| b ^ m)
177 .collect();
178 let ip_addr = bytes_to_ip_addr(family, b)?;
179 Some(SocketAddr::new(ip_addr, port))
180 }
181 FAMILY_IPV6 => {
182 let encoded_ip = &attr_value[4..];
184 let mut mc_ti: Vec<u8> = vec![];
185 mc_ti.extend(&MAGIC_COOKIE.to_be_bytes());
186 mc_ti.extend(&message.header.transaction_id);
187 let b: Vec<u8> = encoded_ip.iter().zip(&mc_ti).map(|(b, m)| b ^ m).collect();
188 let ip_addr = bytes_to_ip_addr(family, b)?;
189 Some(SocketAddr::new(ip_addr, port))
190 }
191 _ => None,
192 }
193 }
194
195 pub fn get_software(message: &Message) -> Option<String> {
197 let attr_value = message.get_raw_attr_value(Self::Software)?;
198 String::from_utf8(attr_value).ok()
199 }
200
201 pub fn get_error_code(message: &Message) -> Option<ErrorCode> {
203 let attr_value = message.get_raw_attr_value(Self::ErrorCode)?;
204 let class = (attr_value[2] as u16) * 100;
205 let number = attr_value[3] as u16;
206 let code = class + number;
207 let reason = String::from_utf8(attr_value[4..].to_vec())
208 .unwrap_or(String::from("cannot parse error reason"));
209 Some(ErrorCode::from(code, reason))
210 }
211
212 pub fn get_other_address(message: &Message) -> Option<SocketAddr> {
214 Self::decode_simple_address_attribute(message, Self::OtherAddress)
217 }
218
219 pub fn get_response_origin(message: &Message) -> Option<SocketAddr> {
221 Self::decode_simple_address_attribute(message, Self::ResponseOrigin)
222 }
223
224 pub fn generate_change_request_value(change_ip: bool, change_port: bool) -> Vec<u8> {
226 let mut value: u32 = 0;
227 if change_ip {
228 value |= CHANGE_REQUEST_IP_FLAG;
229 }
230
231 if change_port {
232 value |= CHANGE_REQUEST_PORT_FLAG;
233 }
234
235 value.to_be_bytes().to_vec()
236 }
237
238 pub fn decode_simple_address_attribute(message: &Message, attr: Self) -> Option<SocketAddr> {
239 let attr_value = message.get_raw_attr_value(attr)?;
240 let family = attr_value[1];
241 let port = u16::from_be_bytes([attr_value[2], attr_value[3]]);
242 let ip_addr = bytes_to_ip_addr(family, attr_value[4..].to_vec())?;
243 Some(SocketAddr::new(ip_addr, port))
244 }
245}
246
247#[derive(Debug, Eq, PartialEq)]
249pub struct Message {
250 header: Header,
251 attributes: Option<HashMap<Attribute, Vec<u8>>>,
252}
253
254impl Message {
255 pub fn new(
257 method: Method,
258 class: Class,
259 attributes: Option<HashMap<Attribute, Vec<u8>>>,
260 ) -> Message {
261 let attr_type_byte_size = 2;
262 let attr_length_byte_size = 2;
263 let length: u16 = if let Some(attributes) = &attributes {
264 attributes
265 .iter()
266 .map(|e| attr_type_byte_size + attr_length_byte_size + e.1.len() as u16)
267 .sum()
268 } else {
269 0
270 };
271
272 let transaction_id: Vec<u8> = thread_rng().gen::<[u8; 12]>().to_vec();
273
274 Message {
275 header: Header::new(method, class, length, transaction_id),
276 attributes: attributes,
277 }
278 }
279
280 pub fn from_raw(buf: &[u8]) -> Result<Message, STUNClientError> {
282 if buf.len() < HEADER_BYTE_SIZE {
283 return Err(STUNClientError::ParseError());
284 }
285
286 let header = Header::from_raw(&buf[..HEADER_BYTE_SIZE])?;
287 let mut attrs = None;
288 if buf.len() > HEADER_BYTE_SIZE {
289 attrs = Some(Message::decode_attrs(&buf[HEADER_BYTE_SIZE..])?);
290 }
291
292 Ok(Message {
293 header: header,
294 attributes: attrs,
295 })
296 }
297
298 pub fn to_raw(&self) -> Vec<u8> {
300 let mut bytes = self.header.to_raw();
301 if let Some(attributes) = &self.attributes {
302 for (k, v) in attributes.iter() {
303 bytes.extend(&k.to_u16().to_be_bytes());
304 bytes.extend(&(v.len() as u16).to_be_bytes());
305 bytes.extend(v);
306 }
307 }
308
309 bytes
310 }
311
312 pub fn get_method(&self) -> Method {
314 self.header.method
315 }
316
317 pub fn get_class(&self) -> Class {
319 self.header.class
320 }
321
322 pub fn get_raw_attr_value(&self, attr: Attribute) -> Option<Vec<u8>> {
324 self.attributes
325 .as_ref()?
326 .get(&attr)
327 .and_then(|v| Some(v.clone()))
328 }
329
330 pub fn get_transaction_id(&self) -> Vec<u8> {
332 self.header.transaction_id.clone()
333 }
334
335 fn decode_attrs(attrs_buf: &[u8]) -> Result<HashMap<Attribute, Vec<u8>>, STUNClientError> {
336 let mut attrs_buf = attrs_buf.to_vec();
337 let mut attributes = HashMap::new();
338
339 if attrs_buf.is_empty() {
340 return Err(STUNClientError::ParseError());
341 }
342
343 while !attrs_buf.is_empty() {
344 if attrs_buf.len() < 4 {
345 break;
346 }
347
348 let attribute_type = Attribute::from_u16(u16::from_be_bytes([
349 attrs_buf.remove(0),
350 attrs_buf.remove(0),
351 ]));
352 let length =
353 u16::from_be_bytes([attrs_buf.remove(0), attrs_buf.remove(0)]) as usize;
354 if attrs_buf.len() < length {
355 return Err(STUNClientError::ParseError());
356 }
357
358 let value: Vec<u8> = attrs_buf.drain(..length).collect();
359 attributes.insert(attribute_type, value);
360 }
361
362 Ok(attributes)
363 }
364}
365
366#[derive(Debug, Eq, PartialEq)]
368pub struct Header {
369 method: Method,
370 class: Class,
371 length: u16,
372 transaction_id: Vec<u8>,
373}
374
375impl Header {
376 pub fn new(method: Method, class: Class, length: u16, transaction_id: Vec<u8>) -> Header {
378 Header {
379 class: class,
380 method: method,
381 length: length,
382 transaction_id: transaction_id,
383 }
384 }
385
386 pub fn from_raw(buf: &[u8]) -> Result<Header, STUNClientError> {
388 let mut buf = buf.to_vec();
389 if buf.len() < HEADER_BYTE_SIZE {
390 return Err(STUNClientError::ParseError());
391 }
392
393 let message_type = u16::from_be_bytes([buf.remove(0), buf.remove(0)]);
394 let class = Header::decode_class(message_type);
395 let method = Header::decode_method(message_type);
396 let length = u16::from_be_bytes([buf.remove(0), buf.remove(0)]);
397
398 Ok(Header {
399 class: class,
400 method: method,
401 length: length,
402 transaction_id: buf[4..].to_vec(),
404 })
405 }
406
407 pub fn to_raw(&self) -> Vec<u8> {
409 let message_type = self.message_type();
410 let mut bytes = vec![];
411 bytes.extend(&message_type.to_be_bytes());
412 bytes.extend(&self.length.to_be_bytes());
413 bytes.extend(&MAGIC_COOKIE.to_be_bytes());
414 bytes.extend(&self.transaction_id);
415 bytes
416 }
417
418 fn message_type(&self) -> u16 {
419 self.class.to_u16() | self.method.to_u16()
420 }
421
422 fn decode_method(message_type: u16) -> Method {
423 Method::from_u16(message_type & 0x3EEF)
425 }
426
427 fn decode_class(message_type: u16) -> Class {
428 Class::from_u16(message_type & 0x0110)
430 }
431}
432
433fn bytes_to_ip_addr(family: u8, b: Vec<u8>) -> Option<IpAddr> {
434 match family {
435 FAMILY_IPV4 => Some(IpAddr::V4(Ipv4Addr::from([b[0], b[1], b[2], b[3]]))),
436 FAMILY_IPV6 => Some(IpAddr::V6(Ipv6Addr::from([
437 b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13],
438 b[14], b[15],
439 ]))),
440 _ => None,
441 }
442}
443
444#[derive(Clone, Debug, Eq, Hash, PartialEq)]
446pub enum ErrorCode {
447 TryAlternate(String),
448 BadRequest(String),
449 Unauthorized(String),
450 UnknownAttribute(String),
451 StaleNonce(String),
452 ServerError(String),
453 Unknown(String),
454}
455
456impl ErrorCode {
457 pub fn from(code: u16, reason: String) -> Self {
458 match code {
459 300 => Self::TryAlternate(reason),
460 400 => Self::BadRequest(reason),
461 401 => Self::Unauthorized(reason),
462 420 => Self::UnknownAttribute(reason),
463 438 => Self::StaleNonce(reason),
464 500 => Self::ServerError(reason),
465 _ => Self::Unknown(reason),
466 }
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn message_new_and_message_from_raw_are_equivalent() {
476 let mut attrs = HashMap::new();
477 attrs.insert(
478 Attribute::ChangeRequest,
479 Attribute::generate_change_request_value(true, false),
480 );
481 let msg = Message::new(Method::Binding, Class::Request, Some(attrs));
482 let re_built_msg = Message::from_raw(&msg.to_raw()).unwrap();
483 assert_eq!(msg, re_built_msg);
484 }
485}