use crate::num::NonZeroPow2Usize;
use miniscript::iter::{Tree, TreeLike};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct BTreeSlice<'a, A>(&'a [A]);
impl<'a, A> BTreeSlice<'a, A> {
pub fn from_slice(slice: &'a [A]) -> Self {
Self(slice)
}
}
impl<A: Clone> BTreeSlice<'_, A> {
pub fn fold<F>(self, f: F) -> Option<A>
where
F: Fn(A, A) -> A,
{
if self.0.is_empty() {
return None;
}
let mut output = vec![];
for item in self.post_order_iter() {
match item.child_indices.len() {
2 => {
let r = output.pop().unwrap();
let l = output.pop().unwrap();
output.push(f(l, r));
}
n => {
debug_assert_eq!(n, 0);
debug_assert_eq!(item.node.0.len(), 1);
output.push(item.node.0[0].clone());
}
}
}
debug_assert_eq!(output.len(), 1);
output.pop()
}
}
impl<A: Clone> TreeLike for BTreeSlice<'_, A> {
fn as_node(&self) -> Tree<Self> {
match self.0.len() {
0 | 1 => Tree::Nullary,
n => {
let next_pow2 = n.next_power_of_two();
debug_assert!(0 < next_pow2 / 2);
debug_assert!(0 < n - next_pow2 / 2);
let half = n - next_pow2 / 2;
let left = BTreeSlice::from_slice(&self.0[..half]);
let right = BTreeSlice::from_slice(&self.0[half..]);
Tree::Binary(left, right)
}
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Unfolder<A> {
tree: A,
n: usize,
}
impl<A: Clone> Unfolder<A> {
pub fn new(tree: A, n: usize) -> Self {
Self { tree, n }
}
pub fn unfold<F>(self, f: F) -> Option<Vec<A>>
where
F: Fn(A) -> Option<(A, A)>,
{
let n = self.n;
let mut stack = vec![self];
let mut output = Vec::with_capacity(n);
while let Some(top) = stack.pop() {
match top.n {
0 => continue,
1 => output.push(top.tree),
_ => {
let (left, right) = f(top.tree.clone())?;
let next_pow2 = top.n.next_power_of_two();
let half = top.n - next_pow2 / 2;
stack.push(Self::new(right, top.n.saturating_sub(half)));
stack.push(Self::new(left, half));
}
}
}
debug_assert_eq!(output.len(), n);
Some(output)
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Partition<'a, A> {
Leaf {
slice: &'a [A],
size: usize,
},
Parent {
slice: &'a [A],
bound: NonZeroPow2Usize,
},
}
impl<'a, A> Partition<'a, A> {
pub fn from_slice(slice: &'a [A], bound: NonZeroPow2Usize) -> Self {
assert!(
slice.len() < bound.get(),
"The slice must be shorter than the given bound"
);
match bound {
NonZeroPow2Usize::TWO => Self::Leaf { slice, size: 1 },
_ => Self::Parent { slice, bound },
}
}
}
impl<A: Clone> Partition<'_, A> {
pub fn is_complete(&self) -> bool {
match self {
Partition::Leaf { slice, size } => slice.len() == *size,
Partition::Parent { slice, bound } => slice.len() + 1 == bound.get(),
}
}
pub fn fold<B, F, G>(self, f: F, g: G) -> B
where
F: Fn(&[A], usize) -> B,
G: Fn(B, B) -> B,
{
let mut output = vec![];
for item in self.post_order_iter() {
match item.node {
Partition::Leaf { slice, size } => {
output.push(f(slice, size));
}
Partition::Parent { .. } => {
let r = output.pop().unwrap();
let l = output.pop().unwrap();
output.push(g(l, r));
}
}
}
debug_assert_eq!(output.len(), 1);
output.pop().unwrap()
}
}
#[rustfmt::skip]
impl<A: Clone> TreeLike for Partition<'_, A> {
fn as_node(&self) -> Tree<Self> {
match self {
Self::Leaf {..} => Tree::Nullary,
Self::Parent { slice, bound } => {
debug_assert!(NonZeroPow2Usize::TWO < *bound);
let smaller_bound = bound.checked_div2().unwrap();
let (l, r) = if slice.len() < smaller_bound.get() {
(
Self::Leaf { slice: &[], size: smaller_bound.get() },
Self::from_slice(slice, smaller_bound),
)
} else {
(
Self::Leaf { slice: &slice[..smaller_bound.get()], size: smaller_bound.get() },
Self::from_slice(&slice[smaller_bound.get()..], smaller_bound),
)
};
Tree::Binary(l, r)
}
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Combiner<A> {
partition: A,
bound: NonZeroPow2Usize,
}
impl<A: Clone> Combiner<A> {
pub fn new(partition: A, bound: NonZeroPow2Usize) -> Self {
Self { partition, bound }
}
pub fn unfold<F, G, B>(self, f: F, g: G) -> Option<Vec<B>>
where
F: Fn(A, usize) -> Option<Vec<B>>,
G: Fn(A) -> Option<(A, A)>,
{
let mut next = Some(self);
let mut output = vec![];
while let Some(top) = next.take() {
match top.bound.checked_div2() {
Some(smaller_bound) => {
let (block, partition) = g(top.partition)?;
let elements = f(block, smaller_bound.get())?;
debug_assert!(elements.is_empty() || elements.len() == smaller_bound.get());
output.extend(elements);
next = Some(Combiner::new(partition, smaller_bound));
}
None => {
let elements = f(top.partition, 1)?;
debug_assert!(elements.is_empty() || elements.len() == 1);
output.extend(elements);
}
}
}
Some(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pattern::{BasePattern, Pattern};
use crate::str::Identifier;
#[derive(Debug, Clone, PartialEq, Eq)]
enum Tree {
Leaf(u16),
Product(Box<Self>, Box<Self>),
Array(Box<[Self]>),
}
impl Tree {
fn product(left: Self, right: Self) -> Self {
Self::Product(Box::new(left), Box::new(right))
}
fn array<I: IntoIterator<Item = Self>>(iter: I) -> Self {
Self::Array(iter.into_iter().collect())
}
fn as_product(&self) -> Option<(&Self, &Self)> {
match self {
Self::Product(left, right) => Some((left, right)),
_ => None,
}
}
fn as_array(&self) -> Option<&[Self]> {
match self {
Self::Array(elements) => Some(elements.as_ref()),
_ => None,
}
}
}
#[test]
#[rustfmt::skip]
fn fold_btree_slice() {
let slice_output: [(&[&str], &str); 9] = [
(&[], ""),
(&["a"], "a"),
(&["a", "b"], "(ab)"),
(&["a", "b", "c"], "(a(bc))"),
(&["a", "b", "c", "d"], "((ab)(cd))"),
(&["a", "b", "c", "d", "e"], "(a((bc)(de)))"),
(&["a", "b", "c", "d", "e", "f"], "((ab)((cd)(ef)))"),
(&["a", "b", "c", "d", "e", "f", "g"], "((a(bc))((de)(fg)))"),
(&["a", "b", "c", "d", "e", "f", "g", "h"], "(((ab)(cd))((ef)(gh)))"),
];
let concat = |a: String, b: String| format!("({a}{b})");
for (slice, expected_output) in slice_output {
let vector: Vec<_> = slice.iter().map(|s| s.to_string()).collect();
let tree = BTreeSlice::from_slice(&vector);
let output = tree.fold(concat).unwrap_or_default();
assert_eq!(&output, expected_output);
}
}
#[test]
fn unfold_btree_slice() {
let elements = (0..255).map(Tree::Leaf).collect::<Vec<Tree>>();
for n in 0..255 {
let slice = &elements[0..n];
let folded = BTreeSlice::from_slice(slice)
.fold(Tree::product)
.unwrap_or(Tree::Leaf(1337));
let unfolded = Unfolder::new(&folded, n).unfold(Tree::as_product).unwrap();
assert_eq!(unfolded.len(), n);
for i in 0..n {
assert_eq!(&slice[i], unfolded[i]);
}
}
}
#[test]
#[rustfmt::skip]
fn fold_partition() {
let slice_len_output: [(&[&str], usize, &str); 14] = [
(&[], 2, ""),
(&["a"], 2, "a"),
(&[], 4, "(:)"),
(&["a"], 4, "(:a)"),
(&["a", "b"], 4, "(ab:)"),
(&["a", "b", "c"], 4, "(ab:c)"),
(&[], 8, "(:(:))"),
(&["a"], 8, "(:(:a))"),
(&["a", "b"], 8, "(:(ab:))"),
(&["a", "b", "c"], 8, "(:(ab:c))"),
(&["a", "b", "c", "d"], 8, "(abcd:(:))"),
(&["a", "b", "c", "d", "e"], 8, "(abcd:(:e))"),
(&["a", "b", "c", "d", "e", "f"], 8, "(abcd:(ef:))"),
(&["a", "b", "c", "d", "e", "f", "g"], 8, "(abcd:(ef:g))"),
];
let process = |block: &[String], _| block.join("");
let join = |a: String, b: String| format!("({a}:{b})");
for (slice, bound, expected_output) in slice_len_output {
let vector: Vec<_> = slice.iter().map(|s| s.to_string()).collect();
let partition = Partition::from_slice(&vector, NonZeroPow2Usize::new_unchecked(bound));
let output = partition.fold(process, join);
assert_eq!(&output, expected_output);
}
}
#[test]
fn unfold_partition() {
let elements = (0..255).map(Tree::Leaf).collect::<Vec<Tree>>();
let bound = NonZeroPow2Usize::new_unchecked(256);
let pack_block = |block: &[Tree], _size: usize| Tree::array(block.iter().cloned());
let unpack_block = |block: &Tree, _size: usize| block.as_array().map(<[Tree]>::to_vec);
for n in 0..255 {
let slice = &elements[0..n];
let folded = Partition::from_slice(slice, bound).fold(pack_block, Tree::product);
let unfolded = Combiner::new(&folded, bound)
.unfold(unpack_block, Tree::as_product)
.unwrap();
assert_eq!(unfolded.len(), n);
for i in 0..n {
assert_eq!(slice[i], unfolded[i]);
}
}
}
#[test]
fn base_pattern() {
let a = Pattern::Identifier(Identifier::from_str_unchecked("a"));
let b = Pattern::Identifier(Identifier::from_str_unchecked("b"));
let c = Pattern::Identifier(Identifier::from_str_unchecked("c"));
let d = Pattern::Identifier(Identifier::from_str_unchecked("d"));
let a_ = BasePattern::Identifier(Identifier::from_str_unchecked("a"));
let b_ = BasePattern::Identifier(Identifier::from_str_unchecked("b"));
let c_ = BasePattern::Identifier(Identifier::from_str_unchecked("c"));
let d_ = BasePattern::Identifier(Identifier::from_str_unchecked("d"));
let pattern_string = [
(a.clone(), a_.clone()),
(
Pattern::product(a.clone(), b.clone()),
BasePattern::product(a_.clone(), b_.clone()),
),
(Pattern::array([a.clone()]), a_.clone()),
(Pattern::array([Pattern::array([a.clone()])]), a_.clone()),
(
Pattern::array([a.clone(), b.clone()]),
BasePattern::product(a_.clone(), b_.clone()),
),
(
Pattern::array([a.clone(), b.clone(), c.clone()]),
BasePattern::product(a_.clone(), BasePattern::product(b_.clone(), c_.clone())),
),
(
Pattern::array([
Pattern::array([a.clone(), b.clone()]),
Pattern::array([c.clone(), d.clone()]),
]),
BasePattern::product(BasePattern::product(a_, b_), BasePattern::product(c_, d_)),
),
];
for (pattern, expected_base_pattern) in pattern_string {
let base_pattern = BasePattern::from(&pattern);
assert_eq!(expected_base_pattern, base_pattern);
}
}
}