csa-rhdl 0.1.0

Carry-save adder compressor trees composed via comp-cat-rs, with hdl-cat backend
Documentation
//! Concrete hdl-cat circuit implementations of the CSA gate hierarchy.
//!
//! - [`full_adder`]: 1-bit full adder as an hdl-cat circuit arrow.
//! - [`csa_3to2`]: N-wide 3-to-2 carry-save compressor.
//!
//! Each function is a circuit *constructor* returning an hdl-cat
//! [`CircuitArrow`](hdl_cat::circuit::CircuitArrow) wrapped in
//! `Result`.  To simulate the resulting circuit, use the hdl-cat
//! simulator (`hdl_cat::sim::interp::interpret`).

use hdl_cat::circuit::{CircuitArrow, CircuitTensor, Obj};
use hdl_cat::ir::{BinOp, HdlGraphBuilder, Op, WireTy};
use hdl_cat::bits::Bits;
use hdl_cat::Error;

/// A full-adder arrow: `((bool ⊗ bool) ⊗ bool) -> (bool ⊗ bool)`.
///
/// Inputs: `((a, b), carry_in)`.  Outputs: `(sum, carry_out)`.
pub type FullAdderArrow = CircuitArrow<
    CircuitTensor<CircuitTensor<Obj<bool>, Obj<bool>>, Obj<bool>>,
    CircuitTensor<Obj<bool>, Obj<bool>>,
>;

/// A 3-to-2 carry-save compressor arrow:
/// `((Bits<N> ⊗ Bits<N>) ⊗ Bits<N>) -> (Bits<N> ⊗ Bits<N>)`.
///
/// Inputs: `((a, b), cin)`.  Outputs: `(sum, carry_out)`.
pub type Csa3to2Arrow<const N: usize> = CircuitArrow<
    CircuitTensor<CircuitTensor<Obj<Bits<N>>, Obj<Bits<N>>>, Obj<Bits<N>>>,
    CircuitTensor<Obj<Bits<N>>, Obj<Bits<N>>>,
>;

/// Construct a 1-bit full adder circuit.
///
/// `S = A ⊕ B ⊕ Cin`, `Cout = (A · B) + (Cin · (A ⊕ B))`.
///
/// # Errors
///
/// Infallible in practice; [`Error`] is returned only if the
/// IR builder rejects an instruction.
pub fn full_adder() -> Result<FullAdderArrow, Error> {
    let (bld, a) = HdlGraphBuilder::new().with_wire(WireTy::Bit);
    let (bld, b) = bld.with_wire(WireTy::Bit);
    let (bld, cin) = bld.with_wire(WireTy::Bit);
    let (bld, ab) = bld.with_wire(WireTy::Bit);
    let (bld, ab_and) = bld.with_wire(WireTy::Bit);
    let (bld, c_and) = bld.with_wire(WireTy::Bit);
    let (bld, sum) = bld.with_wire(WireTy::Bit);
    let (bld, cout) = bld.with_wire(WireTy::Bit);
    let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![a, b], ab)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![a, b], ab_and)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![cin, ab], c_and)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![ab, cin], sum)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::Or), vec![ab_and, c_and], cout)?;
    Ok(CircuitArrow::from_raw_parts(
        bld.build(),
        vec![a, b, cin],
        vec![sum, cout],
    ))
}

/// Construct an N-wide 3-to-2 carry-save compressor circuit.
///
/// The N-fold bit-parallel tensor power of [`full_adder`]: each bit
/// position runs an independent full-adder cell.  The invariant
/// `A + B + Cin == S + (Cout << 1)` holds as natural numbers.
///
/// # Errors
///
/// Returns [`Error::Overflow`] when `N` exceeds `u32::MAX`.
pub fn csa_3to2<const N: usize>() -> Result<Csa3to2Arrow<N>, Error> {
    let w = u32::try_from(N).map_err(|_| Error::Overflow {
        width: hdl_cat::Width::new(u32::MAX),
    })?;
    let wire_ty = WireTy::Bits(w);
    let (bld, a) = HdlGraphBuilder::new().with_wire(wire_ty.clone());
    let (bld, b) = bld.with_wire(wire_ty.clone());
    let (bld, cin) = bld.with_wire(wire_ty.clone());
    let (bld, ab) = bld.with_wire(wire_ty.clone());
    let (bld, ab_and) = bld.with_wire(wire_ty.clone());
    let (bld, c_and) = bld.with_wire(wire_ty.clone());
    let (bld, s) = bld.with_wire(wire_ty.clone());
    let (bld, cout) = bld.with_wire(wire_ty);
    let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![a, b], ab)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![a, b], ab_and)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![cin, ab], c_and)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![ab, cin], s)?;
    let bld = bld.with_instruction(Op::Bin(BinOp::Or), vec![ab_and, c_and], cout)?;
    Ok(CircuitArrow::from_raw_parts(
        bld.build(),
        vec![a, b, cin],
        vec![s, cout],
    ))
}

#[cfg(test)]
mod tests {
    use super::{csa_3to2, full_adder};

    #[test]
    fn full_adder_builds() -> Result<(), hdl_cat::Error> {
        let fa = full_adder()?;
        assert_eq!(fa.inputs().len(), 3);
        assert_eq!(fa.outputs().len(), 2);
        assert_eq!(fa.graph().instructions().len(), 5);
        Ok(())
    }

    #[test]
    fn csa_3to2_builds_w8() -> Result<(), hdl_cat::Error> {
        let csa = csa_3to2::<8>()?;
        assert_eq!(csa.inputs().len(), 3);
        assert_eq!(csa.outputs().len(), 2);
        assert_eq!(csa.graph().instructions().len(), 5);
        Ok(())
    }

    #[test]
    fn csa_3to2_builds_w16() -> Result<(), hdl_cat::Error> {
        let csa = csa_3to2::<16>()?;
        assert_eq!(csa.inputs().len(), 3);
        assert_eq!(csa.outputs().len(), 2);
        assert_eq!(csa.graph().wires().len(), 8);
        Ok(())
    }
}