use alloc::vec::Vec;
use core::convert::TryFrom;
use crate::message::{IntegrityKey, StunParseError, StunWriteError};
use super::{
Attribute, AttributeFromRaw, AttributeStaticType, AttributeType, AttributeWrite,
AttributeWriteExt, RawAttribute,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageIntegrity {
hmac: [u8; 20],
}
impl AttributeStaticType for MessageIntegrity {
const TYPE: AttributeType = AttributeType(0x0008);
}
impl Attribute for MessageIntegrity {
fn get_type(&self) -> AttributeType {
Self::TYPE
}
fn length(&self) -> u16 {
20
}
}
impl AttributeWrite for MessageIntegrity {
fn to_raw(&self) -> RawAttribute<'_> {
RawAttribute::new(MessageIntegrity::TYPE, &self.hmac)
}
fn write_into_unchecked(&self, dest: &mut [u8]) {
self.write_header_unchecked(dest);
dest[4..4 + self.hmac.len()].copy_from_slice(&self.hmac);
}
}
impl AttributeFromRaw<'_> for MessageIntegrity {
fn from_raw_ref(raw: &RawAttribute) -> Result<Self, StunParseError>
where
Self: Sized,
{
Self::try_from(raw)
}
}
impl TryFrom<&RawAttribute<'_>> for MessageIntegrity {
type Error = StunParseError;
fn try_from(raw: &RawAttribute) -> Result<Self, Self::Error> {
raw.check_type_and_len(Self::TYPE, 20..=20)?;
let hmac: [u8; 20] = (&*raw.value).try_into().unwrap();
Ok(Self { hmac })
}
}
impl MessageIntegrity {
pub fn new(hmac: [u8; 20]) -> Self {
Self { hmac }
}
pub fn hmac(&self) -> &[u8; 20] {
&self.hmac
}
#[tracing::instrument(
name = "MessageIntegrity::compute",
level = "trace",
err,
skip(data, key)
)]
pub fn compute(data: &[&[u8]], key: &IntegrityKey) -> Result<[u8; 20], StunWriteError> {
Ok(key.compute_sha1(data).into_bytes().into())
}
#[tracing::instrument(
name = "MessageIntegrity::verify",
level = "debug",
skip(data, key, expected)
)]
pub fn verify(data: &[&[u8]], key: &IntegrityKey, expected: &[u8; 20]) -> bool {
key.verify_sha1(data, expected)
}
}
impl core::fmt::Display for MessageIntegrity {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}: 0x", Self::TYPE)?;
for val in self.hmac.iter() {
write!(f, "{val:02x}")?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageIntegritySha256 {
hmac: Vec<u8>,
}
impl AttributeStaticType for MessageIntegritySha256 {
const TYPE: AttributeType = AttributeType(0x001C);
}
impl Attribute for MessageIntegritySha256 {
fn get_type(&self) -> AttributeType {
Self::TYPE
}
fn length(&self) -> u16 {
self.hmac.len() as u16
}
}
impl AttributeWrite for MessageIntegritySha256 {
fn to_raw(&self) -> RawAttribute<'_> {
RawAttribute::new(MessageIntegritySha256::TYPE, &self.hmac)
}
fn write_into_unchecked(&self, dest: &mut [u8]) {
self.write_header_unchecked(dest);
dest[4..4 + self.hmac.len()].copy_from_slice(&self.hmac);
}
}
impl AttributeFromRaw<'_> for MessageIntegritySha256 {
fn from_raw_ref(raw: &RawAttribute) -> Result<Self, StunParseError>
where
Self: Sized,
{
Self::try_from(raw)
}
}
impl TryFrom<&RawAttribute<'_>> for MessageIntegritySha256 {
type Error = StunParseError;
fn try_from(raw: &RawAttribute) -> Result<Self, Self::Error> {
raw.check_type_and_len(Self::TYPE, 16..=32)?;
if raw.value.len() % 4 != 0 {
return Err(StunParseError::InvalidAttributeData);
}
Ok(Self {
hmac: raw.value.to_vec(),
})
}
}
impl MessageIntegritySha256 {
pub fn new(hmac: &[u8]) -> Result<Self, StunWriteError> {
if hmac.len() < 16 {
return Err(StunWriteError::TooSmall {
expected: 16,
actual: hmac.len(),
});
}
if hmac.len() > 32 {
return Err(StunWriteError::TooLarge {
expected: 32,
actual: hmac.len(),
});
}
if hmac.len() % 4 != 0 {
return Err(StunWriteError::IntegrityFailed);
}
Ok(Self {
hmac: hmac.to_vec(),
})
}
pub fn hmac(&self) -> &[u8] {
&self.hmac
}
#[tracing::instrument(
name = "MessageIntegritySha256::compute",
level = "trace",
err,
skip(data, key)
)]
pub fn compute(data: &[&[u8]], key: &IntegrityKey) -> Result<[u8; 32], StunWriteError> {
Ok(key.compute_sha256(data))
}
#[tracing::instrument(
name = "MessageIntegritySha256::verify",
level = "debug",
skip(data, key, expected)
)]
pub fn verify(data: &[&[u8]], key: &IntegrityKey, expected: &[u8]) -> bool {
key.verify_sha256(data, expected)
}
}
impl core::fmt::Display for MessageIntegritySha256 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}: 0x", Self::TYPE)?;
for val in self.hmac.iter() {
write!(f, "{val:02x}")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{
message::{LongTermKeyCredentials, MessageIntegrityCredentials, ShortTermCredentials},
prelude::AttributeExt,
};
use alloc::string::ToString;
use alloc::vec;
use super::*;
use byteorder::{BigEndian, ByteOrder};
use tracing::trace;
#[test]
fn message_integrity() {
let _log = crate::tests::test_init_log();
let val = [1; 20];
let attr = MessageIntegrity::new(val);
trace!("{attr}");
assert_eq!(attr.hmac(), &val);
assert_eq!(attr.length(), 20);
}
#[test]
fn message_integrity_raw() {
let _log = crate::tests::test_init_log();
let val = [1; 20];
let attr = MessageIntegrity::new(val);
let raw = RawAttribute::from(&attr);
trace!("{raw}");
assert_eq!(raw.get_type(), MessageIntegrity::TYPE);
let mapped2 = MessageIntegrity::try_from(&raw).unwrap();
assert_eq!(mapped2.hmac(), &val);
}
#[test]
fn message_integrity_raw_short() {
let _log = crate::tests::test_init_log();
let val = [1; 20];
let attr = MessageIntegrity::new(val);
let raw = RawAttribute::from(&attr);
let mut data: Vec<_> = raw.clone().into();
let len = data.len();
BigEndian::write_u16(&mut data[2..4], len as u16 - 4 - 1);
assert!(matches!(
MessageIntegrity::try_from(
&RawAttribute::from_bytes(data[..len - 1].as_ref()).unwrap()
),
Err(StunParseError::Truncated {
expected: 20,
actual: 19
})
));
}
#[test]
fn message_integrity_wrong_type() {
let _log = crate::tests::test_init_log();
let val = [1; 20];
let attr = MessageIntegrity::new(val);
let raw = RawAttribute::from(&attr);
let mut data: Vec<_> = raw.into();
BigEndian::write_u16(&mut data[0..2], 0);
assert!(matches!(
MessageIntegrity::try_from(&RawAttribute::from_bytes(data.as_ref()).unwrap()),
Err(StunParseError::WrongAttributeImplementation)
));
}
#[test]
fn message_integrity_write_into() {
let _log = crate::tests::test_init_log();
let val = [1; 20];
let attr = MessageIntegrity::new(val);
let raw = RawAttribute::from(&attr);
let mut out = vec![0; raw.padded_len()];
attr.write_into(&mut out).unwrap();
assert_eq!(out, raw.to_bytes());
}
#[test]
fn message_integrity_verify_fixed_value() {
let credentials = ShortTermCredentials::new("pass".to_string());
let key = MessageIntegrityCredentials::from(credentials)
.make_key(crate::message::IntegrityAlgorithm::Sha1);
let data = [10; 30];
let mut expected = [
92, 91, 148, 243, 28, 168, 16, 154, 137, 179, 250, 169, 153, 222, 37, 127, 210, 148,
222, 119,
];
assert!(MessageIntegrity::verify(&[&data], &key, &expected),);
expected[0] = 0;
assert!(!MessageIntegrity::verify(&[&data], &key, &expected),);
}
#[test]
fn message_integrity_verify_key_long_wrong_type() {
let credentials = LongTermKeyCredentials::new(
"user".to_string(),
"pass".to_string(),
"realm".to_string(),
);
let key = MessageIntegrityCredentials::from(credentials)
.make_key(crate::message::IntegrityAlgorithm::Sha256);
let data = [10; 30];
let mut expected = [
6, 162, 255, 56, 215, 134, 145, 90, 154, 49, 51, 6, 22, 49, 202, 8, 176, 159, 24,
93,
];
assert!(!MessageIntegrity::verify(&[&data], &key, &expected),);
expected[0] = 0;
assert!(!MessageIntegrity::verify(&[&data], &key, &expected),);
}
#[test]
fn message_integrity_sha256() {
let _log = crate::tests::test_init_log();
let val = [1; 32];
let attr = MessageIntegritySha256::new(&val).unwrap();
trace!("{attr}");
assert_eq!(attr.hmac(), &val);
assert_eq!(attr.length(), 32);
}
#[test]
fn message_integrity_sha256_raw() {
let _log = crate::tests::test_init_log();
let val = [1; 32];
let attr = MessageIntegritySha256::new(&val).unwrap();
let raw = RawAttribute::from(&attr);
trace!("{raw}");
assert_eq!(raw.get_type(), MessageIntegritySha256::TYPE);
let mapped2 = MessageIntegritySha256::try_from(&raw).unwrap();
assert_eq!(mapped2.hmac(), &val);
}
#[test]
fn message_integrity_sha256_raw_short() {
let _log = crate::tests::test_init_log();
let val = [1; 32];
let attr = MessageIntegritySha256::new(&val).unwrap();
let raw = RawAttribute::from(&attr);
let mut data: Vec<_> = raw.clone().into();
let len = data.len();
BigEndian::write_u16(&mut data[2..4], len as u16 - 4 - 1);
assert!(matches!(
MessageIntegritySha256::try_from(
&RawAttribute::from_bytes(data[..len - 1].as_ref()).unwrap()
),
Err(StunParseError::InvalidAttributeData)
));
}
#[test]
fn message_integrity_sha256_raw_wrong_type() {
let _log = crate::tests::test_init_log();
let val = [1; 32];
let attr = MessageIntegritySha256::new(&val).unwrap();
let raw = RawAttribute::from(&attr);
let mut data: Vec<_> = raw.into();
BigEndian::write_u16(&mut data[0..2], 0);
assert!(matches!(
MessageIntegritySha256::try_from(&RawAttribute::from_bytes(data.as_ref()).unwrap()),
Err(StunParseError::WrongAttributeImplementation)
));
}
#[test]
fn message_integrity_sha256_write_into() {
let _log = crate::tests::test_init_log();
let val = [1; 32];
let attr = MessageIntegritySha256::new(&val).unwrap();
let raw = RawAttribute::from(&attr);
let mut out = vec![0; raw.padded_len()];
attr.write_into(&mut out).unwrap();
assert_eq!(out, raw.to_bytes());
}
#[test]
fn message_integrity_sha256_verify_fixed_value() {
let credentials = ShortTermCredentials::new("pass".to_string());
let key = MessageIntegrityCredentials::from(credentials)
.make_key(crate::message::IntegrityAlgorithm::Sha256);
let data = [10; 30];
let mut expected = [
16, 175, 53, 195, 18, 50, 153, 148, 7, 247, 27, 185, 195, 171, 22, 197, 22, 180, 244,
67, 190, 185, 71, 34, 150, 194, 108, 18, 75, 94, 221, 185,
];
assert!(MessageIntegritySha256::verify(&[&data], &key, &expected),);
expected[0] = 0;
assert!(!MessageIntegritySha256::verify(&[&data], &key, &expected),);
}
#[test]
fn message_integrity_sha256_verify_key_long() {
let credentials = LongTermKeyCredentials::new(
"user".to_string(),
"pass".to_string(),
"realm".to_string(),
);
let key = MessageIntegrityCredentials::from(credentials)
.make_key(crate::message::IntegrityAlgorithm::Sha256);
let data = [10; 30];
let mut expected = [
6, 162, 255, 56, 215, 134, 145, 90, 154, 49, 51, 6, 22, 49, 202, 8, 176, 159, 24, 93,
161, 160, 22, 105, 211, 138, 184, 39, 172, 103, 186, 106,
];
assert!(MessageIntegritySha256::verify(&[&data], &key, &expected),);
expected[0] = 0;
assert!(!MessageIntegritySha256::verify(&[&data], &key, &expected),);
}
#[test]
fn message_integrity_sha256_verify_key_long_wrong_type() {
let credentials = LongTermKeyCredentials::new(
"user".to_string(),
"pass".to_string(),
"realm".to_string(),
);
let key = MessageIntegrityCredentials::from(credentials)
.make_key(crate::message::IntegrityAlgorithm::Sha1);
let data = [10; 30];
let mut expected = [
6, 162, 255, 56, 215, 134, 145, 90, 154, 49, 51, 6, 22, 49, 202, 8, 176, 159, 24, 93,
161, 160, 22, 105, 211, 138, 184, 39, 172, 103, 186, 106,
];
assert!(!MessageIntegritySha256::verify(&[&data], &key, &expected),);
expected[0] = 0;
assert!(!MessageIntegritySha256::verify(&[&data], &key, &expected),);
}
#[test]
fn message_integrity_sha256_new_too_large() {
let _log = crate::tests::test_init_log();
let val = [1; 33];
assert!(matches!(
MessageIntegritySha256::new(&val),
Err(StunWriteError::TooLarge {
expected: 32,
actual: 33
})
));
}
#[test]
fn message_integrity_sha256_new_too_small() {
let _log = crate::tests::test_init_log();
let val = [1; 15];
assert!(matches!(
MessageIntegritySha256::new(&val),
Err(StunWriteError::TooSmall {
expected: 16,
actual: 15
})
));
}
#[test]
fn message_integrity_sha256_new_not_multiple_of_4() {
let _log = crate::tests::test_init_log();
let val = [1; 19];
assert!(matches!(
MessageIntegritySha256::new(&val),
Err(StunWriteError::IntegrityFailed)
));
}
}