use sawp::error::Result;
use sawp::parser::{Direction, Parse};
use sawp::probe::Probe;
use sawp::protocol::Protocol;
use nom::bytes::streaming::tag;
use nom::bytes::streaming::take;
use nom::combinator;
use nom::error::ErrorKind;
use nom::multi::many0;
use nom::number::streaming::{be_u24, be_u32, be_u8};
use nom::IResult;
#[derive(Debug)]
pub struct Diameter {}
#[derive(Debug, PartialEq)]
pub struct Header {
version: u8,
length: u32,
flags: u8,
code: u32,
app_id: u32,
hop_id: u32,
end_id: u32,
}
#[derive(Debug, PartialEq)]
pub struct AVP {
code: u32,
flags: u8,
length: u32,
vendor_id: Option<u32>,
data: Vec<u8>,
padding: Vec<u8>,
}
#[derive(Debug, PartialEq)]
pub struct Message {
header: Header,
avps: Vec<AVP>,
}
fn length(read: usize) -> impl Fn(&[u8]) -> IResult<&[u8], u32> {
move |input: &[u8]| {
let (input, length) = be_u24(input)?;
let len = length as usize;
if len < read {
Err(nom::Err::Error((input, ErrorKind::LengthValue)))
} else if len > (input.len() + read) {
Err(nom::Err::Incomplete(nom::Needed::Size(
len - (input.len() + read),
)))
} else {
Ok((input, length))
}
}
}
impl Header {
const SIZE: usize = 20;
const PRE_LENGTH_SIZE: usize = 4;
pub const REQUEST_FLAG: u8 = 0b1000_0000;
pub const PROXIABLE_FLAG: u8 = 0b0100_0000;
pub const ERROR_FLAG: u8 = 0b0010_0000;
pub const POTENTIALLY_RETRANSMITTED_FLAG: u8 = 0b0001_0000;
pub const RESERVED_MASK: u8 = 0b0000_1111;
pub fn is_request(&self) -> bool {
self.flags & Self::REQUEST_FLAG != 0
}
pub fn is_proxiable(&self) -> bool {
self.flags & Self::PROXIABLE_FLAG != 0
}
pub fn is_error(&self) -> bool {
self.flags & Self::ERROR_FLAG != 0
}
pub fn is_potentially_retransmitted(&self) -> bool {
self.flags & Self::POTENTIALLY_RETRANSMITTED_FLAG != 0
}
pub fn reserved_set(&self) -> bool {
self.get_reserved() != 0
}
pub fn get_reserved(&self) -> u8 {
self.flags & Self::RESERVED_MASK
}
pub fn length(&self) -> usize {
(self.length as usize) - Self::SIZE
}
pub fn parse(input: &[u8]) -> IResult<&[u8], Self> {
let (input, version) = tag(&[1u8])(input)?;
let (input, length) = length(Self::PRE_LENGTH_SIZE)(input)?;
if (length as usize) < Self::SIZE {
return Err(nom::Err::Error((input, ErrorKind::LengthValue)));
}
let (input, flags) = be_u8(input)?;
let (input, code) = be_u24(input)?;
let (input, app_id) = be_u32(input)?;
let (input, hop_id) = be_u32(input)?;
let (input, end_id) = be_u32(input)?;
Ok((
input,
Self {
version: version[0],
length,
flags,
code,
app_id,
hop_id,
end_id,
},
))
}
}
impl AVP {
const PRE_LENGTH_SIZE: usize = 8;
pub const VENDOR_SPECIFIC_FLAG: u8 = 0b1000_0000;
pub const MANDATORY_FLAG: u8 = 0b0100_0000;
pub const PROTECTED_FLAG: u8 = 0b0010_0000;
pub const RESERVED_MASK: u8 = 0b0001_1111;
fn vendor_specific_flag(flags: u8) -> bool {
flags & Self::VENDOR_SPECIFIC_FLAG != 0
}
fn padding(length: usize) -> usize {
match length % 4 {
0 => 0,
n => 4 - n,
}
}
pub fn is_vendor_specific(&self) -> bool {
Self::vendor_specific_flag(self.flags)
}
pub fn is_mandatory(&self) -> bool {
self.flags & Self::MANDATORY_FLAG != 0
}
pub fn is_protected(&self) -> bool {
self.flags & Self::PROTECTED_FLAG != 0
}
pub fn reserved_set(&self) -> bool {
self.get_reserved() != 0
}
pub fn get_reserved(&self) -> u8 {
self.flags & Self::RESERVED_MASK
}
pub fn parse(input: &[u8]) -> IResult<&[u8], Self> {
let (input, code) = be_u32(input)?;
let (input, flags) = be_u8(input)?;
let (input, length) = length(Self::PRE_LENGTH_SIZE)(input)?;
let header_size = if Self::vendor_specific_flag(flags) {
Self::PRE_LENGTH_SIZE + 4
} else {
Self::PRE_LENGTH_SIZE
};
if (length as usize) < header_size {
return Err(nom::Err::Error((input, ErrorKind::LengthValue)));
}
let data_length = (length as usize) - header_size;
let (input, vendor_id) = if Self::vendor_specific_flag(flags) {
let (input, v) = be_u32(input)?;
(input, Some(v))
} else {
(input, None)
};
let (input, data) = take(data_length)(input)?;
let (input, padding) = take(Self::padding(data_length))(input)?;
Ok((
input,
Self {
code,
flags,
length,
vendor_id,
data: data.into(),
padding: padding.into(),
},
))
}
}
impl Protocol<'_> for Diameter {
type Message = Message;
fn name() -> &'static str {
"diameter"
}
}
impl<'a> Parse<'a> for Diameter {
fn parse(
&self,
input: &'a [u8],
_direction: Direction,
) -> Result<(&'a [u8], Option<Self::Message>)> {
let (input, header) = Header::parse(input)?;
let (input, avps_input) = combinator::complete(take(header.length()))(input)?;
let (rest, avps) = many0(combinator::complete(AVP::parse))(avps_input)?;
if !rest.is_empty() {
Err(nom::Err::Error((avps_input, ErrorKind::Many0)).into())
} else {
Ok((input, Some(Message { header, avps })))
}
}
}
impl<'a> Probe<'a> for Diameter {}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use sawp::error;
use sawp::probe::Status;
#[test]
fn test_name() {
assert_eq!(Diameter::name(), "diameter");
}
#[rstest(
input,
expected,
case::empty(b"", Err(nom::Err::Incomplete(nom::Needed::Size(1)))),
case::hello_world(b"hello world", Err(nom::Err::Error((b"hello world" as &[u8], ErrorKind::Tag)))),
case::invalid_length(
&[
0x01,
0x00, 0x00, 0x0c,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
],
Err(nom::Err::Error((
&[
0x80_u8,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
] as &[u8],
ErrorKind::LengthValue))
)
),
case::diagnostic(
&[
0x01,
0x00, 0x00, 0x14,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
],
Ok((&[] as &[u8], Header {
version: 1,
length: 20,
flags: 128,
code: 257,
app_id: 0,
hop_id: 0x53ca_fe6a,
end_id: 0x7dc0_a11b,
}))
),
case::diagnostic(
&[
0x01,
0x00, 0x00, 0x18,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
],
Err(nom::Err::Incomplete(nom::Needed::Size(4)))
),
)]
fn test_header(input: &[u8], expected: IResult<&[u8], Header>) {
assert_eq!(Header::parse(input), expected);
}
#[rstest(
input,
expected,
case::empty(b"", Err(nom::Err::Incomplete(nom::Needed::Size(4)))),
case::diagnostic(
&[
0x00, 0x00, 0x01, 0x08,
0x40,
0x00, 0x00, 0x1f,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x65, 0x61, 0x70, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x62, 0x65, 0x64, 0x2e, 0x61, 0x61, 0x61,
0x00,
],
Ok((&[] as &[u8], AVP {
code: 264,
flags: 0x40,
length: 31,
vendor_id: None,
data: (b"backend.eap.testbed.aaa" as &[u8]).into(),
padding: vec![0x00],
}))
),
case::diagnostic_vendor_id(
&[
0x00, 0x00, 0x01, 0x08,
0x80,
0x00, 0x00, 0x0c,
0x49, 0x96, 0x02, 0xd2,
],
Ok((&[] as &[u8], AVP {
code: 264,
flags: 0x80,
length: 12,
vendor_id: Some(1_234_567_890u32),
data: Vec::new(),
padding: Vec::new(),
}))
),
)]
fn test_avp(input: &[u8], expected: IResult<&[u8], AVP>) {
assert_eq!(AVP::parse(input), expected);
}
#[rstest(
input,
expected,
case::empty(b"", Err(error::Error::incomplete_needed(1))),
case::header(
&[
0x01,
0x00, 0x00, 0x14,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
],
Ok((&[] as &[u8],
Some(Message {
header: Header {
version: 1,
length: 20,
flags: 128,
code: 257,
app_id: 0,
hop_id: 0x53ca_fe6a,
end_id: 0x7dc0_a11b,
},
avps: Vec::new(),
})
))
),
case::full_message(
&[
0x01,
0x00, 0x00, 0x40,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
0x00, 0x00, 0x01, 0x08,
0x40,
0x00, 0x00, 0x1f,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x65, 0x61, 0x70, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x62, 0x65, 0x64, 0x2e, 0x61, 0x61, 0x61,
0x00,
0x00, 0x00, 0x01, 0x08,
0x80,
0x00, 0x00, 0x0c,
0x49, 0x96, 0x02, 0xd2,
],
Ok((&[] as &[u8],
Some(Message {
header: Header {
version: 1,
length: 64,
flags: 128,
code: 257,
app_id: 0,
hop_id: 0x53ca_fe6a,
end_id: 0x7dc0_a11b,
},
avps: vec![
AVP {
code: 264,
flags: 0x40,
length: 31,
vendor_id: None,
data: (b"backend.eap.testbed.aaa" as &[u8]).into(),
padding: vec![0x00],
},
AVP {
code: 264,
flags: 0x80,
length: 12,
vendor_id: Some(1_234_567_890u32),
data: Vec::new(),
padding: Vec::new(),
},
],
})
))),
case::incomplete(
&[
0x01,
0x00, 0x00, 0x42,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
0x00, 0x00, 0x01, 0x08,
0x40,
0x00, 0x00, 0x1f,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x65, 0x61, 0x70, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x62, 0x65, 0x64, 0x2e, 0x61, 0x61, 0x61,
0x00,
0x00, 0x00, 0x01, 0x08,
0x80,
0x00, 0x00, 0x0e,
0x49, 0x96, 0x02, 0xd2,
],
Err(error::Error::incomplete_needed(2))
),
case::invalid_avp(
&[
0x01,
0x00, 0x00, 0x40,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
0x00, 0x00, 0x01, 0x08,
0x40,
0x00, 0x00, 0x1f,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x65, 0x61, 0x70, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x62, 0x65, 0x64, 0x2e, 0x61, 0x61, 0x61,
0x00,
0x00, 0x00, 0x01, 0x08,
0x80,
0x00, 0x00, 0x0e,
0x49, 0x96, 0x02, 0xd2,
],
Err(error::Error::parse(Some("Many0".to_string()))),
),
)]
fn test_parse(input: &[u8], expected: Result<(&[u8], Option<Message>)>) {
let diameter = Diameter {};
assert_eq!(diameter.parse(input, Direction::Unknown), expected);
}
#[rstest(
input,
expected,
case::empty(b"", Status::Incomplete),
case::hello_world(b"hello world", Status::Unrecognized),
case::header(
&[
0x01,
0x00, 0x00, 0x14,
0x80,
0x00, 0x01, 0x01,
0x00, 0x00, 0x00, 0x00,
0x53, 0xca, 0xfe, 0x6a,
0x7d, 0xc0, 0xa1, 0x1b,
],
Status::Recognized
),
)]
fn test_probe(input: &[u8], expected: Status) {
let diameter = Diameter {};
assert_eq!(diameter.probe(input, Direction::Unknown), expected);
}
}