use crate::bitstream::{BitReader, BitWriter, re_emit_bits};
use crate::error::Error;
use crate::origin_path::OriginPath;
use crate::use_site_path::UseSitePath;
use crate::varint::{read_varint, write_varint};
pub const TLV_USE_SITE_PATH_OVERRIDES: u8 = 0x00;
pub const TLV_FINGERPRINTS: u8 = 0x01;
pub const TLV_PUBKEYS: u8 = 0x02;
pub const TLV_ORIGIN_PATH_OVERRIDES: u8 = 0x03;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TlvSection {
pub use_site_path_overrides: Option<Vec<(u8, UseSitePath)>>,
pub fingerprints: Option<Vec<(u8, [u8; 4])>>,
pub pubkeys: Option<Vec<(u8, [u8; 65])>>,
pub origin_path_overrides: Option<Vec<(u8, OriginPath)>>,
pub unknown: Vec<(u8, Vec<u8>, usize)>,
}
impl TlvSection {
pub fn new_empty() -> Self {
Self {
use_site_path_overrides: None,
fingerprints: None,
pubkeys: None,
origin_path_overrides: None,
unknown: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
let Self {
use_site_path_overrides,
fingerprints,
pubkeys,
origin_path_overrides,
unknown,
} = self;
use_site_path_overrides.is_none()
&& fingerprints.is_none()
&& pubkeys.is_none()
&& origin_path_overrides.is_none()
&& unknown.is_empty()
}
pub fn write(&self, w: &mut BitWriter, key_index_width: u8) -> Result<(), Error> {
let Self {
use_site_path_overrides,
fingerprints,
pubkeys,
origin_path_overrides,
unknown,
} = self;
let mut entries: Vec<(u8, Vec<u8>, usize)> = Vec::new();
if let Some(overrides) = use_site_path_overrides {
if overrides.is_empty() {
return Err(Error::EmptyTlvEntry {
tag: TLV_USE_SITE_PATH_OVERRIDES,
});
}
let mut sub = BitWriter::new();
let mut last_idx: Option<u8> = None;
for (idx, path) in overrides {
if let Some(prev) = last_idx {
if *idx <= prev {
return Err(Error::OverrideOrderViolation {
prev,
current: *idx,
});
}
}
last_idx = Some(*idx);
sub.write_bits(u64::from(*idx), key_index_width as usize);
path.write(&mut sub)?;
}
let bit_len = sub.bit_len();
entries.push((TLV_USE_SITE_PATH_OVERRIDES, sub.into_bytes(), bit_len));
}
if let Some(fps) = fingerprints {
if fps.is_empty() {
return Err(Error::EmptyTlvEntry {
tag: TLV_FINGERPRINTS,
});
}
let mut sub = BitWriter::new();
let mut last_idx: Option<u8> = None;
for (idx, fp) in fps {
if let Some(prev) = last_idx {
if *idx <= prev {
return Err(Error::OverrideOrderViolation {
prev,
current: *idx,
});
}
}
last_idx = Some(*idx);
sub.write_bits(u64::from(*idx), key_index_width as usize);
for b in fp {
sub.write_bits(u64::from(*b), 8);
}
}
let bit_len = sub.bit_len();
entries.push((TLV_FINGERPRINTS, sub.into_bytes(), bit_len));
}
if let Some(pks) = pubkeys {
if pks.is_empty() {
return Err(Error::EmptyTlvEntry { tag: TLV_PUBKEYS });
}
let mut sub = BitWriter::new();
let mut last_idx: Option<u8> = None;
for (idx, xpub) in pks {
if let Some(prev) = last_idx {
if *idx <= prev {
return Err(Error::OverrideOrderViolation {
prev,
current: *idx,
});
}
}
last_idx = Some(*idx);
sub.write_bits(u64::from(*idx), key_index_width as usize);
for b in xpub {
sub.write_bits(u64::from(*b), 8);
}
}
let bit_len = sub.bit_len();
entries.push((TLV_PUBKEYS, sub.into_bytes(), bit_len));
}
if let Some(paths) = origin_path_overrides {
if paths.is_empty() {
return Err(Error::EmptyTlvEntry {
tag: TLV_ORIGIN_PATH_OVERRIDES,
});
}
let mut sub = BitWriter::new();
let mut last_idx: Option<u8> = None;
for (idx, path) in paths {
if let Some(prev) = last_idx {
if *idx <= prev {
return Err(Error::OverrideOrderViolation {
prev,
current: *idx,
});
}
}
last_idx = Some(*idx);
sub.write_bits(u64::from(*idx), key_index_width as usize);
path.write(&mut sub)?;
}
let bit_len = sub.bit_len();
entries.push((TLV_ORIGIN_PATH_OVERRIDES, sub.into_bytes(), bit_len));
}
for (tag, payload, bit_len) in unknown {
entries.push((*tag, payload.clone(), *bit_len));
}
entries.sort_by_key(|(t, _, _)| *t);
for (tag, payload, bit_len) in entries {
w.write_bits(u64::from(tag), 5);
write_varint(w, bit_len as u32)?;
re_emit_bits(w, &payload, bit_len)?;
}
Ok(())
}
pub fn read(r: &mut BitReader, key_index_width: u8, n: u8) -> Result<Self, Error> {
let mut section = Self::new_empty();
let mut last_tag: Option<u8> = None;
loop {
let entry_start = r.save_position();
if r.remaining_bits() < 5 {
break; }
let parse_result: Result<(), Error> = (|| {
let tag = r.read_bits(5)? as u8;
if let Some(prev) = last_tag {
if tag <= prev {
return Err(Error::TlvOrderingViolation { prev, current: tag });
}
}
let bit_len = read_varint(r)? as usize;
if bit_len > r.remaining_bits() {
return Err(Error::TlvLengthExceedsRemaining {
length: bit_len,
remaining: r.remaining_bits(),
});
}
if bit_len == 0 {
return Err(Error::EmptyTlvEntry { tag });
}
match tag {
TLV_USE_SITE_PATH_OVERRIDES => {
let entry = read_use_site_overrides(r, bit_len, key_index_width, n)?;
section.use_site_path_overrides = Some(entry);
}
TLV_FINGERPRINTS => {
let entry = read_fingerprints(r, bit_len, key_index_width, n)?;
section.fingerprints = Some(entry);
}
TLV_PUBKEYS => {
let entry = read_pubkeys(r, bit_len, key_index_width, n)?;
section.pubkeys = Some(entry);
}
TLV_ORIGIN_PATH_OVERRIDES => {
let entry = read_origin_path_overrides(r, bit_len, key_index_width, n)?;
section.origin_path_overrides = Some(entry);
}
_ => {
let mut sub = BitWriter::new();
let mut remaining = bit_len;
while remaining > 0 {
let chunk = remaining.min(8);
let bits = r.read_bits(chunk)?;
sub.write_bits(bits, chunk);
remaining -= chunk;
}
let payload = sub.into_bytes();
section.unknown.push((tag, payload, bit_len));
}
}
last_tag = Some(tag);
Ok(())
})();
match parse_result {
Ok(()) => continue,
Err(e) => {
r.restore_position(entry_start);
let remaining_at_entry_start = r.remaining_bits();
if remaining_at_entry_start <= 7 {
break;
}
return Err(e);
}
}
}
Ok(section)
}
}
fn read_sparse_tlv_idx(
r: &mut BitReader,
key_index_width: u8,
n: u8,
last_idx: Option<u8>,
) -> Result<u8, Error> {
let idx = r.read_bits(key_index_width as usize)? as u8;
if idx >= n {
return Err(Error::PlaceholderIndexOutOfRange { idx, n });
}
if let Some(prev) = last_idx {
if idx <= prev {
return Err(Error::OverrideOrderViolation { prev, current: idx });
}
}
Ok(idx)
}
fn read_sparse_tlv_body<T, F>(
r: &mut BitReader,
bit_len: usize,
tag: u8,
key_index_width: u8,
n: u8,
mut read_value: F,
) -> Result<Vec<(u8, T)>, Error>
where
F: FnMut(&mut BitReader) -> Result<T, Error>,
{
let start = r.bit_position();
let saved_limit = r.save_bit_limit();
r.set_bit_limit_for_scope(start + bit_len);
let mut entries: Vec<(u8, T)> = Vec::new();
let mut last_idx: Option<u8> = None;
let result = (|| -> Result<(), Error> {
while r.bit_position() - start < bit_len {
let idx = read_sparse_tlv_idx(r, key_index_width, n, last_idx)?;
let value = read_value(r)?;
last_idx = Some(idx);
entries.push((idx, value));
}
Ok(())
})();
r.restore_bit_limit(saved_limit);
result?;
if entries.is_empty() {
return Err(Error::EmptyTlvEntry { tag });
}
Ok(entries)
}
fn read_use_site_overrides(
r: &mut BitReader,
bit_len: usize,
key_index_width: u8,
n: u8,
) -> Result<Vec<(u8, UseSitePath)>, Error> {
read_sparse_tlv_body(
r,
bit_len,
TLV_USE_SITE_PATH_OVERRIDES,
key_index_width,
n,
UseSitePath::read,
)
}
fn read_fingerprints(
r: &mut BitReader,
bit_len: usize,
key_index_width: u8,
n: u8,
) -> Result<Vec<(u8, [u8; 4])>, Error> {
read_sparse_tlv_body(r, bit_len, TLV_FINGERPRINTS, key_index_width, n, |r| {
let mut fp = [0u8; 4];
for byte in &mut fp {
*byte = r.read_bits(8)? as u8;
}
Ok(fp)
})
}
fn read_pubkeys(
r: &mut BitReader,
bit_len: usize,
key_index_width: u8,
n: u8,
) -> Result<Vec<(u8, [u8; 65])>, Error> {
read_sparse_tlv_body(r, bit_len, TLV_PUBKEYS, key_index_width, n, |r| {
let mut xpub = [0u8; 65];
for byte in &mut xpub {
*byte = r.read_bits(8)? as u8;
}
Ok(xpub)
})
}
fn read_origin_path_overrides(
r: &mut BitReader,
bit_len: usize,
key_index_width: u8,
n: u8,
) -> Result<Vec<(u8, OriginPath)>, Error> {
read_sparse_tlv_body(
r,
bit_len,
TLV_ORIGIN_PATH_OVERRIDES,
key_index_width,
n,
OriginPath::read,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::origin_path::PathComponent;
#[test]
fn empty_tlv_section_round_trip() {
let s = TlvSection::new_empty();
assert!(s.is_empty());
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
assert_eq!(w.bit_len(), 0);
}
#[test]
fn use_site_path_override_round_trip() {
let mut s = TlvSection::new_empty();
s.use_site_path_overrides = Some(vec![(
1u8,
UseSitePath {
multipath: None,
wildcard_hardened: true,
},
)]);
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
let bit_len = w.bit_len();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let s2 = TlvSection::read(&mut r, 2, 3).unwrap();
assert_eq!(s2, s);
assert_eq!(r.bit_position(), bit_len);
}
#[test]
fn fingerprint_round_trip() {
let mut s = TlvSection::new_empty();
s.fingerprints = Some(vec![
(0u8, [0xaa, 0xbb, 0xcc, 0xdd]),
(2u8, [0x11, 0x22, 0x33, 0x44]),
]);
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let s2 = TlvSection::read(&mut r, 2, 3).unwrap();
assert_eq!(s2, s);
}
#[test]
fn pubkeys_round_trip() {
let mut xpub_a = [0u8; 65];
for (i, b) in xpub_a.iter_mut().enumerate() {
*b = i as u8;
}
let mut xpub_b = [0u8; 65];
for (i, b) in xpub_b.iter_mut().enumerate() {
*b = (0xff - i as u8) ^ 0x5a;
}
let mut s = TlvSection::new_empty();
s.pubkeys = Some(vec![(0u8, xpub_a), (2u8, xpub_b)]);
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
let bit_len = w.bit_len();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let s2 = TlvSection::read(&mut r, 2, 3).unwrap();
assert_eq!(s2, s);
assert_eq!(r.bit_position(), bit_len);
}
#[test]
fn origin_path_overrides_round_trip() {
let bip84 = OriginPath {
components: vec![
PathComponent {
hardened: true,
value: 84,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 5,
},
],
};
let bip48 = OriginPath {
components: vec![
PathComponent {
hardened: true,
value: 48,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 0,
},
PathComponent {
hardened: true,
value: 2,
},
],
};
let mut s = TlvSection::new_empty();
s.origin_path_overrides = Some(vec![(0u8, bip84), (1u8, bip48)]);
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
let bit_len = w.bit_len();
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let s2 = TlvSection::read(&mut r, 2, 3).unwrap();
assert_eq!(s2, s);
assert_eq!(r.bit_position(), bit_len);
}
#[test]
fn ascending_tag_order_enforced_in_encoder() {
let mut s = TlvSection::new_empty();
s.use_site_path_overrides = Some(vec![(
0,
UseSitePath {
multipath: None,
wildcard_hardened: false,
},
)]);
s.fingerprints = Some(vec![(0, [0u8; 4])]);
s.pubkeys = Some(vec![(0, [0u8; 65])]);
s.origin_path_overrides = Some(vec![(
0,
OriginPath {
components: vec![PathComponent {
hardened: true,
value: 84,
}],
},
)]);
let mut w = BitWriter::new();
s.write(&mut w, 2).unwrap();
let bytes = w.into_bytes();
let first_tag = (bytes[0] >> 3) & 0x1F;
assert_eq!(first_tag, TLV_USE_SITE_PATH_OVERRIDES);
}
#[test]
fn pubkeys_ordering_violation_rejected_at_encoder() {
let mut s = TlvSection::new_empty();
s.pubkeys = Some(vec![(1u8, [0u8; 65]), (0u8, [0u8; 65])]);
let mut w = BitWriter::new();
let result = s.write(&mut w, 2);
assert!(matches!(
result,
Err(Error::OverrideOrderViolation {
prev: 1,
current: 0
})
));
}
#[test]
fn pubkeys_ordering_violation_rejected_at_decoder() {
let mut sub = BitWriter::new();
sub.write_bits(1, 2);
for _ in 0..65 {
sub.write_bits(0, 8);
}
sub.write_bits(1, 2);
for _ in 0..65 {
sub.write_bits(0, 8);
}
let bit_len = sub.bit_len();
let payload_bytes = sub.into_bytes();
let mut w = BitWriter::new();
w.write_bits(u64::from(TLV_PUBKEYS), 5);
write_varint(&mut w, bit_len as u32).unwrap();
re_emit_bits(&mut w, &payload_bytes, bit_len).unwrap();
let total_bit_len = w.bit_len();
let bytes = w.into_bytes();
let mut r = BitReader::with_bit_limit(&bytes, total_bit_len);
let result = TlvSection::read(&mut r, 2, 3);
assert!(matches!(
result,
Err(Error::OverrideOrderViolation {
prev: 1,
current: 1
})
));
}
#[test]
fn read_sparse_tlv_idx_out_of_range() {
let mut sub = BitWriter::new();
sub.write_bits(3, 2);
let bit_len = sub.bit_len();
let bytes = sub.into_bytes();
let mut r = BitReader::with_bit_limit(&bytes, bit_len);
let result = read_sparse_tlv_idx(&mut r, 2, 2, None);
assert!(matches!(
result,
Err(Error::PlaceholderIndexOutOfRange { idx: 3, n: 2 })
));
}
#[test]
fn read_sparse_tlv_idx_non_ascending() {
let mut sub = BitWriter::new();
sub.write_bits(0, 2);
let bit_len = sub.bit_len();
let bytes = sub.into_bytes();
let mut r = BitReader::with_bit_limit(&bytes, bit_len);
let result = read_sparse_tlv_idx(&mut r, 2, 3, Some(1));
assert!(matches!(
result,
Err(Error::OverrideOrderViolation {
prev: 1,
current: 0
})
));
}
#[test]
fn empty_pubkeys_vec_rejected_at_encoder() {
let mut s = TlvSection::new_empty();
s.pubkeys = Some(Vec::new());
let mut w = BitWriter::new();
let result = s.write(&mut w, 2);
assert!(matches!(
result,
Err(Error::EmptyTlvEntry { tag }) if tag == TLV_PUBKEYS
));
}
#[test]
fn empty_origin_path_overrides_vec_rejected_at_encoder() {
let mut s = TlvSection::new_empty();
s.origin_path_overrides = Some(Vec::new());
let mut w = BitWriter::new();
let result = s.write(&mut w, 2);
assert!(matches!(
result,
Err(Error::EmptyTlvEntry { tag }) if tag == TLV_ORIGIN_PATH_OVERRIDES
));
}
fn craft_inflated_tlv_wire(
tag: u8,
idx: u8,
idx_width: u8,
record_payload_bits: &[(u64, usize)],
slack_bits: usize,
) -> (Vec<u8>, usize) {
let mut w = BitWriter::new();
w.write_bits(u64::from(tag), 5);
let actual_record_bits: usize =
(idx_width as usize) + record_payload_bits.iter().map(|(_, n)| n).sum::<usize>();
let declared_bit_len = actual_record_bits + slack_bits;
write_varint(&mut w, declared_bit_len as u32).unwrap();
w.write_bits(u64::from(idx), idx_width as usize);
for (val, bits) in record_payload_bits {
w.write_bits(*val, *bits);
}
for _ in 0..slack_bits {
w.write_bits(0, 1);
}
let bit_len = w.bit_len();
(w.into_bytes(), bit_len)
}
#[test]
fn fingerprints_with_trailing_slack_rejected() {
let (bytes, total_bits) =
craft_inflated_tlv_wire(TLV_FINGERPRINTS, 0, 1, &[(0xDEAD_BEEF, 32)], 4);
let mut r = BitReader::with_bit_limit(&bytes, total_bits);
let result = TlvSection::read(&mut r, 1, 1);
assert!(
result.is_err(),
"trailing slack must be rejected, got {:?}",
result
);
}
#[test]
fn pubkeys_with_trailing_slack_rejected() {
let payload: Vec<(u64, usize)> = (0..65).map(|_i| (0x42u64, 8)).collect();
let (bytes, total_bits) = craft_inflated_tlv_wire(TLV_PUBKEYS, 0, 1, &payload, 3);
let mut r = BitReader::with_bit_limit(&bytes, total_bits);
let result = TlvSection::read(&mut r, 1, 1);
assert!(
result.is_err(),
"trailing slack must be rejected, got {:?}",
result
);
}
#[test]
fn use_site_path_overrides_with_trailing_slack_rejected() {
let mut path_w = BitWriter::new();
UseSitePath::standard_multipath()
.write(&mut path_w)
.unwrap();
let path_bit_len = path_w.bit_len();
let path_bytes = path_w.into_bytes();
let mut path_record: Vec<(u64, usize)> = Vec::new();
let mut br = BitReader::new(&path_bytes);
let mut consumed = 0;
while consumed < path_bit_len {
let chunk = (path_bit_len - consumed).min(8);
path_record.push((br.read_bits(chunk).unwrap(), chunk));
consumed += chunk;
}
let (bytes, total_bits) =
craft_inflated_tlv_wire(TLV_USE_SITE_PATH_OVERRIDES, 0, 1, &path_record, 2);
let mut r = BitReader::with_bit_limit(&bytes, total_bits);
let result = TlvSection::read(&mut r, 1, 1);
assert!(
result.is_err(),
"trailing slack must be rejected, got {:?}",
result
);
}
#[test]
fn origin_path_overrides_with_trailing_slack_rejected() {
let (bytes, total_bits) =
craft_inflated_tlv_wire(TLV_ORIGIN_PATH_OVERRIDES, 0, 1, &[(0, 4)], 5);
let mut r = BitReader::with_bit_limit(&bytes, total_bits);
let result = TlvSection::read(&mut r, 1, 1);
assert!(
result.is_err(),
"trailing slack must be rejected, got {:?}",
result
);
}
}