csa-rhdl 0.1.0

Carry-save adder compressor trees composed via comp-cat-rs, with hdl-cat backend
Documentation
//! [`CircuitArrow`]: a concrete AST for circuit morphisms.
//!
//! This enum is the Hom type of the categorical layer.  Every
//! morphism the crate builds is a `CircuitArrow` value; composition
//! and tensor product construct `Compose` and `Tensor` nodes.
//!
//! Interpretation is a later phase (e.g. RHDL codegen); the enum
//! itself is pure data with no effects.

use crate::error::CsaError;
use crate::shape::Shape;

/// A combinational morphism in the circuit category.
///
/// - [`Id`](Self::Id) is the concrete identity on a known shape.
/// - [`IdAbstract`](Self::IdAbstract) is the categorical identity with
///   shape unknown at this level (provided by the
///   [`Category`](comp_cat_rs::foundation::Category) impl, which cannot
///   recover the generic object type).
/// - [`FullAdder`](Self::FullAdder) is the 1-bit full adder cell.
/// - [`Csa3to2`](Self::Csa3to2) is the N-wide three-to-two compressor.
/// - [`Braid`](Self::Braid) swaps two adjacent shape blocks.
/// - [`Passthrough`](Self::Passthrough) wires a shape through unchanged
///   (used when a level has a remainder bundle).
/// - [`Tensor`](Self::Tensor) composes two arrows in parallel.
/// - [`Compose`](Self::Compose) composes two arrows in series.
#[derive(Debug, Clone)]
pub enum CircuitArrow {
    /// Concrete identity on a known shape.
    Id(Shape),
    /// Categorical identity with shape erased (from trait impl).
    IdAbstract,
    /// 1-bit full adder: `(a, b, cin) -> (s, cout)`.
    FullAdder,
    /// N-wide carry-save 3-to-2 compressor.
    Csa3to2 {
        /// The operand width.
        width: usize,
    },
    /// Wire-swap of two adjacent shape blocks.
    Braid {
        /// The left operand shape.
        left: Shape,
        /// The right operand shape.
        right: Shape,
    },
    /// Pass a shape through unchanged.
    Passthrough(Shape),
    /// Parallel composition.
    Tensor {
        /// Left branch.
        left: Box<CircuitArrow>,
        /// Right branch.
        right: Box<CircuitArrow>,
    },
    /// Series composition: `first` then `second`.
    Compose {
        /// Upstream.
        first: Box<CircuitArrow>,
        /// Downstream.
        second: Box<CircuitArrow>,
    },
}

impl CircuitArrow {
    /// Concrete identity on a shape.
    #[must_use]
    pub const fn identity(shape: Shape) -> Self {
        Self::Id(shape)
    }

    /// The 1-bit full adder.
    #[must_use]
    pub const fn full_adder() -> Self {
        Self::FullAdder
    }

    /// The N-wide 3-to-2 carry-save compressor.
    #[must_use]
    pub const fn csa_3to2(width: usize) -> Self {
        Self::Csa3to2 { width }
    }

    /// A structural wire swap.
    #[must_use]
    pub const fn braid(left: Shape, right: Shape) -> Self {
        Self::Braid { left, right }
    }

    /// Pass a shape through unchanged.
    #[must_use]
    pub const fn passthrough(shape: Shape) -> Self {
        Self::Passthrough(shape)
    }

    /// Parallel tensor product.
    #[must_use]
    pub fn tensor(left: Self, right: Self) -> Self {
        Self::Tensor {
            left: Box::new(left),
            right: Box::new(right),
        }
    }

    /// Series composition without shape checking.
    ///
    /// Prefer [`CircuitArrow::compose`] which validates that the
    /// target of `first` matches the source of `second`.
    #[must_use]
    pub fn compose_unchecked(first: Self, second: Self) -> Self {
        Self::Compose {
            first: Box::new(first),
            second: Box::new(second),
        }
    }

    /// Series composition with shape checking.
    ///
    /// # Errors
    ///
    /// Returns [`CsaError::ShapeMismatch`] if `first.target()` does
    /// not equal `second.source()`.  Abstract shapes (from
    /// [`IdAbstract`](Self::IdAbstract)) bypass the check.
    pub fn compose(first: Self, second: Self) -> Result<Self, CsaError> {
        let lhs = first.target();
        let rhs = second.source();
        match (lhs, rhs) {
            (None, _) | (_, None) => Ok(Self::compose_unchecked(first, second)),
            (Some(t), Some(s)) if t == s => Ok(Self::compose_unchecked(first, second)),
            (Some(t), Some(s)) => Err(CsaError::ShapeMismatch { left: t, right: s }),
        }
    }

