use crate::{error::parser::MappingParseError, primitives::Str};
use bytes::{BufMut, Bytes, BytesMut};
use hashbrown::{
hash_map::{IntoIter, Iter},
HashMap,
};
use nom::{
number::complete::{be_u16, be_u8},
Err, IResult,
};
use alloc::vec::Vec;
use core::{
fmt::{self, Debug},
num::NonZeroUsize,
};
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct Mapping(HashMap<Str, Str>);
impl fmt::Display for Mapping {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Mapping {
pub fn serialize(&self) -> Bytes {
let out = BytesMut::with_capacity(2);
self.serialize_into(out).freeze()
}
pub fn serialize_into(&self, mut out: BytesMut) -> BytesMut {
let start_len = out.len();
if out.capacity() - out.len() < 2 {
out.reserve(2);
}
out.put_u16(0);
let data_start = out.len();
let mut entries: Vec<_> = self.0.iter().collect();
entries.sort_unstable_by(|a, b| a.0.cmp(b.0));
for (key, value) in entries {
let key = key.serialize();
let value = value.serialize();
out.reserve(key.len() + value.len() + 2);
out.extend_from_slice(&key);
out.put_u8(b'=');
out.extend_from_slice(&value);
out.put_u8(b';');
}
let data_len = out.len() - data_start;
debug_assert!(data_len <= u16::MAX as usize);
let prefix_pos = start_len;
out[prefix_pos..prefix_pos + 2].copy_from_slice(&(data_len as u16).to_be_bytes());
out
}
pub fn parse_frame(input: &[u8]) -> IResult<&[u8], Self, MappingParseError> {
let (rest, size) = be_u16(input)?;
let mut mapping = Self::default();
match rest.split_at_checked(size as usize) {
Some((mut data, rest)) => {
while !data.is_empty() {
let (remaining, key) = Str::parse_frame(data).map_err(Err::convert)?;
let (remaining, _) = be_u8(remaining)?;
let (remaining, value) = Str::parse_frame(remaining).map_err(Err::convert)?;
let (remaining, _) = be_u8(remaining)?;
mapping.insert(key, value);
data = remaining;
}
Ok((rest, mapping))
}
None => {
let non_zero_size = NonZeroUsize::new(size as usize).expect("non-zero size");
Err(nom::Err::Incomplete(nom::Needed::Size(non_zero_size)))
}
}
}
pub fn parse(bytes: impl AsRef<[u8]>) -> Result<Mapping, MappingParseError> {
Ok(Self::parse_frame(bytes.as_ref())?.1)
}
pub fn insert(&mut self, key: Str, value: Str) -> Option<Str> {
self.0.insert(key, value)
}
pub fn get(&self, key: &Str) -> Option<&Str> {
self.0.get(key)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> Iter<'_, Str, Str> {
self.0.iter()
}
pub fn remove(&mut self, key: &Str) -> Option<Str> {
self.0.remove(key)
}
pub fn retain(&mut self, predicate: impl Fn(&Str, &mut Str) -> bool) {
self.0.retain(predicate);
}
}
impl IntoIterator for Mapping {
type Item = (Str, Str);
type IntoIter = IntoIter<Str, Str>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl FromIterator<(Str, Str)> for Mapping {
fn from_iter<T: IntoIterator<Item = (Str, Str)>>(iter: T) -> Self {
Self(HashMap::from_iter(iter))
}
}
impl From<HashMap<Str, Str>> for Mapping {
fn from(value: HashMap<Str, Str>) -> Self {
Mapping(value)
}
}
#[cfg(test)]
mod tests {
use crate::{
crypto::base64_encode,
primitives::RouterId,
runtime::{mock::MockRuntime, Runtime},
};
use super::*;
#[test]
fn empty_mapping() {
assert_eq!(Mapping::parse(b"\0\0"), Ok(Mapping::default()));
}
#[test]
fn valid_mapping() {
let mut mapping = Mapping::default();
mapping.insert("hello".into(), "world".into());
let ser = mapping.serialize();
assert_eq!(Mapping::parse(ser), Ok(mapping));
}
#[test]
fn valid_string_with_extra_end_bytes() {
let mut mapping = Mapping::default();
mapping.insert("hello".into(), "world".into());
let mut ser = mapping.serialize().to_vec();
ser.push(1);
ser.push(2);
ser.push(3);
ser.push(4);
assert_eq!(Mapping::parse(ser), Ok(mapping));
}
#[test]
fn valid_string_with_extra_start_bytes() {
let mut mapping = Mapping::default();
mapping.insert("hello".into(), "world".into());
const PREFIX: &[u8] = b"prefix";
let buf = BytesMut::from(PREFIX);
let ser = mapping.serialize_into(buf).to_vec();
assert_eq!(&ser[..PREFIX.len()], b"prefix");
assert_eq!(Mapping::parse(&ser[PREFIX.len()..]), Ok(mapping));
}
#[test]
fn extra_bytes_returned() {
let mut mapping = Mapping::default();
mapping.insert("hello".into(), "world".into());
let mut ser = mapping.serialize().to_vec();
ser.push(1);
ser.push(2);
ser.push(3);
ser.push(4);
let (rest, parsed_mapping) = Mapping::parse_frame(&ser).unwrap();
assert_eq!(parsed_mapping, mapping);
assert_eq!(rest, [1, 2, 3, 4]);
}
#[test]
fn multiple_mappings() {
let expected_ser = b"\x00\x19\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;\x02zz=\x01z;";
let mapping = Mapping::parse(expected_ser).expect("to be valid");
assert_eq!(mapping.get(&"a".into()), Some(&Str::from("b")));
assert_eq!(mapping.get(&"c".into()), Some(&Str::from("d")));
assert_eq!(mapping.get(&"e".into()), Some(&Str::from("f")));
assert_eq!(mapping.get(&"zz".into()), Some(&Str::from("z")));
assert_eq!(mapping.serialize().to_vec(), expected_ser);
}
#[test]
fn over_sized() {
let ser = b"\x01\x00\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;";
assert_eq!(
Mapping::parse(ser).unwrap_err(),
MappingParseError::InvalidBitstream
);
}
#[test]
fn retain_works() {
let mut options = Mapping::from_iter([
(Str::from("host"), Str::from("127.0.0.1")),
(Str::from("port"), Str::from("8888")),
]);
for i in 0..3 {
options.insert(
Str::from(format!("iexp{i}")),
Str::from((MockRuntime::time_since_epoch().as_secs() + 1337).to_string()),
);
options.insert(
Str::from(format!("ih{i}")),
Str::from(base64_encode(RouterId::random().to_vec())),
);
options.insert(
Str::from(format!("itag{i}")),
Str::from(format!("{}", 1337 + i)),
);
}
assert_eq!(options.len(), 11);
options.retain(|key, _| {
!(key.starts_with("iexp") || key.starts_with("itag") || key.starts_with("ih"))
});
assert_eq!(
options.get(&Str::from("host")),
Some(&Str::from("127.0.0.1"))
);
assert_eq!(options.get(&Str::from("port")), Some(&Str::from("8888")));
}
}