csa-rhdl 0.1.0

Carry-save adder compressor trees composed via comp-cat-rs, with hdl-cat backend
Documentation
//! Stream-based catamorphism: the delay-run tree reducer.
//!
//! [`tree_reduce`] produces a lazy [`Stream`] of [`LevelDescriptor`]s
//! describing each stage of the compressor tree.  [`compile_tree`]
//! folds that stream into a composed [`CircuitArrow`], returning an
//! [`Io`] that the caller invokes `.run()` on exactly once at the
//! boundary (verilog codegen, simulation, shape inspection).

use std::sync::Arc;

use comp_cat_rs::effect::io::Io;
use comp_cat_rs::effect::stream::Stream;

use crate::category::arrow::CircuitArrow;
use crate::error::CsaError;
use crate::shape::Shape;
use crate::tree::level::next_bundles;

/// One stage of the compressor tree: `in_bundles` bundles at width
/// `in_width` reduced to `next_bundles(in_bundles)` bundles at
/// width `in_width + 1`.
#[derive(Debug, Clone, Copy)]
pub struct LevelDescriptor {
    /// Input bundle count.
    pub in_bundles: usize,
    /// Input bit width.
    pub in_width: usize,
}

/// Internal state for the stream unfold.
#[derive(Debug, Clone, Copy)]
struct LevelState {
    bundles: usize,
    width: usize,
}

/// Lazy stream of level descriptors for reducing `m` bundles at width `w`.
///
/// Emits nothing when `m ∈ {0, 1, 2}` (no work to do).
#[must_use]
pub fn tree_reduce(initial_bundles: usize, initial_width: usize)
    -> Stream<CsaError, LevelDescriptor>
{
    Stream::unfold(
        LevelState {
            bundles: initial_bundles,
            width: initial_width,
        },
        Arc::new(|state: LevelState| {
            Io::pure(match state.bundles {
                0..=2 => None,
                n => Some((
                    LevelDescriptor {
                        in_bundles: n,
                        in_width: state.width,
                    },
                    LevelState {
                        bundles: next_bundles(n),
                        width: state.width + 1,
                    },
                )),
            })
        }),
    )
}

/// Fold the level stream into a composed circuit morphism.
///
/// The returned [`Io`] is deferred: no composition happens eagerly,
/// the caller invokes `.run()` exactly once at the boundary.
///
/// For `m == 0` produces `Io::suspend(|| Err(TreeSizeZero))`.
#[must_use]
pub fn compile_tree(m: usize, w: usize) -> Io<CsaError, CircuitArrow> {
    match m {
        0 => Io::suspend(|| Err(CsaError::TreeSizeZero)),
        _ => {
            let initial: Result<CircuitArrow, CsaError> =
                Ok(CircuitArrow::identity(Shape::new(m, w)));
            tree_reduce(m, w)
                .fold(
                    initial,
                    Arc::new(|acc, desc| acc.and_then(|prev| compose_level(prev, desc))),
                )
                .flat_map(|res| match res {
                    Ok(arrow) => Io::pure(arrow),
                    Err(e) => Io::suspend(move || Err(e)),
                })
        }
    }
}

fn compose_level(prev: CircuitArrow, desc: LevelDescriptor) -> Result<CircuitArrow, CsaError> {
    CircuitArrow::compose(prev, build_level_arrow(desc.in_bundles, desc.in_width))
}

/// Build one level's arrow: `(m / 3)` `Csa3to2<w>` in parallel with
/// `Passthrough(m % 3, w)` for the remainder.
#[must_use]
pub fn build_level_arrow(m: usize, w: usize) -> CircuitArrow {
    let triples = m / 3;
    let remainder = m % 3;
    let triple_core = (1..triples).fold(CircuitArrow::csa_3to2(w), |acc, _| {
        CircuitArrow::tensor(acc, CircuitArrow::csa_3to2(w))
    });
    match (triples, remainder) {
        (0, 0) => CircuitArrow::IdAbstract,
        (0, rem) => CircuitArrow::passthrough(Shape::new(rem, w)),
        (_, 0) => triple_core,
        (_, rem) => CircuitArrow::tensor(
            triple_core,
            CircuitArrow::passthrough(Shape::new(rem, w)),
        ),
    }
}

#[cfg(test)]
mod tests {
    use super::{build_level_arrow, compile_tree, tree_reduce, LevelDescriptor};
    use crate::category::arrow::CircuitArrow;
    use crate::error::CsaError;
    use crate::shape::Shape;
    use std::sync::Arc;

    #[test]
    fn tree_reduce_emits_four_descriptors_for_nine() -> Result<(), CsaError> {
        let ds = tree_reduce(9, 16)
            .fold(
                Vec::<LevelDescriptor>::new(),
                Arc::new(|acc: Vec<LevelDescriptor>, d| {
                    acc.into_iter().chain(std::iter::once(d)).collect()
                }),
            )
            .run()?;
        assert_eq!(ds.len(), 4);
        assert_eq!(ds[0].in_bundles, 9);
        assert_eq!(ds[0].in_width, 16);
        assert_eq!(ds[1].in_bundles, 6);
        assert_eq!(ds[1].in_width, 17);
        assert_eq!(ds[2].in_bundles, 4);
        assert_eq!(ds[2].in_width, 18);
        assert_eq!(ds[3].in_bundles, 3);
        assert_eq!(ds[3].in_width, 19);
        Ok(())
    }

    #[test]
    fn tree_reduce_empty_for_small_inputs() -> Result<(), CsaError> {
        [0usize, 1, 2].iter().try_for_each(|m| {
            let count = tree_reduce(*m, 8)
                .fold(0usize, Arc::new(|a, _| a + 1))
                .run()?;
            assert_eq!(count, 0, "m={m}");
            Ok::<(), CsaError>(())
        })
    }

    #[test]
    fn compile_tree_nine_sixteen_endpoints() -> Result<(), CsaError> {
        let arrow = compile_tree(9, 16).run()?;
        assert_eq!(arrow.source(), Some(Shape::new(9, 16)));
        assert_eq!(arrow.target(), Some(Shape::new(2, 20)));
        Ok(())
    }

    #[test]
    fn compile_tree_zero_errors() {
        assert!(matches!(
            compile_tree(0, 8).run(),
            Err(CsaError::TreeSizeZero)
        ));
    }

    #[test]
    fn compile_tree_small_is_identity() -> Result<(), CsaError> {
        [1usize, 2].iter().try_for_each(|m| {
            let arrow = compile_tree(*m, 8).run()?;
            assert_eq!(arrow.source(), Some(Shape::new(*m, 8)));
            assert_eq!(arrow.target(), Some(Shape::new(*m, 8)));
            Ok::<(), CsaError>(())
        })
    }

    #[test]
    fn build_level_arrow_pure_triples() {
        let a = build_level_arrow(6, 4);
        assert_eq!(a.source(), Some(Shape::new(6, 4)));
        assert_eq!(a.target(), Some(Shape::new(4, 5)));
    }

    #[test]
    fn build_level_arrow_with_remainder() {
        let a = build_level_arrow(4, 8);
        assert_eq!(a.source(), Some(Shape::new(4, 8)));
        match a {
            CircuitArrow::Tensor { .. } => {}
            CircuitArrow::Id(_)
            | CircuitArrow::IdAbstract
            | CircuitArrow::FullAdder
            | CircuitArrow::Csa3to2 { .. }
            | CircuitArrow::Braid { .. }
            | CircuitArrow::Passthrough(_)
            | CircuitArrow::Compose { .. } => panic!("expected Tensor"),
        }
    }
}