csa-rhdl 0.1.0

Carry-save adder compressor trees composed via comp-cat-rs, with hdl-cat backend
Documentation
//! [`TreeGraph`]: the free-category encoding of a compressor tree.
//!
//! A [`TreeGraph`] for an `M → 2` reduction over width `W` has one
//! vertex per intermediate level-shape and one edge per level
//! transition.  [`TreeMorphism`] interprets each edge as the tensor
//! product of `Csa3to2` arrows (one per triple) plus a `Passthrough`
//! for any remainder bundle.

use comp_cat_rs::collapse::free_category::{
    Edge, FreeCategoryError, Graph, GraphMorphism, Vertex,
};

use crate::category::arrow::CircuitArrow;
use crate::shape::Shape;

/// What kind of transition an edge represents.
#[derive(Debug, Clone, Copy)]
pub enum LevelKind {
    /// A carry-save reduction of `m` bundles at width `w`.
    Csa {
        /// Input bundle count.
        m: usize,
        /// Input width.
        w: usize,
    },
}

/// A directed edge in the compressor-tree graph.
#[derive(Debug, Clone, Copy)]
pub struct LevelEdge {
    /// Source vertex index.
    pub source: Vertex,
    /// Target vertex index.
    pub target: Vertex,
    /// What transformation this edge performs.
    pub kind: LevelKind,
}

/// The free-category graph for an `M → 2` compressor tree.
///
/// Construct via [`TreeGraph::build`].  Vertices are the per-level
/// shapes; edges are level transitions.
#[derive(Debug, Clone)]
pub struct TreeGraph {
    vertices: Vec<Shape>,
    edges: Vec<LevelEdge>,
}

impl TreeGraph {
    /// Build a `TreeGraph` for reducing `m` bundles of width `w`.
    ///
    /// Runs until the remaining bundle count is `<= 2`.  For
    /// `m <= 2` returns a single-vertex graph with no edges.
    #[must_use]
    pub fn build(m: usize, w: usize) -> Self {
        let head = Shape::new(m, w);
        let (vertices, edges) = std::iter::successors(
            Some((head, 0usize)),
            |(s, idx)| match s.bundles() {
                0..=2 => None,
                n => Some((
                    Shape::new((n / 3) * 2 + n % 3, s.width() + 1),
                    idx + 1,
                )),
            },
        )
        .fold(
            (Vec::new(), Vec::new()),
            |(verts, edges): (Vec<Shape>, Vec<LevelEdge>), (shape, idx)| {
                let next_verts = verts
                    .into_iter()
                    .chain(std::iter::once(shape))
                    .collect::<Vec<_>>();
                let next_edges = match next_verts.get(idx) {
                    Some(current) => match current.bundles() {
                        0..=2 => edges,
                        m_current => edges
                            .into_iter()
                            .chain(std::iter::once(LevelEdge {
                                source: Vertex::new(idx),
                                target: Vertex::new(idx + 1),
                                kind: LevelKind::Csa {
                                    m: m_current,
                                    w: current.width(),
                                },
                            }))
                            .collect::<Vec<_>>(),
                    },
                    None => edges,
                };
                (next_verts, next_edges)
            },
        );
        Self { vertices, edges }
    }

    /// All vertices (level shapes).
    #[must_use]
    pub fn vertices(&self) -> &[Shape] {
        &self.vertices
    }

    /// All edges.
    #[must_use]
    pub fn edges_slice(&self) -> &[LevelEdge] {
        &self.edges
    }
}

impl Graph for TreeGraph {
    fn vertex_count(&self) -> usize {
        self.vertices.len()
    }

    fn edge_count(&self) -> usize {
        self.edges.len()
    }

    fn source(&self, edge: Edge) -> Result<Vertex, FreeCategoryError> {
        self.edges
            .get(edge.index())
            .map(|e| e.source)
            .ok_or(FreeCategoryError::EdgeOutOfBounds {
                edge,
                count: self.edges.len(),
            })
    }

    fn target(&self, edge: Edge) -> Result<Vertex, FreeCategoryError> {
        self.edges
            .get(edge.index())
            .map(|e| e.target)
            .ok_or(FreeCategoryError::EdgeOutOfBounds {
                edge,
                count: self.edges.len(),
            })
    }
}

/// A [`GraphMorphism`] interpreting [`TreeGraph`] edges as
/// tensor products of `Csa3to2` arrows.
#[derive(Debug)]
pub struct TreeMorphism<'g> {
    graph: &'g TreeGraph,
}

