1use crate::attribute::{
2 Attribute, LosslessAttribute, LosslessAttributeDecoder, LosslessAttributeEncoder, RawAttribute,
3};
4use crate::constants::MAGIC_COOKIE;
5use crate::convert::TryAsRef;
6use crate::{Method, TransactionId};
7use bytecodec::bytes::{BytesEncoder, CopyableBytesDecoder};
8use bytecodec::combinator::{Collect, Length, Peekable, PreEncode, Repeat};
9use bytecodec::fixnum::{U16beDecoder, U16beEncoder, U32beDecoder, U32beEncoder};
10use bytecodec::{ByteCount, Decode, Encode, Eos, Error, ErrorKind, Result, SizedEncode};
11use std::{fmt, vec};
12use trackable::error::ErrorKindExt;
13
14pub type DecodedMessage<A> = std::result::Result<Message<A>, BrokenMessage>;
16
17#[allow(missing_docs)]
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum MessageClass {
21 Request,
22 Indication,
23 SuccessResponse,
24 ErrorResponse,
25}
26impl MessageClass {
27 fn from_u8(value: u8) -> Option<Self> {
28 match value {
29 0b00 => Some(MessageClass::Request),
30 0b01 => Some(MessageClass::Indication),
31 0b10 => Some(MessageClass::SuccessResponse),
32 0b11 => Some(MessageClass::ErrorResponse),
33 _ => None,
34 }
35 }
36}
37
38impl fmt::Display for MessageClass {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 MessageClass::Request => write!(f, "request"),
42 MessageClass::Indication => write!(f, "indication"),
43 MessageClass::SuccessResponse => write!(f, "success response"),
44 MessageClass::ErrorResponse => write!(f, "error response"),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
170pub struct Message<A> {
171 class: MessageClass,
172 method: Method,
173 transaction_id: TransactionId,
174 attributes: Vec<LosslessAttribute<A>>,
175}
176impl<A: Attribute> Message<A> {
177 pub fn new(class: MessageClass, method: Method, transaction_id: TransactionId) -> Self {
179 Message {
180 class,
181 method,
182 transaction_id,
183 attributes: Vec::new(),
184 }
185 }
186
187 pub fn class(&self) -> MessageClass {
189 self.class
190 }
191
192 pub fn method(&self) -> Method {
194 self.method
195 }
196
197 pub fn transaction_id(&self) -> TransactionId {
199 self.transaction_id
200 }
201
202 pub fn get_attribute<T>(&self) -> Option<&T>
206 where
207 T: Attribute,
208 A: TryAsRef<T>,
209 {
210 self.attributes().filter_map(|a| a.try_as_ref()).next()
211 }
212
213 pub fn attributes(&self) -> impl Iterator<Item = &A> {
215 self.attributes.iter().filter_map(|a| a.as_known())
216 }
217
218 pub fn unknown_attributes(&self) -> impl Iterator<Item = &RawAttribute> {
223 self.attributes.iter().filter_map(|a| a.as_unknown())
224 }
225
226 pub fn add_attribute(&mut self, attribute: impl Into<A>) {
228 self.attributes
229 .push(LosslessAttribute::new(attribute.into()));
230 }
231}
232
233#[allow(missing_docs)]
235#[derive(Debug, Clone)]
236pub struct BrokenMessage {
237 method: Method,
238 class: MessageClass,
239 transaction_id: TransactionId,
240 error: Error,
241}
242impl BrokenMessage {
243 pub fn class(&self) -> MessageClass {
245 self.class
246 }
247
248 pub fn method(&self) -> Method {
250 self.method
251 }
252
253 pub fn transaction_id(&self) -> TransactionId {
255 self.transaction_id
256 }
257
258 pub fn error(&self) -> &Error {
260 &self.error
261 }
262}
263impl From<BrokenMessage> for Error {
264 fn from(f: BrokenMessage) -> Self {
265 ErrorKind::InvalidInput.cause(format!("{f:?}")).into()
266 }
267}
268
269#[derive(Debug, Default)]
270struct MessageHeaderDecoder {
271 message_type: U16beDecoder,
272 message_len: U16beDecoder,
273 magic_cookie: U32beDecoder,
274 transaction_id: CopyableBytesDecoder<[u8; 12]>,
275}
276impl MessageHeaderDecoder {
277 fn check_magic_cookie(&self, magic_cookie: u32) -> Result<()> {
278 track_assert_eq!(
279 magic_cookie,
280 MAGIC_COOKIE,
281 ErrorKind::InvalidInput,
282 "Unexpected MAGIC_COOKIE: actual=0x{:08x}, expected=0x{:08x}",
283 magic_cookie,
284 MAGIC_COOKIE,
285 );
286 Ok(())
287 }
288}
289impl Decode for MessageHeaderDecoder {
290 type Item = (Type, u16, TransactionId);
291
292 fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
293 let mut offset = 0;
294 bytecodec_try_decode!(self.message_type, offset, buf, eos);
295 bytecodec_try_decode!(self.message_len, offset, buf, eos);
296 bytecodec_try_decode!(self.magic_cookie, offset, buf, eos);
297 bytecodec_try_decode!(self.transaction_id, offset, buf, eos);
298 Ok(offset)
299 }
300
301 fn finish_decoding(&mut self) -> Result<Self::Item> {
302 let message_type = track!(self.message_type.finish_decoding())?;
303 let message_type = track!(Type::from_u16(message_type))?;
304 let message_len = track!(self.message_len.finish_decoding())?;
305 let magic_cookie = track!(self.magic_cookie.finish_decoding())?;
306 let transaction_id = TransactionId::new(track!(self.transaction_id.finish_decoding())?);
307 track!(self.check_magic_cookie(magic_cookie); message_type, message_len, transaction_id)?;
308 Ok((message_type, message_len, transaction_id))
309 }
310
311 fn requiring_bytes(&self) -> ByteCount {
312 self.message_type
313 .requiring_bytes()
314 .add_for_decoding(self.message_len.requiring_bytes())
315 .add_for_decoding(self.magic_cookie.requiring_bytes())
316 .add_for_decoding(self.transaction_id.requiring_bytes())
317 }
318
319 fn is_idle(&self) -> bool {
320 self.transaction_id.is_idle()
321 }
322}
323
324#[derive(Debug)]
325struct AttributesDecoder<A: Attribute> {
326 inner: Collect<LosslessAttributeDecoder<A>, Vec<LosslessAttribute<A>>>,
327 last_error: Option<Error>,
328 is_eos: bool,
329}
330impl<A: Attribute> Default for AttributesDecoder<A> {
331 fn default() -> Self {
332 AttributesDecoder {
333 inner: Default::default(),
334 last_error: None,
335 is_eos: false,
336 }
337 }
338}
339impl<A: Attribute> Decode for AttributesDecoder<A> {
340 type Item = Vec<LosslessAttribute<A>>;
341
342 fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
343 if self.last_error.is_none() {
344 match track!(self.inner.decode(buf, eos)) {
345 Err(e) => {
346 self.last_error = Some(e);
347 }
348 Ok(size) => return Ok(size),
349 }
350 }
351
352 self.is_eos = eos.is_reached();
354 Ok(buf.len())
355 }
356
357 fn finish_decoding(&mut self) -> Result<Self::Item> {
358 self.is_eos = false;
359 if let Some(e) = self.last_error.take() {
360 return Err(track!(e));
361 }
362 track!(self.inner.finish_decoding())
363 }
364
365 fn requiring_bytes(&self) -> ByteCount {
366 if self.last_error.is_none() {
367 self.inner.requiring_bytes()
368 } else if self.is_eos {
369 ByteCount::Finite(0)
370 } else {
371 ByteCount::Unknown
372 }
373 }
374
375 fn is_idle(&self) -> bool {
376 if self.last_error.is_none() {
377 self.inner.is_idle()
378 } else {
379 self.is_eos
380 }
381 }
382}
383
384#[derive(Debug)]
386pub struct MessageDecoder<A: Attribute> {
387 header: Peekable<MessageHeaderDecoder>,
388 attributes: Length<AttributesDecoder<A>>,
389}
390impl<A: Attribute> MessageDecoder<A> {
391 pub fn new() -> Self {
393 Self::default()
394 }
395
396 fn finish_decoding_with_header(
397 &mut self,
398 method: Method,
399 class: MessageClass,
400 transaction_id: TransactionId,
401 ) -> Result<Message<A>> {
402 let attributes = track!(self.attributes.finish_decoding())?;
403 let mut message = Message {
404 class,
405 method,
406 transaction_id,
407 attributes,
408 };
409
410 let attributes_len = message.attributes.len();
411 for i in 0..attributes_len {
412 unsafe {
413 let message_mut = &mut *(&mut message as *mut Message<A>);
414 let attr = message_mut.attributes.get_unchecked_mut(i);
415 message.attributes.set_len(i);
416 let decode_result = track!(attr.after_decode(&message));
417 message.attributes.set_len(attributes_len);
418 decode_result?;
419 }
420 }
421 Ok(message)
422 }
423}
424impl<A: Attribute> Default for MessageDecoder<A> {
425 fn default() -> Self {
426 MessageDecoder {
427 header: Default::default(),
428 attributes: Default::default(),
429 }
430 }
431}
432impl<A: Attribute> Decode for MessageDecoder<A> {
433 type Item = DecodedMessage<A>;
434
435 fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
436 let mut offset = 0;
437 if !self.header.is_idle() {
438 bytecodec_try_decode!(self.header, offset, buf, eos);
439
440 let message_len = self.header.peek().expect("never fails").1;
441 track!(self.attributes.set_expected_bytes(u64::from(message_len)))?;
442 }
443 bytecodec_try_decode!(self.attributes, offset, buf, eos);
444 Ok(offset)
445 }
446
447 fn finish_decoding(&mut self) -> Result<Self::Item> {
448 let (Type { method, class }, _, transaction_id) = track!(self.header.finish_decoding())?;
449 match self.finish_decoding_with_header(method, class, transaction_id) {
450 Err(error) => Ok(Err(BrokenMessage {
451 method,
452 class,
453 transaction_id,
454 error,
455 })),
456 Ok(message) => Ok(Ok(message)),
457 }
458 }
459
460 fn requiring_bytes(&self) -> ByteCount {
461 self.header
462 .requiring_bytes()
463 .add_for_decoding(self.attributes.requiring_bytes())
464 }
465
466 fn is_idle(&self) -> bool {
467 self.header.is_idle() && self.attributes.is_idle()
468 }
469}
470
471#[derive(Debug)]
473pub struct MessageEncoder<A: Attribute> {
474 message_type: U16beEncoder,
475 message_len: U16beEncoder,
476 magic_cookie: U32beEncoder,
477 transaction_id: BytesEncoder<TransactionId>,
478 attributes: PreEncode<Repeat<LosslessAttributeEncoder<A>, vec::IntoIter<LosslessAttribute<A>>>>,
479}
480impl<A: Attribute> MessageEncoder<A> {
481 pub fn new() -> Self {
483 Self::default()
484 }
485}
486impl<A: Attribute> Default for MessageEncoder<A> {
487 fn default() -> Self {
488 MessageEncoder {
489 message_type: Default::default(),
490 message_len: Default::default(),
491 magic_cookie: Default::default(),
492 transaction_id: Default::default(),
493 attributes: Default::default(),
494 }
495 }
496}
497impl<A: Attribute> Encode for MessageEncoder<A> {
498 type Item = Message<A>;
499
500 fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result<usize> {
501 let mut offset = 0;
502 bytecodec_try_encode!(self.message_type, offset, buf, eos);
503 bytecodec_try_encode!(self.message_len, offset, buf, eos);
504 bytecodec_try_encode!(self.magic_cookie, offset, buf, eos);
505 bytecodec_try_encode!(self.transaction_id, offset, buf, eos);
506 bytecodec_try_encode!(self.attributes, offset, buf, eos);
507 Ok(offset)
508 }
509
510 fn start_encoding(&mut self, mut item: Self::Item) -> Result<()> {
511 let attributes_len = item.attributes.len();
512 for i in 0..attributes_len {
513 unsafe {
514 let item_mut = &mut *(&mut item as *mut Message<A>);
515 let attr = item_mut.attributes.get_unchecked_mut(i);
516 item.attributes.set_len(i);
517 let encode_result = track!(attr.before_encode(&item));
518 item.attributes.set_len(attributes_len);
519 encode_result?;
520 }
521 }
522
523 let message_type = Type {
524 class: item.class,
525 method: item.method,
526 };
527 track!(self.message_type.start_encoding(message_type.as_u16()))?;
528 track!(self.magic_cookie.start_encoding(MAGIC_COOKIE))?;
529 track!(self.transaction_id.start_encoding(item.transaction_id))?;
530 track!(self.attributes.start_encoding(item.attributes.into_iter()))?;
531
532 let message_len = self.attributes.exact_requiring_bytes();
533 track_assert!(
534 message_len < 0x10000,
535 ErrorKind::InvalidInput,
536 "Too large message length: actual={}, limit=0xFFFF",
537 message_len
538 );
539 track!(self.message_len.start_encoding(message_len as u16))?;
540 Ok(())
541 }
542
543 fn requiring_bytes(&self) -> ByteCount {
544 ByteCount::Finite(self.exact_requiring_bytes())
545 }
546
547 fn is_idle(&self) -> bool {
548 self.transaction_id.is_idle() && self.attributes.is_idle()
549 }
550}
551impl<A: Attribute> SizedEncode for MessageEncoder<A> {
552 fn exact_requiring_bytes(&self) -> u64 {
553 self.message_type.exact_requiring_bytes()
554 + self.message_len.exact_requiring_bytes()
555 + self.magic_cookie.exact_requiring_bytes()
556 + self.transaction_id.exact_requiring_bytes()
557 + self.attributes.exact_requiring_bytes()
558 }
559}
560
561#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
562struct Type {
563 class: MessageClass,
564 method: Method,
565}
566impl Type {
567 fn as_u16(self) -> u16 {
568 let class = self.class as u16;
569 let method = self.method.as_u16();
570 (method & 0b0000_0000_1111)
571 | ((class & 0b01) << 4)
572 | ((method & 0b0000_0111_0000) << 5)
573 | ((class & 0b10) << 7)
574 | ((method & 0b1111_1000_0000) << 9)
575 }
576
577 fn from_u16(value: u16) -> Result<Self> {
578 track_assert!(
579 value >> 14 == 0,
580 ErrorKind::InvalidInput,
581 "First two-bits of STUN message must be 0"
582 );
583 let class = ((value >> 4) & 0b01) | ((value >> 7) & 0b10);
584 let class = MessageClass::from_u8(class as u8).unwrap();
585 let method = (value & 0b0000_0000_1111)
586 | ((value >> 1) & 0b0000_0111_0000)
587 | ((value >> 2) & 0b1111_1000_0000);
588 let method = Method(method);
589 Ok(Type { class, method })
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use crate::rfc5389::attributes::MappedAddress;
597 use crate::rfc5389::methods::BINDING;
598 use crate::{MessageClass, TransactionId};
599 use bytecodec::DecodeExt;
600 use trackable::result::TestResult;
601
602 #[test]
603 fn message_class_from_u8_works() {
604 assert_eq!(MessageClass::from_u8(0), Some(MessageClass::Request));
605 assert_eq!(MessageClass::from_u8(9), None);
606 }
607
608 #[test]
609 fn decoder_fails_when_decoding_attributes() -> TestResult {
610 let bytes = [
611 0, 1, 0, 12, 33, 18, 164, 66, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1, 0, 8, 0, 1, 0,
612 80, 127, 0, 1,
613 ];
614
615 let mut decoder = MessageDecoder::<MappedAddress>::new();
616 let broken_message = decoder.decode_from_bytes(&bytes)?.err().unwrap();
617 assert_eq!(broken_message.method, BINDING);
618 assert_eq!(broken_message.class, MessageClass::Request);
619 assert_eq!(broken_message.transaction_id, TransactionId::new([3; 12]));
620
621 Ok(())
622 }
623}