use crate::error::{Error, Result};
use std::fmt;
pub type BtiResult<T> = Result<T>;
#[derive(Debug, Clone)]
pub enum BtiError {
Parse(String),
InvalidNodeStructure(String),
NavigationError(String),
InvalidNodeType(u8),
MaxDepthExceeded(usize),
InvalidByteComparableKey(String),
CorruptedTrie(String),
MissingComponent(String),
}
impl fmt::Display for BtiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BtiError::Parse(msg) => write!(f, "BTI parse error: {}", msg),
BtiError::InvalidNodeStructure(msg) => write!(f, "Invalid BTI node structure: {}", msg),
BtiError::NavigationError(msg) => write!(f, "BTI navigation error: {}", msg),
BtiError::InvalidNodeType(node_type) => {
write!(f, "Invalid BTI trie node type: 0x{:02X}", node_type)
}
BtiError::MaxDepthExceeded(depth) => {
write!(f, "BTI trie depth exceeded maximum: {}", depth)
}
BtiError::InvalidByteComparableKey(key) => {
write!(f, "Invalid byte-comparable key: {}", key)
}
BtiError::CorruptedTrie(msg) => {
write!(f, "Corrupted BTI trie structure: {}", msg)
}
BtiError::MissingComponent(component) => {
write!(f, "Missing BTI component: {}", component)
}
}
}
}
impl std::error::Error for BtiError {}
impl From<BtiError> for Error {
fn from(err: BtiError) -> Self {
Error::Parse(format!("BTI error: {}", err))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BtiNodeType {
PayloadOnly,
Single,
Sparse,
Dense,
}
impl BtiNodeType {
pub fn expected_children_range(&self) -> (usize, Option<usize>) {
match self {
BtiNodeType::PayloadOnly => (0, Some(0)),
BtiNodeType::Single => (1, Some(1)),
BtiNodeType::Sparse => (2, Some(256)), BtiNodeType::Dense => (1, Some(256)), }
}
}
impl fmt::Display for BtiNodeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BtiNodeType::PayloadOnly => write!(f, "PayloadOnly"),
BtiNodeType::Single => write!(f, "Single"),
BtiNodeType::Sparse => write!(f, "Sparse"),
BtiNodeType::Dense => write!(f, "Dense"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SizedPointer {
pub distance: u64,
pub size: u8,
}
impl SizedPointer {
pub fn new(distance: u64) -> Self {
let size = if distance <= 0xFF {
1
} else if distance <= 0xFFFF {
2
} else if distance <= 0xFFFF_FFFF {
4
} else {
8
};
Self { distance, size }
}
pub fn to_bytes(&self) -> Vec<u8> {
match self.size {
1 => vec![self.distance as u8],
2 => (self.distance as u16).to_be_bytes().to_vec(),
4 => (self.distance as u32).to_be_bytes().to_vec(),
8 => self.distance.to_be_bytes().to_vec(),
_ => panic!("Invalid pointer size: {}", self.size),
}
}
pub fn from_bytes(data: &[u8], size: u8) -> BtiResult<Self> {
let distance = match size {
1 if !data.is_empty() => data[0] as u64,
2 if data.len() >= 2 => u16::from_be_bytes([data[0], data[1]]) as u64,
4 if data.len() >= 4 => u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as u64,
8 if data.len() >= 8 => u64::from_be_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]),
_ => {
return Err(BtiError::Parse(format!(
"Invalid pointer size {} or insufficient data",
size
))
.into());
}
};
Ok(Self { distance, size })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Transition {
pub byte: u8,
pub child: SizedPointer,
}
impl Transition {
pub fn new(byte: u8, child: SizedPointer) -> Self {
Self { byte, child }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PayloadRef {
pub offset: u64,
pub length: u32,
pub checksum: Option<u32>,
}
impl PayloadRef {
pub fn new(offset: u64, length: u32) -> Self {
Self {
offset,
length,
checksum: None,
}
}
pub fn with_checksum(mut self, checksum: u32) -> Self {
self.checksum = Some(checksum);
self
}
}
#[derive(Debug, Clone)]
pub struct BtiNode {
pub node_type: BtiNodeType,
pub level: u16,
pub key_prefix: Vec<u8>,
pub data: BtiNodeData,
}
#[derive(Debug, Clone)]
pub enum BtiNodeData {
PayloadOnly { payload: PayloadRef },
Single { transition: Transition },
Sparse { transitions: Vec<Transition> },
Dense {
start_byte: u8,
children: Vec<SizedPointer>,
},
}
impl BtiNode {
pub fn payload_only(level: u16, key_prefix: Vec<u8>, payload: PayloadRef) -> Self {
Self {
node_type: BtiNodeType::PayloadOnly,
level,
key_prefix,
data: BtiNodeData::PayloadOnly { payload },
}
}
pub fn single(level: u16, key_prefix: Vec<u8>, transition: Transition) -> Self {
Self {
node_type: BtiNodeType::Single,
level,
key_prefix,
data: BtiNodeData::Single { transition },
}
}
pub fn sparse(level: u16, key_prefix: Vec<u8>, mut transitions: Vec<Transition>) -> Self {
transitions.sort_by_key(|t| t.byte);
Self {
node_type: BtiNodeType::Sparse,
level,
key_prefix,
data: BtiNodeData::Sparse { transitions },
}
}
pub fn dense(
level: u16,
key_prefix: Vec<u8>,
start_byte: u8,
children: Vec<SizedPointer>,
) -> Self {
Self {
node_type: BtiNodeType::Dense,
level,
key_prefix,
data: BtiNodeData::Dense {
start_byte,
children,
},
}
}
pub fn find_child(&self, byte: u8) -> Option<&SizedPointer> {
match &self.data {
BtiNodeData::PayloadOnly { .. } => None,
BtiNodeData::Single { transition } => {
if transition.byte == byte {
Some(&transition.child)
} else {
None
}
}
BtiNodeData::Sparse { transitions } => {
transitions
.binary_search_by_key(&byte, |t| t.byte)
.ok()
.map(|idx| &transitions[idx].child)
}
BtiNodeData::Dense {
start_byte,
children,
} => {
if byte >= *start_byte && (byte as usize) < (*start_byte as usize + children.len())
{
let index = byte as usize - *start_byte as usize;
children.get(index)
} else {
None
}
}
}
}
pub fn get_transitions(&self) -> Vec<&Transition> {
match &self.data {
BtiNodeData::PayloadOnly { .. } => Vec::new(),
BtiNodeData::Single { transition } => vec![transition],
BtiNodeData::Sparse { transitions } => transitions.iter().collect(),
BtiNodeData::Dense {
start_byte: _,
children: _,
} => {
Vec::new() }
}
}
pub fn get_payload(&self) -> Option<&PayloadRef> {
match &self.data {
BtiNodeData::PayloadOnly { payload } => Some(payload),
_ => None,
}
}
pub fn is_leaf(&self) -> bool {
matches!(self.data, BtiNodeData::PayloadOnly { .. })
}
pub fn child_count(&self) -> usize {
match &self.data {
BtiNodeData::PayloadOnly { .. } => 0,
BtiNodeData::Single { .. } => 1,
BtiNodeData::Sparse { transitions } => transitions.len(),
BtiNodeData::Dense { children, .. } => children.len(),
}
}
pub fn validate(&self) -> BtiResult<()> {
let expected_range = self.node_type.expected_children_range();
let child_count = self.child_count();
if child_count < expected_range.0 {
return Err(BtiError::InvalidNodeStructure(format!(
"Node type {} has {} children, expected at least {}",
self.node_type, child_count, expected_range.0
))
.into());
}
if let Some(max) = expected_range.1 {
if child_count > max {
return Err(BtiError::InvalidNodeStructure(format!(
"Node type {} has {} children, expected at most {}",
self.node_type, child_count, max
))
.into());
}
}
match &self.data {
BtiNodeData::Sparse { transitions } => {
for window in transitions.windows(2) {
if window[0].byte >= window[1].byte {
return Err(BtiError::InvalidNodeStructure(
"Sparse node transitions not sorted".to_string(),
)
.into());
}
}
}
BtiNodeData::Dense {
start_byte,
children,
} => {
let end_byte = *start_byte as usize + children.len();
if end_byte > 256 {
return Err(BtiError::InvalidNodeStructure(
"Dense node range overflows byte values".to_string(),
)
.into());
}
}
_ => {} }
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TrieNavigator {
pub current_offset: u64,
pub path: Vec<u8>,
pub visited_offsets: std::collections::HashSet<u64>,
}
impl TrieNavigator {
pub fn new(root_offset: u64) -> Self {
Self {
current_offset: root_offset,
path: Vec::new(),
visited_offsets: std::collections::HashSet::new(),
}
}
pub fn navigate_to_child(&mut self, byte: u8, child_pointer: &SizedPointer) -> BtiResult<()> {
let target_offset = self.current_offset + child_pointer.distance;
if self.visited_offsets.contains(&target_offset) {
return Err(
BtiError::NavigationError("Cycle detected in trie navigation".to_string()).into(),
);
}
self.visited_offsets.insert(self.current_offset);
self.current_offset = target_offset;
self.path.push(byte);
Ok(())
}
pub fn current_path(&self) -> &[u8] {
&self.path
}
pub fn reset(&mut self, root_offset: u64) {
self.current_offset = root_offset;
self.path.clear();
self.visited_offsets.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sized_pointer() {
let small = SizedPointer::new(100);
assert_eq!(small.size, 1);
assert_eq!(small.to_bytes(), vec![100]);
let large = SizedPointer::new(0x10000);
assert_eq!(large.size, 4);
assert_eq!(large.to_bytes(), vec![0x00, 0x01, 0x00, 0x00]);
}
#[test]
fn test_node_creation() {
let payload = PayloadRef::new(1000, 50);
let node = BtiNode::payload_only(0, b"test".to_vec(), payload);
assert_eq!(node.node_type, BtiNodeType::PayloadOnly);
assert_eq!(node.level, 0);
assert_eq!(node.key_prefix, b"test");
assert!(node.is_leaf());
assert_eq!(node.child_count(), 0);
}
#[test]
fn test_sparse_node_search() {
let transitions = vec![
Transition::new(b'a', SizedPointer::new(100)),
Transition::new(b'm', SizedPointer::new(200)),
Transition::new(b'z', SizedPointer::new(300)),
];
let node = BtiNode::sparse(1, Vec::new(), transitions);
assert!(node.find_child(b'a').is_some());
assert!(node.find_child(b'm').is_some());
assert!(node.find_child(b'z').is_some());
assert!(node.find_child(b'b').is_none());
assert_eq!(node.child_count(), 3);
}
#[test]
fn test_dense_node_lookup() {
let children = vec![
SizedPointer::new(100),
SizedPointer::new(200),
SizedPointer::new(300),
];
let node = BtiNode::dense(1, Vec::new(), b'a', children);
assert!(node.find_child(b'a').is_some());
assert!(node.find_child(b'b').is_some());
assert!(node.find_child(b'c').is_some());
assert!(node.find_child(b'd').is_none());
assert!(node.find_child(b'@').is_none()); }
#[test]
fn test_node_validation() {
let payload_node = BtiNode::payload_only(0, Vec::new(), PayloadRef::new(0, 10));
assert!(payload_node.validate().is_ok());
let _invalid_sparse = BtiNode::sparse(
1,
Vec::new(),
vec![Transition::new(b'a', SizedPointer::new(100))],
);
}
#[test]
fn test_trie_navigator() {
let mut nav = TrieNavigator::new(1000);
assert_eq!(nav.current_offset, 1000);
assert_eq!(nav.current_path(), &[] as &[u8]);
let pointer = SizedPointer::new(100);
nav.navigate_to_child(b'a', &pointer).unwrap();
assert_eq!(nav.current_offset, 1100);
assert_eq!(nav.current_path(), b"a");
}
}