use crate::bitstream::{BitReader, BitWriter};
use crate::error::Error;
use crate::varint::{read_varint, write_varint};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Alternative {
pub hardened: bool,
pub value: u32,
}
impl Alternative {
pub fn write(&self, w: &mut BitWriter) -> Result<(), Error> {
w.write_bits(u64::from(self.hardened), 1);
write_varint(w, self.value)?;
Ok(())
}
pub fn read(r: &mut BitReader) -> Result<Self, Error> {
let hardened = r.read_bits(1)? != 0;
let value = read_varint(r)?;
Ok(Self { hardened, value })
}
}
pub const MIN_ALT_COUNT: usize = 2;
pub const MAX_ALT_COUNT: usize = 9;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UseSitePath {
pub multipath: Option<Vec<Alternative>>,
pub wildcard_hardened: bool,
}
impl UseSitePath {
pub fn standard_multipath() -> Self {
Self {
multipath: Some(vec![
Alternative {
hardened: false,
value: 0,
},
Alternative {
hardened: false,
value: 1,
},
]),
wildcard_hardened: false,
}
}
pub fn write(&self, w: &mut BitWriter) -> Result<(), Error> {
if let Some(alts) = &self.multipath {
if !(MIN_ALT_COUNT..=MAX_ALT_COUNT).contains(&alts.len()) {
return Err(Error::AltCountOutOfRange { got: alts.len() });
}
w.write_bits(1, 1);
w.write_bits((alts.len() - MIN_ALT_COUNT) as u64, 3);
for a in alts {
a.write(w)?;
}
} else {
w.write_bits(0, 1);
}
w.write_bits(u64::from(self.wildcard_hardened), 1);
Ok(())
}
pub fn read(r: &mut BitReader) -> Result<Self, Error> {
let has_multipath = r.read_bits(1)? != 0;
let multipath = if has_multipath {
let alt_count = (r.read_bits(3)? as usize) + MIN_ALT_COUNT;
let mut alts = Vec::with_capacity(alt_count);
for _ in 0..alt_count {
alts.push(Alternative::read(r)?);
}
Some(alts)
} else {
None
};
let wildcard_hardened = r.read_bits(1)? != 0;
Ok(Self {
multipath,
wildcard_hardened,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn use_site_path_standard_round_trip() {
let p = UseSitePath::standard_multipath();
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(UseSitePath::read(&mut r).unwrap(), p);
}
#[test]
fn use_site_path_standard_bit_cost() {
let p = UseSitePath::standard_multipath();
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
assert_eq!(w.bit_len(), 16);
}
#[test]
fn use_site_path_bare_star_round_trip() {
let p = UseSitePath {
multipath: None,
wildcard_hardened: false,
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(UseSitePath::read(&mut r).unwrap(), p);
}
#[test]
fn use_site_path_bare_star_bit_cost() {
let p = UseSitePath {
multipath: None,
wildcard_hardened: false,
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
assert_eq!(w.bit_len(), 2);
}
#[test]
fn use_site_path_hardened_wildcard_round_trip() {
let p = UseSitePath {
multipath: None,
wildcard_hardened: true,
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(UseSitePath::read(&mut r).unwrap(), p);
}
#[test]
fn use_site_path_alt_count_too_small_rejected() {
let p = UseSitePath {
multipath: Some(vec![Alternative {
hardened: false,
value: 0,
}]),
wildcard_hardened: false,
};
let mut w = BitWriter::new();
assert!(matches!(
p.write(&mut w),
Err(Error::AltCountOutOfRange { got: 1 })
));
}
#[test]
fn use_site_path_alt_count_too_large_rejected() {
let p = UseSitePath {
multipath: Some(
(0..10)
.map(|i| Alternative {
hardened: false,
value: i,
})
.collect(),
),
wildcard_hardened: false,
};
let mut w = BitWriter::new();
assert!(matches!(
p.write(&mut w),
Err(Error::AltCountOutOfRange { got: 10 })
));
}
}