use crate::attributes::{StunAttribute, StunAttributeType};
use crate::common::check_buffer_boundaries;
use crate::error::{StunError, StunErrorType};
use crate::{Encode, TransactionId};
use byteorder::{BigEndian, ByteOrder};
use std::convert::{TryFrom, TryInto};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MessageType {
method: MessageMethod,
class: MessageClass,
}
impl MessageType {
pub fn new(method: MessageMethod, class: MessageClass) -> Self {
Self { method, class }
}
pub fn class(&self) -> MessageClass {
self.class
}
pub fn method(&self) -> MessageMethod {
self.method
}
pub fn as_u16(&self) -> u16 {
((self.method.0 & 0x1F80) << 2)
| ((self.method.as_u16() & 0x0070) << 1)
| (self.method.as_u16() & 0x000F)
| ((self.class.as_u16() & 0x0002) << 7)
| ((self.class.as_u16() & 0x0001) << 4)
}
}
impl From<u16> for MessageType {
fn from(value: u16) -> Self {
let val = value & 0x3FFF;
let class_u8: u8 = (((val & 0x0100) >> 7) | ((val & 0x0010) >> 4))
.try_into()
.unwrap();
let class = MessageClass::try_from(class_u8).unwrap();
let method_u16: u16 = ((val & 0x3E00) >> 2) | ((val & 0x00E0) >> 1) | (val & 0x000F);
let method = MessageMethod::try_from(method_u16).unwrap();
MessageType::new(method, class)
}
}
impl From<&[u8; 2]> for MessageType {
fn from(value: &[u8; 2]) -> Self {
MessageType::from(BigEndian::read_u16(value))
}
}
impl Encode for MessageType {
fn encode(&self, buffer: &mut [u8]) -> Result<usize, StunError> {
check_buffer_boundaries(buffer, 2)?;
BigEndian::write_u16(buffer, self.as_u16());
Ok(2)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct MessageMethod(pub(crate) u16);
impl MessageMethod {
pub fn as_u16(&self) -> u16 {
self.0
}
pub fn is_valid(&self) -> bool {
(0x00..=0xff).contains(&self.0)
}
}
impl TryFrom<u16> for MessageMethod {
type Error = StunError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
(value & 0xF000 == 0)
.then_some(MessageMethod(value))
.ok_or_else(|| {
StunError::new(
StunErrorType::InvalidParam,
format!("Value '{:#02x}' is not a valid a MessageMethod", value),
)
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageClass {
Request,
Indication,
SuccessResponse,
ErrorResponse,
}
impl MessageClass {
fn as_u16(&self) -> u16 {
match self {
MessageClass::Request => 0b00,
MessageClass::Indication => 0b01,
MessageClass::SuccessResponse => 0b10,
MessageClass::ErrorResponse => 0b11,
}
}
}
impl TryFrom<u8> for MessageClass {
type Error = StunError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0b00 => Ok(MessageClass::Request),
0b01 => Ok(MessageClass::Indication),
0b10 => Ok(MessageClass::SuccessResponse),
0b11 => Ok(MessageClass::ErrorResponse),
_ => Err(StunError::new(
StunErrorType::InvalidParam,
format!("Value '{:#02x}' is not a valid a MessageClass", value),
)),
}
}
}
#[derive(Debug)]
struct StunMessageParameters {
method: MessageMethod,
class: MessageClass,
transaction_id: Option<TransactionId>,
attributes: Vec<StunAttribute>,
}
#[derive(Debug)]
pub struct StunMessageBuilder(StunMessageParameters);
impl StunMessageBuilder {
pub fn new(method: MessageMethod, class: MessageClass) -> StunMessageBuilder {
Self(StunMessageParameters {
method,
class,
transaction_id: None,
attributes: Vec::new(),
})
}
pub fn with_transaction_id(mut self, transaction_id: TransactionId) -> Self {
self.0.transaction_id = Some(transaction_id);
self
}
pub fn with_attribute<T>(mut self, attribute: T) -> Self
where
T: Into<StunAttribute>,
{
self.0.attributes.push(attribute.into());
self
}
pub fn build(self) -> StunMessage {
StunMessage {
method: self.0.method,
class: self.0.class,
transaction_id: self.0.transaction_id.unwrap_or_default(),
attributes: self.0.attributes,
}
}
}
#[derive(Debug)]
pub struct StunMessage {
method: MessageMethod,
class: MessageClass,
transaction_id: TransactionId,
attributes: Vec<StunAttribute>,
}
impl StunMessage {
pub fn method(&self) -> MessageMethod {
self.method
}
pub fn class(&self) -> MessageClass {
self.class
}
pub fn transaction_id(&self) -> &TransactionId {
&self.transaction_id
}
pub fn attributes(&self) -> &[StunAttribute] {
&self.attributes
}
pub fn get<A>(&self) -> Option<&StunAttribute>
where
A: StunAttributeType,
{
self.attributes
.iter()
.find(|&attr| attr.attribute_type() == A::get_type())
}
}
#[cfg(test)]
mod tests {
use crate::{message::*, methods::BINDING};
#[test]
fn message_class() {
let cls = MessageClass::try_from(0).expect("Can not create MessageClass");
assert_eq!(cls.as_u16(), 0);
let cls = MessageClass::try_from(1).expect("Can not create MessageClass");
assert_eq!(cls.as_u16(), 1);
let cls = MessageClass::try_from(2).expect("Can not create MessageClass");
assert_eq!(cls.as_u16(), 2);
let cls = MessageClass::try_from(3).expect("Can not create MessageClass");
assert_eq!(cls.as_u16(), 3);
MessageClass::try_from(4).expect_err("MessageClass should not be created");
}
#[test]
fn message_method() {
let m = MessageMethod::try_from(0x0000).expect("Can not create MessageMethod");
assert_eq!(m.as_u16(), 0x0000);
let m = MessageMethod::try_from(0x0001).expect("Can not create MessageMethod");
assert_eq!(m.as_u16(), 0x0001);
let m = MessageMethod::try_from(0x0FFF).expect("Can not create MessageMethod");
assert_eq!(m.as_u16(), 0x0FFF);
MessageMethod::try_from(0x1000).expect_err("MessageMethod should not be created");
}
#[test]
fn message_type() {
let cls = MessageClass::Request;
let method = MessageMethod::try_from(0x0001).expect("Can not create MessageMethod");
let msg_type = MessageType::new(method, cls);
assert_eq!(msg_type.class(), cls);
assert_eq!(msg_type.method(), method);
let mut buffer: [u8; 2] = [0; 2];
assert_eq!(msg_type.encode(&mut buffer), Ok(2));
assert_eq!(buffer, [0x00, 0x01]);
}
#[test]
fn encode_message_type() {
let cls = MessageClass::Request;
let method = MessageMethod::try_from(0x08D8).expect("Can not create MessageMethod");
let msg_type = MessageType::new(method, cls);
let mut buffer: [u8; 2] = [0; 2];
assert_eq!(msg_type.encode(&mut buffer), Ok(2));
assert_eq!(buffer, [0x22, 0xA8]);
let cls = MessageClass::Indication;
let msg_type = MessageType::new(method, cls);
let mut buffer: [u8; 2] = [0; 2];
assert_eq!(msg_type.encode(&mut buffer), Ok(2));
assert_eq!(buffer, [0x22, 0xB8]);
let cls = MessageClass::SuccessResponse;
let msg_type = MessageType::new(method, cls);
let mut buffer: [u8; 2] = [0; 2];
assert_eq!(msg_type.encode(&mut buffer), Ok(2));
assert_eq!(buffer, [0x23, 0xA8]);
let cls = MessageClass::ErrorResponse;
let msg_type = MessageType::new(method, cls);
let mut buffer: [u8; 2] = [0; 2];
assert_eq!(msg_type.encode(&mut buffer), Ok(2));
assert_eq!(buffer, [0x23, 0xB8]);
let cls = MessageClass::ErrorResponse;
let msg_type = MessageType::new(method, cls);
let mut buffer: [u8; 1] = [0; 1];
assert_eq!(
msg_type.encode(&mut buffer).expect_err("Error expected"),
StunErrorType::SmallBuffer
);
}
#[test]
fn message_type_from() {
let method = MessageMethod::try_from(0x08D8).expect("Can not create MessageMethod");
let buffer = [0x22, 0xA8];
let msg_type = MessageType::from(&buffer);
assert_eq!(msg_type.class(), MessageClass::Request);
assert_eq!(msg_type.method(), method);
let buffer = [0x22, 0xB8];
let msg_type = MessageType::from(&buffer);
assert_eq!(msg_type.class(), MessageClass::Indication);
assert_eq!(msg_type.method(), method);
let buffer = [0x23, 0xA8];
let msg_type = MessageType::from(&buffer);
assert_eq!(msg_type.class(), MessageClass::SuccessResponse);
assert_eq!(msg_type.method(), method);
let buffer = [0x23, 0xB8];
let msg_type = MessageType::from(&buffer);
assert_eq!(msg_type.class(), MessageClass::ErrorResponse);
assert_eq!(msg_type.method(), method);
let buffer = [0x23, 0xB8];
let msg_type = MessageType::from(&buffer);
assert_eq!(msg_type.class(), MessageClass::ErrorResponse);
assert_eq!(msg_type.method(), method);
}
#[test]
fn fmt() {
let cls = MessageClass::Request;
let method = MessageMethod::try_from(0x0001).expect("Can not create MessageMethod");
let msg_type = MessageType::new(method, cls);
let _val = format!("{:?}", msg_type);
let builder = StunMessageBuilder::new(BINDING, MessageClass::Request);
let _val = format!("{:?}", builder);
let msg = builder.build();
let _val = format!("{:?}", msg);
}
}