use buffertk::{Packable, Unpackable};
use prototk::{FieldNumber, Tag, WireType};
use crate::Error;
use crate::bit_vector::BitVector as BitVectorTrait;
use crate::bit_vector::rrr::BitVector;
use crate::builder::{Builder, Helper, parse_one_field_bytes};
use crate::encoder::Encoder;
use super::WaveletTree as WaveletTreeTrait;
const CONTAINER_TAG: u32 = 1;
const NODE_TAG: u32 = 2;
#[derive(Clone, Debug, Default, prototk_derive::Message)]
struct Capstone {
#[prototk(1, fixed64)]
root_offset: u64,
}
#[derive(Clone, Debug, Default, prototk_derive::Message)]
struct Root {
#[prototk(3, uint64)]
encoder_start: u64,
#[prototk(4, uint64)]
encoder_limit: u64,
#[prototk(5, uint64)]
length: u64,
#[prototk(6, uint64)]
tree: u64,
}
#[derive(Clone, Debug, Default, prototk_derive::Message)]
struct Node {
#[prototk(7, uint64)]
length: u64,
#[prototk(8, uint64)]
start: u64,
#[prototk(9, uint64)]
limit: u64,
#[prototk(10, uint64)]
left: u64,
#[prototk(11, uint64)]
right: u64,
}
fn all_sized(mut iter: impl Iterator<Item = (u32, u8)>) -> bool {
iter.all(|s| s.1 > 0)
}
fn all_zero(mut iter: impl Iterator<Item = (u32, u8)>) -> bool {
iter.all(|s| s.1 == 0)
}
pub struct WaveletTree<'a, E: Encoder> {
encoder: E,
root: Root,
tree: &'a [u8],
nodes: Vec<(Node, BitVector<'a>)>,
}
impl<'a, E: Encoder> WaveletTree<'a, E> {
fn load_root(tree: &[u8]) -> Option<Root> {
if tree.len() < 9 {
return None;
}
let capstone = Capstone::unpack(&tree[tree.len() - 9..]).ok()?.0;
if (tree.len() as u64) < capstone.root_offset {
return None;
}
let root_offset: usize = capstone.root_offset.try_into().ok()?;
Some(Root::unpack(&tree[root_offset..tree.len() - 9]).ok()?.0)
}
fn load_node(&self, offset: u64) -> Option<Node> {
if offset == 0 {
return None;
}
if let Some((node, _)) = self.load_node_and_bit_vector(offset) {
return Some(node.clone());
}
if offset >= self.tree.len() as u64 {
return None;
}
let offset: usize = offset.try_into().ok()?;
let (tag, value, _) = parse_one_field_bytes(&self.tree[offset..])?;
if tag
!= (Tag {
field_number: FieldNumber::must(NODE_TAG),
wire_type: WireType::LengthDelimited,
})
{
return None;
}
Some(Node::unpack(value).ok()?.0)
}
fn load_node_and_bit_vector(&self, offset: u64) -> Option<(&Node, &BitVector<'a>)> {
if offset > 0 {
let index: usize = (offset - 1).try_into().ok()?;
self.nodes.get(index).map(|(n, bv)| (n, bv))
} else {
None
}
}
fn load_nodes(&mut self) {
self.root.tree = self.load_nodes_recursive(self.root.tree).unwrap_or(0);
}
fn load_nodes_recursive(&mut self, offset: u64) -> Option<u64> {
let mut node = self.load_node(offset)?;
if node.left != 0 {
node.left = self.load_nodes_recursive(node.left)?;
}
if node.right != 0 {
node.right = self.load_nodes_recursive(node.right)?;
}
let start: usize = node.start.try_into().ok()?;
let limit: usize = node.limit.try_into().ok()?;
if start > limit || limit > self.tree.len() {
return None;
}
let bv = BitVector::parse(&self.tree[start..limit])
.ok()
.map(|x| x.0)?;
self.nodes.push((node, bv));
Some(self.nodes.len() as u64)
}
fn construct_from_iter<H: Helper>(
builder: &mut Builder<H>,
intermediate: &mut Vec<(u32, u8)>,
iter: impl Iterator<Item = (u32, u8)> + Clone,
) -> Result<u64, Error> {
if iter.clone().next().is_some() && all_sized(iter.clone()) {
intermediate.clear();
for x in iter {
intermediate.push(x);
}
Self::construct_recursive(builder, intermediate)
} else if all_zero(iter) {
Ok(0)
} else {
Err(Error::LogicError(
"wavelet tree should be all zero or all sized",
))
}
}
fn construct_recursive<H: Helper>(
builder: &mut Builder<H>,
symbols: &[(u32, u8)],
) -> Result<u64, Error> {
let (left, right) = if !all_zero(symbols.iter().copied()) {
let lhs_iter = symbols
.iter()
.filter(|s| s.0 & 1 == 0)
.map(|s| (s.0 >> 1, s.1 - 1));
let rhs_iter = symbols
.iter()
.filter(|s| s.0 & 1 == 1)
.map(|s| (s.0 >> 1, s.1 - 1));
let mut intermediate = Vec::with_capacity(symbols.len());
let left = Self::construct_from_iter(builder, &mut intermediate, lhs_iter)?;
let right = Self::construct_from_iter(builder, &mut intermediate, rhs_iter)?;
(left, right)
} else {
(0, 0)
};
let this: Vec<bool> = symbols.iter().map(|s| s.0 & 1 == 1).collect();
let length: u64 = symbols.len() as u64;
let start: u64 = builder.relative_len() as u64;
BitVector::construct(&this, builder)?;
let limit: u64 = builder.relative_len() as u64;
let node = Node {
length,
start,
limit,
left,
right,
};
builder.append_packable(FieldNumber::must(NODE_TAG), &node);
Ok(limit)
}
fn recursive_access(&self, mut e: u32, mut sz: u8, node_offset: u64, x: usize) -> Option<u32> {
if node_offset == 0 {
self.encoder.decode(e, sz)
} else {
let (node, bv) = self.load_node_and_bit_vector(node_offset)?;
let bit = bv.access(x)?;
let (x, node_offset) = if bit {
e |= 1 << sz;
(bv.rank(x)?, node.right)
} else {
(bv.rank0(x)?, node.left)
};
sz += 1;
self.recursive_access(e, sz, node_offset, x)
}
}
fn recursive_rank(
&self,
e: u32,
sz: u8,
node: &Node,
bv: &BitVector<'a>,
x: usize,
) -> Option<usize> {
if sz == 0 {
return None;
}
let (this_rank, next_node_offset) = if e & 1 != 0 {
(bv.rank(x)?, node.right)
} else {
((x - bv.rank(x)?), node.left)
};
if sz == 1 {
Some(this_rank)
} else if next_node_offset != 0 {
let (node, bv) = self.load_node_and_bit_vector(next_node_offset)?;
self.recursive_rank(e >> 1, sz - 1, node, bv, this_rank)
} else {
None
}
}
fn recursive_select(
&self,
e: u32,
sz: u8,
node: &Node,
bv: &BitVector<'a>,
x: usize,
) -> Option<usize> {
if sz == 0 {
return None;
}
let x = if sz > 1 {
let node_offset = if e & 1 != 0 { node.right } else { node.left };
let (inner, bv) = self.load_node_and_bit_vector(node_offset)?;
self.recursive_select(e >> 1, sz - 1, inner, bv, x)?
} else {
x
};
if e & 1 != 0 {
bv.select(x)
} else {
bv.select0(x)
}
}
}
impl<E: Encoder + Packable> WaveletTreeTrait for WaveletTree<'_, E> {
fn construct<H: Helper>(symbols: &[u32], builder: &mut Builder<'_, H>) -> Result<(), Error> {
let mut builder = builder.sub(FieldNumber::must(CONTAINER_TAG));
let enc = E::construct(symbols);
let encoder_start = builder.relative_len() as u64;
builder.append_raw_packable(&enc);
let encoder_limit = builder.relative_len() as u64;
let mut encoded: Vec<(u32, u8)> = Vec::with_capacity(symbols.len());
for sym in symbols.iter() {
encoded.push(enc.encode(*sym).ok_or(Error::InvalidEncoder)?);
}
let length = encoded.len() as u64;
drop(enc);
let tree = Self::construct_recursive(&mut builder, &encoded)?;
let root = Root {
encoder_start,
encoder_limit,
length,
tree,
};
let root_offset: u64 = builder.relative_len() as u64;
builder.append_raw_packable(&root);
let capstone = Capstone { root_offset };
builder.append_raw_packable(&capstone);
Ok(())
}
fn len(&self) -> usize {
self.root.length as usize
}
fn access(&self, x: usize) -> Option<u32> {
self.recursive_access(0, 0, self.root.tree, x)
}
fn rank_q(&self, q: u32, x: usize) -> Option<usize> {
let (node, bv) = self.load_node_and_bit_vector(self.root.tree)?;
let (e, sz) = self.encoder.encode(q)?;
self.recursive_rank(e, sz, node, bv, x)
}
fn select_q(&self, q: u32, x: usize) -> Option<usize> {
let (node, bv) = self.load_node_and_bit_vector(self.root.tree)?;
let (e, sz) = self.encoder.encode(q)?;
self.recursive_select(e, sz, node, bv, x)
}
}
impl<E: Encoder + std::fmt::Debug> std::fmt::Debug for WaveletTree<'_, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
f.debug_struct("WaveletTree")
.field("encoder", &self.encoder.symbols())
.field("tree", &self.tree.len())
.finish()
}
}
impl<'a, E: Encoder + Unpackable<'a>> Unpackable<'a> for WaveletTree<'a, E> {
type Error = Error;
fn unpack<'b: 'a>(buf: &'b [u8]) -> Result<(Self, &'b [u8]), Self::Error> {
let (tag, value, remain) = parse_one_field_bytes(buf).ok_or(Error::InvalidWaveletTree)?;
if tag
!= (Tag {
field_number: FieldNumber::must(CONTAINER_TAG),
wire_type: WireType::LengthDelimited,
})
{
return Err(Error::InvalidWaveletTree);
}
let root = Self::load_root(value).ok_or(Error::InvalidWaveletTree)?;
if root.encoder_start > root.encoder_limit || root.encoder_limit > value.len() as u64 {
return Err(Error::InvalidWaveletTree);
}
let encoder_start: usize = root.encoder_start.try_into()?;
let encoder_limit: usize = root.encoder_limit.try_into()?;
let encoder = E::unpack(&value[encoder_start..encoder_limit])
.map_err(|_| Error::InvalidEncoder)?
.0;
let tree = value;
let nodes = vec![];
let mut wt = WaveletTree {
encoder,
root,
tree,
nodes,
};
wt.load_nodes();
Ok((wt, remain))
}
}