use crate::attribute::{Attribute, AttributeType};
use crate::message::{Message, MessageEncoder};
use crate::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder};
use crate::rfc5389::errors;
use bytecodec::bytes::{BytesEncoder, CopyableBytesDecoder, Utf8Decoder, Utf8Encoder};
use bytecodec::combinator::{Collect, PreEncode, Repeat};
use bytecodec::fixnum::{U16beDecoder, U16beEncoder, U32beDecoder, U32beEncoder};
use bytecodec::tuple::{TupleDecoder, TupleEncoder};
use bytecodec::{
ByteCount, Decode, Encode, EncodeExt, Eos, Error, ErrorKind, Result, SizedEncode,
TryTaggedDecode,
};
use byteorder::{BigEndian, ByteOrder};
use crc::crc32;
use hmacsha1::hmac_sha1;
use std::net::SocketAddr;
use std::vec;
macro_rules! impl_decode {
($decoder:ty, $item:ident, $and_then:expr) => {
impl Decode for $decoder {
type Item = $item;
fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
track!(self.0.decode(buf, eos))
}
fn finish_decoding(&mut self) -> Result<Self::Item> {
track!(self.0.finish_decoding()).and_then($and_then)
}
fn requiring_bytes(&self) -> ByteCount {
self.0.requiring_bytes()
}
fn is_idle(&self) -> bool {
self.0.is_idle()
}
}
impl TryTaggedDecode for $decoder {
type Tag = AttributeType;
fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result<bool> {
Ok(attr_type.as_u16() == $item::CODEPOINT)
}
}
};
}
macro_rules! impl_encode {
($encoder:ty, $item:ty, $map_from:expr) => {
impl Encode for $encoder {
type Item = $item;
fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result<usize> {
track!(self.0.encode(buf, eos))
}
fn start_encoding(&mut self, item: Self::Item) -> Result<()> {
track!(self.0.start_encoding($map_from(item)))
}
fn requiring_bytes(&self) -> ByteCount {
self.0.requiring_bytes()
}
fn is_idle(&self) -> bool {
self.0.is_idle()
}
}
impl SizedEncode for $encoder {
fn exact_requiring_bytes(&self) -> u64 {
self.0.exact_requiring_bytes()
}
}
};
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AlternateServer(SocketAddr);
impl AlternateServer {
pub const CODEPOINT: u16 = 0x8023;
pub fn new(addr: SocketAddr) -> Self {
AlternateServer(addr)
}
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl Attribute for AlternateServer {
type Decoder = AlternateServerDecoder;
type Encoder = AlternateServerEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct AlternateServerDecoder(SocketAddrDecoder);
impl AlternateServerDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(AlternateServerDecoder, AlternateServer, |item| Ok(
AlternateServer(item)
));
#[derive(Debug, Default)]
pub struct AlternateServerEncoder(SocketAddrEncoder);
impl AlternateServerEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(
AlternateServerEncoder,
AlternateServer,
|item: Self::Item| item.0
);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ErrorCode {
code: u16,
reason_phrase: String,
}
impl ErrorCode {
pub const CODEPOINT: u16 = 0x0009;
pub fn new(code: u16, reason_phrase: String) -> Result<Self> {
track_assert!((300..600).contains(&code), ErrorKind::InvalidInput; code, reason_phrase);
Ok(ErrorCode {
code,
reason_phrase,
})
}
pub fn code(&self) -> u16 {
self.code
}
pub fn reason_phrase(&self) -> &str {
&self.reason_phrase
}
}
impl Attribute for ErrorCode {
type Decoder = ErrorCodeDecoder;
type Encoder = ErrorCodeEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
impl From<Error> for ErrorCode {
fn from(f: Error) -> Self {
match *f.kind() {
ErrorKind::InvalidInput => errors::BadRequest.into(),
_ => errors::ServerError.into(),
}
}
}
#[derive(Debug, Default)]
pub struct ErrorCodeDecoder(TupleDecoder<(U32beDecoder, Utf8Decoder)>);
impl ErrorCodeDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(ErrorCodeDecoder, ErrorCode, |(value, reason_phrase): (
u32,
_
)| {
let class = (value >> 8) & 0b111;
let number = value & 0b1111_1111;
track_assert!((3..6).contains(&class), ErrorKind::InvalidInput);
track_assert!(number < 100, ErrorKind::InvalidInput);
let code = (class * 100 + number) as u16;
Ok(ErrorCode {
code,
reason_phrase,
})
});
#[derive(Debug, Default)]
pub struct ErrorCodeEncoder(TupleEncoder<(U32beEncoder, Utf8Encoder)>);
impl ErrorCodeEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(ErrorCodeEncoder, ErrorCode, |item: Self::Item| {
let class = u32::from(item.code / 100);
let number = u32::from(item.code % 100);
let value = (class << 8) | number;
(value, item.reason_phrase)
});
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Fingerprint {
crc32: u32,
}
impl Fingerprint {
pub const CODEPOINT: u16 = 0x8028;
pub fn new<A: Attribute>(message: &Message<A>) -> Result<Self> {
let mut bytes = track!(MessageEncoder::default().encode_into_bytes(message.clone()))?;
let final_len = bytes.len() as u16 - 20 + 8;
BigEndian::write_u16(&mut bytes[2..4], final_len);
let crc32 = crc32::checksum_ieee(&bytes[..]) ^ 0x5354_554e;
Ok(Fingerprint { crc32 })
}
pub fn crc32(&self) -> u32 {
self.crc32
}
}
impl Attribute for Fingerprint {
type Decoder = FingerprintDecoder;
type Encoder = FingerprintEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn after_decode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
let actual = track!(Self::new(message))?;
track_assert_eq!(actual.crc32, self.crc32, ErrorKind::InvalidInput);
Ok(())
}
}
#[derive(Debug, Default)]
pub struct FingerprintDecoder(U32beDecoder);
impl FingerprintDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(FingerprintDecoder, Fingerprint, |crc32| Ok(Fingerprint {
crc32
}));
#[derive(Debug, Default)]
pub struct FingerprintEncoder(U32beEncoder);
impl FingerprintEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(FingerprintEncoder, Fingerprint, |item: Self::Item| item
.crc32);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MappedAddress(SocketAddr);
impl MappedAddress {
pub const CODEPOINT: u16 = 0x0001;
pub fn new(addr: SocketAddr) -> Self {
MappedAddress(addr)
}
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl Attribute for MappedAddress {
type Decoder = MappedAddressDecoder;
type Encoder = MappedAddressEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct MappedAddressDecoder(SocketAddrDecoder);
impl MappedAddressDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(MappedAddressDecoder, MappedAddress, |item| Ok(
MappedAddress(item)
));
#[derive(Debug, Default)]
pub struct MappedAddressEncoder(SocketAddrEncoder);
impl MappedAddressEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(MappedAddressEncoder, MappedAddress, |item: Self::Item| item
.0);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MessageIntegrity {
hmac_sha1: [u8; 20],
preceding_message_bytes: Vec<u8>,
}
impl MessageIntegrity {
pub const CODEPOINT: u16 = 0x0008;
pub fn new_short_term_credential<A>(message: &Message<A>, password: &str) -> Result<Self>
where
A: Attribute,
{
let key = password.as_bytes();
let preceding_message_bytes = track!(Self::message_into_bytes(message.clone()))?;
let hmac_sha1 = hmac_sha1(key, &preceding_message_bytes);
Ok(MessageIntegrity {
hmac_sha1,
preceding_message_bytes,
})
}
pub fn new_long_term_credential<A>(
message: &Message<A>,
username: &Username,
realm: &Realm,
password: &str,
) -> Result<Self>
where
A: Attribute,
{
let key =
md5::compute(format!("{}:{}:{}", username.name(), realm.text(), password).as_bytes());
let preceding_message_bytes = track!(Self::message_into_bytes(message.clone()))?;
let hmac_sha1 = hmac_sha1(&key.0[..], &preceding_message_bytes);
Ok(MessageIntegrity {
hmac_sha1,
preceding_message_bytes,
})
}
pub fn check_short_term_credential(
&self,
password: &str,
) -> std::result::Result<(), ErrorCode> {
let key = password.as_bytes();
let expected = hmac_sha1(key, &self.preceding_message_bytes);
if self.hmac_sha1 == expected {
Ok(())
} else {
Err(errors::Unauthorized.into())
}
}
pub fn check_long_term_credential(
&self,
username: &Username,
realm: &Realm,
password: &str,
) -> std::result::Result<(), ErrorCode> {
let key =
md5::compute(format!("{}:{}:{}", username.name(), realm.text(), password).as_bytes());
let expected = hmac_sha1(&key.0[..], &self.preceding_message_bytes);
if self.hmac_sha1 == expected {
Ok(())
} else {
Err(errors::Unauthorized.into())
}
}
pub fn hmac_sha1(&self) -> [u8; 20] {
self.hmac_sha1
}
fn message_into_bytes<A: Attribute>(message: Message<A>) -> Result<Vec<u8>> {
let mut bytes = track!(MessageEncoder::default().encode_into_bytes(message))?;
let adjusted_len = bytes.len() - 20 + 4 + 20 ;
BigEndian::write_u16(&mut bytes[2..4], adjusted_len as u16);
Ok(bytes)
}
}
impl Attribute for MessageIntegrity {
type Decoder = MessageIntegrityDecoder;
type Encoder = MessageIntegrityEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn after_decode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
self.preceding_message_bytes = track!(Self::message_into_bytes(message.clone()))?;
Ok(())
}
}
#[derive(Debug, Default)]
pub struct MessageIntegrityDecoder(CopyableBytesDecoder<[u8; 20]>);
impl MessageIntegrityDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(MessageIntegrityDecoder, MessageIntegrity, |hmac_sha1| Ok(
MessageIntegrity {
hmac_sha1,
preceding_message_bytes: Vec::new()
}
));
#[derive(Debug, Default)]
pub struct MessageIntegrityEncoder(BytesEncoder<[u8; 20]>);
impl MessageIntegrityEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(
MessageIntegrityEncoder,
MessageIntegrity,
|item: Self::Item| item.hmac_sha1
);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Nonce {
value: String,
}
impl Nonce {
pub const CODEPOINT: u16 = 0x0015;
pub fn new(value: String) -> Result<Self> {
track_assert!(value.chars().count() < 128, ErrorKind::InvalidInput; value);
Ok(Nonce { value })
}
pub fn value(&self) -> &str {
&self.value
}
}
impl Attribute for Nonce {
type Decoder = NonceDecoder;
type Encoder = NonceEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct NonceDecoder(Utf8Decoder);
impl NonceDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(NonceDecoder, Nonce, Nonce::new);
#[derive(Debug, Default)]
pub struct NonceEncoder(Utf8Encoder);
impl NonceEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(NonceEncoder, Nonce, |item: Self::Item| item.value);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Realm {
text: String,
}
impl Realm {
pub const CODEPOINT: u16 = 0x0014;
pub fn new(text: String) -> Result<Self> {
track_assert!( text.chars().count() < 128, ErrorKind::InvalidInput; text);
Ok(Realm { text })
}
pub fn text(&self) -> &str {
&self.text
}
}
impl Attribute for Realm {
type Decoder = RealmDecoder;
type Encoder = RealmEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct RealmDecoder(Utf8Decoder);
impl RealmDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(RealmDecoder, Realm, Realm::new);
#[derive(Debug, Default)]
pub struct RealmEncoder(Utf8Encoder);
impl RealmEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(RealmEncoder, Realm, |item: Self::Item| item.text);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Software {
description: String,
}
impl Software {
pub const CODEPOINT: u16 = 0x8022;
pub fn new(description: String) -> Result<Self> {
track_assert!(description.chars().count() < 128, ErrorKind::InvalidInput; description);
Ok(Software { description })
}
pub fn description(&self) -> &str {
&self.description
}
}
impl Attribute for Software {
type Decoder = SoftwareDecoder;
type Encoder = SoftwareEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct SoftwareDecoder(Utf8Decoder);
impl SoftwareDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(SoftwareDecoder, Software, Software::new);
#[derive(Debug, Default)]
pub struct SoftwareEncoder(Utf8Encoder);
impl SoftwareEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(SoftwareEncoder, Software, |item: Self::Item| item
.description);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UnknownAttributes {
unknowns: Vec<AttributeType>,
}
impl UnknownAttributes {
pub const CODEPOINT: u16 = 0x000A;
pub fn new(unknowns: Vec<AttributeType>) -> Self {
UnknownAttributes { unknowns }
}
pub fn unknowns(&self) -> &[AttributeType] {
&self.unknowns
}
}
impl Attribute for UnknownAttributes {
type Decoder = UnknownAttributesDecoder;
type Encoder = UnknownAttributesEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct UnknownAttributesDecoder(Collect<U16beDecoder, Vec<u16>>);
impl UnknownAttributesDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(UnknownAttributesDecoder, UnknownAttributes, |vs: Vec<
u16,
>| Ok(
UnknownAttributes {
unknowns: vs.into_iter().map(AttributeType::new).collect()
}
));
#[derive(Debug, Default)]
pub struct UnknownAttributesEncoder(PreEncode<Repeat<U16beEncoder, vec::IntoIter<u16>>>);
impl UnknownAttributesEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(
UnknownAttributesEncoder,
UnknownAttributes,
|item: Self::Item| item
.unknowns
.into_iter()
.map(|ty| ty.as_u16())
.collect::<Vec<_>>()
.into_iter()
);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Username {
name: String,
}
impl Username {
pub const CODEPOINT: u16 = 0x0006;
pub fn new(name: String) -> Result<Self> {
track_assert!(name.len() < 513, ErrorKind::InvalidInput; name);
Ok(Username { name })
}
pub fn name(&self) -> &str {
&self.name
}
}
impl Attribute for Username {
type Decoder = UsernameDecoder;
type Encoder = UsernameEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
#[derive(Debug, Default)]
pub struct UsernameDecoder(Utf8Decoder);
impl UsernameDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(UsernameDecoder, Username, Username::new);
#[derive(Debug, Default)]
pub struct UsernameEncoder(Utf8Encoder);
impl UsernameEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(UsernameEncoder, Username, |item: Self::Item| item.name);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct XorMappedAddress(SocketAddr);
impl XorMappedAddress {
pub const CODEPOINT: u16 = 0x0020;
pub fn new(addr: SocketAddr) -> Self {
XorMappedAddress(addr)
}
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl Attribute for XorMappedAddress {
type Decoder = XorMappedAddressDecoder;
type Encoder = XorMappedAddressEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn before_encode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
fn after_decode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
}
#[derive(Debug, Default)]
pub struct XorMappedAddressDecoder(SocketAddrDecoder);
impl XorMappedAddressDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(XorMappedAddressDecoder, XorMappedAddress, |item| Ok(
XorMappedAddress(item)
));
#[derive(Debug, Default)]
pub struct XorMappedAddressEncoder(SocketAddrEncoder);
impl XorMappedAddressEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(
XorMappedAddressEncoder,
XorMappedAddress,
|item: Self::Item| item.0
);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct XorMappedAddress2(SocketAddr);
impl XorMappedAddress2 {
pub const CODEPOINT: u16 = 0x8020;
pub fn new(addr: SocketAddr) -> Self {
XorMappedAddress2(addr)
}
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl Attribute for XorMappedAddress2 {
type Decoder = XorMappedAddress2Decoder;
type Encoder = XorMappedAddress2Encoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn before_encode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
fn after_decode<A: Attribute>(&mut self, message: &Message<A>) -> Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
}
#[derive(Debug, Default)]
pub struct XorMappedAddress2Decoder(SocketAddrDecoder);
impl XorMappedAddress2Decoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(XorMappedAddress2Decoder, XorMappedAddress2, |item| Ok(
XorMappedAddress2(item)
));
#[derive(Debug, Default)]
pub struct XorMappedAddress2Encoder(SocketAddrEncoder);
impl XorMappedAddress2Encoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(
XorMappedAddress2Encoder,
XorMappedAddress2,
|item: Self::Item| item.0
);