    /// The source (input) shape, if known.
    ///
    /// Returns `None` for [`IdAbstract`](Self::IdAbstract), whose
    /// shape is erased by the generic [`Category::id`](comp_cat_rs::foundation::Category::id)
    /// signature.
    #[must_use]
    pub fn source(&self) -> Option<Shape> {
        match self {
            Self::IdAbstract => None,
            Self::Id(s) | Self::Passthrough(s) => Some(*s),
            Self::FullAdder => Some(Shape::new(3, 1)),
            Self::Csa3to2 { width } => Some(Shape::new(3, *width)),
            Self::Braid { left, right } => Some(concat_shapes(*left, *right)),
            Self::Tensor { left, right } => left
                .source()
                .and_then(|l| right.source().map(|r| concat_shapes(l, r))),
            Self::Compose { first, second: _ } => first.source(),
        }
    }

    /// The target (output) shape, if known.
    #[must_use]
    pub fn target(&self) -> Option<Shape> {
        match self {
            Self::IdAbstract => None,
            Self::Id(s) | Self::Passthrough(s) => Some(*s),
            Self::FullAdder => Some(Shape::new(2, 1)),
            Self::Csa3to2 { width } => Some(Shape::new(2, width + 1)),
            Self::Braid { left, right } => Some(concat_shapes(*right, *left)),
            Self::Tensor { left, right } => left
                .target()
                .and_then(|l| right.target().map(|r| concat_shapes(l, r))),
            Self::Compose { first: _, second } => second.target(),
        }
    }
}

/// Concatenate two shapes by summing bundle counts.
///
/// Widths must agree; when they don't, the *left* width wins
/// (this case should never arise from well-typed composition).
#[must_use]
fn concat_shapes(a: Shape, b: Shape) -> Shape {
    Shape::new(a.bundles() + b.bundles(), a.width().max(b.width()))
}

#[cfg(test)]
mod tests {
    use super::CircuitArrow;
    use crate::shape::Shape;

    #[test]
    fn full_adder_shapes() {
        let fa = CircuitArrow::full_adder();
        assert_eq!(fa.source(), Some(Shape::new(3, 1)));
        assert_eq!(fa.target(), Some(Shape::new(2, 1)));
    }

    #[test]
    fn csa_3to2_shapes() {
        let csa = CircuitArrow::csa_3to2(16);
        assert_eq!(csa.source(), Some(Shape::new(3, 16)));
        assert_eq!(csa.target(), Some(Shape::new(2, 17)));
    }

    #[test]
    fn tensor_accumulates_bundles() {
        let left = CircuitArrow::passthrough(Shape::new(1, 8));
        let right = CircuitArrow::passthrough(Shape::new(2, 8));
        let t = CircuitArrow::tensor(left, right);
        assert_eq!(t.source(), Some(Shape::new(3, 8)));
        assert_eq!(t.target(), Some(Shape::new(3, 8)));
    }

    #[test]
    fn compose_succeeds_on_match() -> Result<(), crate::error::CsaError> {
        let csa = CircuitArrow::csa_3to2(4);
        let pt = CircuitArrow::passthrough(Shape::new(2, 5));
        let composed = CircuitArrow::compose(csa, pt)?;
        assert_eq!(composed.source(), Some(Shape::new(3, 4)));
        assert_eq!(composed.target(), Some(Shape::new(2, 5)));
        Ok(())
    }

    #[test]
    fn compose_errors_on_mismatch() {
        let csa = CircuitArrow::csa_3to2(4);
        let wrong = CircuitArrow::passthrough(Shape::new(3, 5));
        assert!(CircuitArrow::compose(csa, wrong).is_err());
    }

    #[test]
    fn id_abstract_has_no_shape() {
        let id = CircuitArrow::IdAbstract;
        assert!(id.source().is_none());
        assert!(id.target().is_none());
    }

    #[test]
    fn braid_swaps_shapes() {
        let b = CircuitArrow::braid(Shape::new(1, 4), Shape::new(2, 4));
        assert_eq!(b.source(), Some(Shape::new(3, 4)));
        assert_eq!(b.target(), Some(Shape::new(3, 4)));
    }
}