csa-rhdl 0.1.0

Carry-save adder compressor trees composed via comp-cat-rs, with hdl-cat backend
Documentation
//! Recursive `M → 2` compressor built from the [`TreeGraph`].

use comp_cat_rs::collapse::free_category::{Edge, Graph, Path, Vertex};

use crate::category::arrow::CircuitArrow;
use crate::category::graph::TreeGraph;
use crate::category::interpret::interpret_tree;
use crate::error::CsaError;
use crate::tree::sum_terms::sum_terms_passthrough;

/// Concatenate every edge of a [`TreeGraph`] into a single [`Path`].
///
/// For graphs with no edges (i.e. `m ∈ {0, 1, 2}`) returns the
/// identity path on the starting vertex.
///
/// # Errors
///
/// Returns [`CsaError::UpstreamGraph`] if any edge lookup fails
/// inside [`comp_cat_rs`].
pub fn build_full_path(graph: &TreeGraph) -> Result<Path, CsaError> {
    (0..graph.edge_count()).try_fold(Path::identity(Vertex::new(0)), |acc, i| {
        Path::singleton(graph, Edge::new(i))
            .and_then(|p| acc.compose(p))
            .map_err(CsaError::from)
    })
}

/// Compile a combinational `m → 2` compressor tree of width `w`.
///
/// # Errors
///
/// Returns [`CsaError::TreeSizeZero`] if `m == 0`, and
/// [`CsaError::UpstreamGraph`] on any underlying free-category error.
pub fn compressor_tree(m: usize, w: usize) -> Result<CircuitArrow, CsaError> {
    match m {
        0 => Err(CsaError::TreeSizeZero),
        1 | 2 => Ok(sum_terms_passthrough(m, w)),
        _ => {
            let graph = TreeGraph::build(m, w);
            let path = build_full_path(&graph)?;
            Ok(interpret_tree(&graph, &path))
        }
    }
}

/// Const-generic wrapper over [`compressor_tree`].
///
/// # Errors
///
/// Same conditions as [`compressor_tree`].
pub fn compressor_tree_const<const M: usize, const W: usize>()
    -> Result<CircuitArrow, CsaError>
{
    compressor_tree(M, W)
}

#[cfg(test)]
mod tests {
    use super::{build_full_path, compressor_tree};
    use crate::category::graph::TreeGraph;
    use crate::error::CsaError;
    use crate::shape::Shape;
    use comp_cat_rs::collapse::free_category::FreeCategoryError;

    #[test]
    fn compressor_tree_zero_is_error() {
        assert!(matches!(
            compressor_tree(0, 8),
            Err(CsaError::TreeSizeZero)
        ));
    }

    #[test]
    fn compressor_tree_small_is_passthrough() -> Result<(), CsaError> {
        let one = compressor_tree(1, 8)?;
        assert_eq!(one.source(), Some(Shape::new(1, 8)));
        let two = compressor_tree(2, 8)?;
        assert_eq!(two.target(), Some(Shape::new(2, 8)));
        Ok(())
    }

    #[test]
    fn compressor_tree_9_16_has_expected_endpoints() -> Result<(), CsaError> {
        let arrow = compressor_tree(9, 16)?;
        assert_eq!(arrow.source(), Some(Shape::new(9, 16)));
        assert_eq!(arrow.target(), Some(Shape::new(2, 20)));
        Ok(())
    }

    #[test]
    fn full_path_for_two_is_identity() -> Result<(), FreeCategoryError> {
        let g = TreeGraph::build(2, 8);
        let p = super::build_full_path(&g).map_err(|e| match e {
            CsaError::UpstreamGraph(inner) => inner,
            CsaError::TreeSizeZero
            | CsaError::BitWidthMismatch { .. }
            | CsaError::ShapeMismatch { .. }
            | CsaError::InvalidGrouping { .. } => {
                FreeCategoryError::CompositionMismatch {
                    target: comp_cat_rs::collapse::free_category::Vertex::new(0),
                    source: comp_cat_rs::collapse::free_category::Vertex::new(0),
                }
            }
            #[cfg(feature = "hdl-cat-gates")]
            CsaError::HdlCat(_) => {
                FreeCategoryError::CompositionMismatch {
                    target: comp_cat_rs::collapse::free_category::Vertex::new(0),
                    source: comp_cat_rs::collapse::free_category::Vertex::new(0),
                }
            }
        })?;
        assert!(p.is_identity());
        Ok(())
    }

    #[test]
    fn full_path_for_nine_has_four_edges() -> Result<(), CsaError> {
        let g = TreeGraph::build(9, 16);
        let p = build_full_path(&g)?;
        assert_eq!(p.len(), 4);
        Ok(())
    }
}