use std::sync::Arc;
use comp_cat_rs::effect::io::Io;
use comp_cat_rs::effect::stream::Stream;
use crate::category::arrow::CircuitArrow;
use crate::error::CsaError;
use crate::shape::Shape;
use crate::tree::level::next_bundles;
#[derive(Debug, Clone, Copy)]
pub struct LevelDescriptor {
pub in_bundles: usize,
pub in_width: usize,
}
#[derive(Debug, Clone, Copy)]
struct LevelState {
bundles: usize,
width: usize,
}
#[must_use]
pub fn tree_reduce(initial_bundles: usize, initial_width: usize)
-> Stream<CsaError, LevelDescriptor>
{
Stream::unfold(
LevelState {
bundles: initial_bundles,
width: initial_width,
},
Arc::new(|state: LevelState| {
Io::pure(match state.bundles {
0..=2 => None,
n => Some((
LevelDescriptor {
in_bundles: n,
in_width: state.width,
},
LevelState {
bundles: next_bundles(n),
width: state.width + 1,
},
)),
})
}),
)
}
#[must_use]
pub fn compile_tree(m: usize, w: usize) -> Io<CsaError, CircuitArrow> {
match m {
0 => Io::suspend(|| Err(CsaError::TreeSizeZero)),
_ => {
let initial: Result<CircuitArrow, CsaError> =
Ok(CircuitArrow::identity(Shape::new(m, w)));
tree_reduce(m, w)
.fold(
initial,
Arc::new(|acc, desc| acc.and_then(|prev| compose_level(prev, desc))),
)
.flat_map(|res| match res {
Ok(arrow) => Io::pure(arrow),
Err(e) => Io::suspend(move || Err(e)),
})
}
}
}
fn compose_level(prev: CircuitArrow, desc: LevelDescriptor) -> Result<CircuitArrow, CsaError> {
CircuitArrow::compose(prev, build_level_arrow(desc.in_bundles, desc.in_width))
}
#[must_use]
pub fn build_level_arrow(m: usize, w: usize) -> CircuitArrow {
let triples = m / 3;
let remainder = m % 3;
let triple_core = (1..triples).fold(CircuitArrow::csa_3to2(w), |acc, _| {
CircuitArrow::tensor(acc, CircuitArrow::csa_3to2(w))
});
match (triples, remainder) {
(0, 0) => CircuitArrow::IdAbstract,
(0, rem) => CircuitArrow::passthrough(Shape::new(rem, w)),
(_, 0) => triple_core,
(_, rem) => CircuitArrow::tensor(
triple_core,
CircuitArrow::passthrough(Shape::new(rem, w)),
),
}
}
#[cfg(test)]
mod tests {
use super::{build_level_arrow, compile_tree, tree_reduce, LevelDescriptor};
use crate::category::arrow::CircuitArrow;
use crate::error::CsaError;
use crate::shape::Shape;
use std::sync::Arc;
#[test]
fn tree_reduce_emits_four_descriptors_for_nine() -> Result<(), CsaError> {
let ds = tree_reduce(9, 16)
.fold(
Vec::<LevelDescriptor>::new(),
Arc::new(|acc: Vec<LevelDescriptor>, d| {
acc.into_iter().chain(std::iter::once(d)).collect()
}),
)
.run()?;
assert_eq!(ds.len(), 4);
assert_eq!(ds[0].in_bundles, 9);
assert_eq!(ds[0].in_width, 16);
assert_eq!(ds[1].in_bundles, 6);
assert_eq!(ds[1].in_width, 17);
assert_eq!(ds[2].in_bundles, 4);
assert_eq!(ds[2].in_width, 18);
assert_eq!(ds[3].in_bundles, 3);
assert_eq!(ds[3].in_width, 19);
Ok(())
}
#[test]
fn tree_reduce_empty_for_small_inputs() -> Result<(), CsaError> {
[0usize, 1, 2].iter().try_for_each(|m| {
let count = tree_reduce(*m, 8)
.fold(0usize, Arc::new(|a, _| a + 1))
.run()?;
assert_eq!(count, 0, "m={m}");
Ok::<(), CsaError>(())
})
}
#[test]
fn compile_tree_nine_sixteen_endpoints() -> Result<(), CsaError> {
let arrow = compile_tree(9, 16).run()?;
assert_eq!(arrow.source(), Some(Shape::new(9, 16)));
assert_eq!(arrow.target(), Some(Shape::new(2, 20)));
Ok(())
}
#[test]
fn compile_tree_zero_errors() {
assert!(matches!(
compile_tree(0, 8).run(),
Err(CsaError::TreeSizeZero)
));
}
#[test]
fn compile_tree_small_is_identity() -> Result<(), CsaError> {
[1usize, 2].iter().try_for_each(|m| {
let arrow = compile_tree(*m, 8).run()?;
assert_eq!(arrow.source(), Some(Shape::new(*m, 8)));
assert_eq!(arrow.target(), Some(Shape::new(*m, 8)));
Ok::<(), CsaError>(())
})
}
#[test]
fn build_level_arrow_pure_triples() {
let a = build_level_arrow(6, 4);
assert_eq!(a.source(), Some(Shape::new(6, 4)));
assert_eq!(a.target(), Some(Shape::new(4, 5)));
}
#[test]
fn build_level_arrow_with_remainder() {
let a = build_level_arrow(4, 8);
assert_eq!(a.source(), Some(Shape::new(4, 8)));
match a {
CircuitArrow::Tensor { .. } => {}
CircuitArrow::Id(_)
| CircuitArrow::IdAbstract
| CircuitArrow::FullAdder
| CircuitArrow::Csa3to2 { .. }
| CircuitArrow::Braid { .. }
| CircuitArrow::Passthrough(_)
| CircuitArrow::Compose { .. } => panic!("expected Tensor"),
}
}
}