impl<'g> TreeMorphism<'g> {
    /// Create a morphism viewing the given graph.
    #[must_use]
    pub const fn new(graph: &'g TreeGraph) -> Self {
        Self { graph }
    }
}

impl GraphMorphism<TreeGraph> for TreeMorphism<'_> {
    type Object = Shape;
    type Morphism = CircuitArrow;

    fn map_vertex(&self, v: Vertex) -> Self::Object {
        self.graph
            .vertices
            .get(v.index())
            .copied()
            .unwrap_or_else(|| Shape::new(0, 0))
    }

    fn map_edge(&self, e: Edge) -> Self::Morphism {
        self.graph
            .edges
            .get(e.index())
            .map_or(CircuitArrow::IdAbstract, |edge| match edge.kind {
                LevelKind::Csa { m, w } => build_csa_level(m, w),
            })
    }
}

/// Build one level's morphism: `(m/3)` `Csa3to2<w>` arrows tensored
/// with a `Passthrough` for the `m % 3` remainder bundles.
#[must_use]
fn build_csa_level(m: usize, w: usize) -> CircuitArrow {
    let triples = m / 3;
    let remainder = m % 3;
    let triple_core = (1..triples).fold(CircuitArrow::csa_3to2(w), |acc, _| {
        CircuitArrow::tensor(acc, CircuitArrow::csa_3to2(w))
    });
    match (triples, remainder) {
        (0, 0) => CircuitArrow::IdAbstract,
        (0, rem) => CircuitArrow::passthrough(Shape::new(rem, w)),
        (_, 0) => triple_core,
        (_, rem) => CircuitArrow::tensor(
            triple_core,
            CircuitArrow::passthrough(Shape::new(rem, w)),
        ),
    }
}

#[cfg(test)]
mod tests {
    use super::{LevelKind, TreeGraph, TreeMorphism};
    use crate::category::arrow::CircuitArrow;
    use crate::shape::Shape;
    use comp_cat_rs::collapse::free_category::{Edge, Graph, GraphMorphism, Vertex};

    #[test]
    fn build_9_16_has_five_vertices_four_edges() {
        let g = TreeGraph::build(9, 16);
        assert_eq!(g.vertex_count(), 5);
        assert_eq!(g.edge_count(), 4);
        assert_eq!(g.vertices()[0], Shape::new(9, 16));
        assert_eq!(g.vertices()[1], Shape::new(6, 17));
        assert_eq!(g.vertices()[2], Shape::new(4, 18));
        assert_eq!(g.vertices()[3], Shape::new(3, 19));
        assert_eq!(g.vertices()[4], Shape::new(2, 20));
    }

    #[test]
    fn build_2_or_fewer_has_no_edges() {
        [0usize, 1, 2].iter().for_each(|m| {
            let g = TreeGraph::build(*m, 8);
            assert_eq!(g.edge_count(), 0, "m={m}");
        });
    }

    #[test]
    fn edge_endpoints_match() -> Result<(), comp_cat_rs::collapse::free_category::FreeCategoryError>
    {
        let g = TreeGraph::build(9, 16);
        (0..g.edge_count()).try_for_each(|i| {
            let e = Edge::new(i);
            let s = g.source(e)?;
            let t = g.target(e)?;
            assert_eq!(s, Vertex::new(i));
            assert_eq!(t, Vertex::new(i + 1));
            Ok(())
        })
    }

    #[test]
    fn morphism_interprets_first_edge_for_nine_operands() {
        let g = TreeGraph::build(9, 16);
        let m = TreeMorphism::new(&g);
        let arrow = m.map_edge(Edge::new(0));
        assert_eq!(arrow.source(), Some(Shape::new(9, 16)));
        assert_eq!(arrow.target(), Some(Shape::new(6, 17)));
    }

    #[test]
    fn morphism_interprets_level_with_remainder() {
        // stage 2 (M=4, W=18): one triple + 1 passthrough
        let g = TreeGraph::build(9, 16);
        assert!(matches!(
            g.edges_slice()[2].kind,
            LevelKind::Csa { m: 4, w: 18 }
        ));
        let m = TreeMorphism::new(&g);
        let arrow = m.map_edge(Edge::new(2));
        match arrow {
            CircuitArrow::Tensor { .. } => {}
            CircuitArrow::Id(_)
            | CircuitArrow::IdAbstract
            | CircuitArrow::FullAdder
            | CircuitArrow::Csa3to2 { .. }
            | CircuitArrow::Braid { .. }
            | CircuitArrow::Passthrough(_)
            | CircuitArrow::Compose { .. } => panic!("expected Tensor"),
        }
    }
}