use crate::bitstream::{BitReader, BitWriter};
use crate::error::Error;
use crate::varint::{read_varint, write_varint};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PathComponent {
pub hardened: bool,
pub value: u32,
}
impl PathComponent {
pub fn write(&self, w: &mut BitWriter) -> Result<(), Error> {
w.write_bits(u64::from(self.hardened), 1);
write_varint(w, self.value)?;
Ok(())
}
pub fn read(r: &mut BitReader) -> Result<Self, Error> {
let hardened = r.read_bits(1)? != 0;
let value = read_varint(r)?;
Ok(Self { hardened, value })
}
}
pub const MAX_PATH_COMPONENTS: usize = 15;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OriginPath {
pub components: Vec<PathComponent>,
}
impl OriginPath {
pub fn write(&self, w: &mut BitWriter) -> Result<(), Error> {
if self.components.len() > MAX_PATH_COMPONENTS {
return Err(Error::PathDepthExceeded {
got: self.components.len(),
max: MAX_PATH_COMPONENTS,
});
}
w.write_bits(self.components.len() as u64, 4);
for c in &self.components {
c.write(w)?;
}
Ok(())
}
pub fn read(r: &mut BitReader) -> Result<Self, Error> {
let depth = r.read_bits(4)? as usize;
let mut components = Vec::with_capacity(depth);
for _ in 0..depth {
components.push(PathComponent::read(r)?);
}
Ok(Self { components })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PathDecl {
pub n: u8,
pub paths: PathDeclPaths,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PathDeclPaths {
Shared(OriginPath),
Divergent(Vec<OriginPath>),
}
impl PathDecl {
pub fn write(&self, w: &mut BitWriter) -> Result<(), Error> {
if !(1..=32).contains(&(self.n as u32)) {
return Err(Error::KeyCountOutOfRange { n: self.n });
}
w.write_bits((self.n - 1) as u64, 5);
match &self.paths {
PathDeclPaths::Shared(p) => p.write(w)?,
PathDeclPaths::Divergent(paths) => {
if paths.len() != self.n as usize {
return Err(Error::DivergentPathCountMismatch {
n: self.n,
got: paths.len(),
});
}
for p in paths {
p.write(w)?;
}
}
}
Ok(())
}
pub fn read(r: &mut BitReader, divergent_mode: bool) -> Result<Self, Error> {
let n = (r.read_bits(5)? + 1) as u8;
let paths = if divergent_mode {
let mut paths = Vec::with_capacity(n as usize);
for _ in 0..n {
paths.push(OriginPath::read(r)?);
}
PathDeclPaths::Divergent(paths)
} else {
PathDeclPaths::Shared(OriginPath::read(r)?)
};
Ok(Self { n, paths })
}
}
#[cfg(test)]
mod tests {
use super::*;
fn bip84() -> OriginPath {
OriginPath {
components: vec![
PathComponent {
hardened: true,
value: 84,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 0,
},
],
}
}
#[test]
fn origin_path_round_trip_bip84() {
let p = bip84();
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(OriginPath::read(&mut r).unwrap(), p);
}
#[test]
fn origin_path_bit_cost_bip84() {
let p = bip84();
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
assert_eq!(w.bit_len(), 26);
}
#[test]
fn origin_path_rejects_depth_too_large() {
let p = OriginPath {
components: (0..16)
.map(|_| PathComponent {
hardened: false,
value: 0,
})
.collect(),
};
let mut w = BitWriter::new();
assert!(matches!(
p.write(&mut w),
Err(Error::PathDepthExceeded { got: 16, max: 15 })
));
}
}
#[cfg(test)]
mod path_decl_tests {
use super::*;
#[test]
fn path_decl_shared_round_trip() {
let p = PathDecl {
n: 1,
paths: PathDeclPaths::Shared(OriginPath {
components: vec![
PathComponent {
hardened: true,
value: 84,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 0,
},
],
}),
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(PathDecl::read(&mut r, false).unwrap(), p);
}
#[test]
fn path_decl_shared_bit_cost_bip84() {
let p = PathDecl {
n: 1,
paths: PathDeclPaths::Shared(OriginPath {
components: vec![
PathComponent {
hardened: true,
value: 84,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 0,
},
],
}),
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
assert_eq!(w.bit_len(), 31);
}
#[test]
fn path_decl_divergent_round_trip() {
let p = PathDecl {
n: 2,
paths: PathDeclPaths::Divergent(vec![
OriginPath {
components: vec![PathComponent {
hardened: true,
value: 84,
}],
},
OriginPath {
components: vec![PathComponent {
hardened: true,
value: 86,
}],
},
]),
};
let mut w = BitWriter::new();
p.write(&mut w).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
assert_eq!(PathDecl::read(&mut r, true).unwrap(), p);
}
#[test]
fn path_decl_n_zero_rejected() {
let p = PathDecl {
n: 0,
paths: PathDeclPaths::Shared(OriginPath { components: vec![] }),
};
let mut w = BitWriter::new();
assert!(matches!(
p.write(&mut w),
Err(Error::KeyCountOutOfRange { n: 0 })
));
}
}