use crate::error::CsaError;
use crate::shape::Shape;
#[derive(Debug, Clone)]
pub enum CircuitArrow {
Id(Shape),
IdAbstract,
FullAdder,
Csa3to2 {
width: usize,
},
Braid {
left: Shape,
right: Shape,
},
Passthrough(Shape),
Tensor {
left: Box<CircuitArrow>,
right: Box<CircuitArrow>,
},
Compose {
first: Box<CircuitArrow>,
second: Box<CircuitArrow>,
},
}
impl CircuitArrow {
#[must_use]
pub const fn identity(shape: Shape) -> Self {
Self::Id(shape)
}
#[must_use]
pub const fn full_adder() -> Self {
Self::FullAdder
}
#[must_use]
pub const fn csa_3to2(width: usize) -> Self {
Self::Csa3to2 { width }
}
#[must_use]
pub const fn braid(left: Shape, right: Shape) -> Self {
Self::Braid { left, right }
}
#[must_use]
pub const fn passthrough(shape: Shape) -> Self {
Self::Passthrough(shape)
}
#[must_use]
pub fn tensor(left: Self, right: Self) -> Self {
Self::Tensor {
left: Box::new(left),
right: Box::new(right),
}
}
#[must_use]
pub fn compose_unchecked(first: Self, second: Self) -> Self {
Self::Compose {
first: Box::new(first),
second: Box::new(second),
}
}
pub fn compose(first: Self, second: Self) -> Result<Self, CsaError> {
let lhs = first.target();
let rhs = second.source();
match (lhs, rhs) {
(None, _) | (_, None) => Ok(Self::compose_unchecked(first, second)),
(Some(t), Some(s)) if t == s => Ok(Self::compose_unchecked(first, second)),
(Some(t), Some(s)) => Err(CsaError::ShapeMismatch { left: t, right: s }),
}
}
#[must_use]
pub fn source(&self) -> Option<Shape> {
match self {
Self::IdAbstract => None,
Self::Id(s) | Self::Passthrough(s) => Some(*s),
Self::FullAdder => Some(Shape::new(3, 1)),
Self::Csa3to2 { width } => Some(Shape::new(3, *width)),
Self::Braid { left, right } => Some(concat_shapes(*left, *right)),
Self::Tensor { left, right } => left
.source()
.and_then(|l| right.source().map(|r| concat_shapes(l, r))),
Self::Compose { first, second: _ } => first.source(),
}
}
#[must_use]
pub fn target(&self) -> Option<Shape> {
match self {
Self::IdAbstract => None,
Self::Id(s) | Self::Passthrough(s) => Some(*s),
Self::FullAdder => Some(Shape::new(2, 1)),
Self::Csa3to2 { width } => Some(Shape::new(2, width + 1)),
Self::Braid { left, right } => Some(concat_shapes(*right, *left)),
Self::Tensor { left, right } => left
.target()
.and_then(|l| right.target().map(|r| concat_shapes(l, r))),
Self::Compose { first: _, second } => second.target(),
}
}
}
#[must_use]
fn concat_shapes(a: Shape, b: Shape) -> Shape {
Shape::new(a.bundles() + b.bundles(), a.width().max(b.width()))
}
#[cfg(test)]
mod tests {
use super::CircuitArrow;
use crate::shape::Shape;
#[test]
fn full_adder_shapes() {
let fa = CircuitArrow::full_adder();
assert_eq!(fa.source(), Some(Shape::new(3, 1)));
assert_eq!(fa.target(), Some(Shape::new(2, 1)));
}
#[test]
fn csa_3to2_shapes() {
let csa = CircuitArrow::csa_3to2(16);
assert_eq!(csa.source(), Some(Shape::new(3, 16)));
assert_eq!(csa.target(), Some(Shape::new(2, 17)));
}
#[test]
fn tensor_accumulates_bundles() {
let left = CircuitArrow::passthrough(Shape::new(1, 8));
let right = CircuitArrow::passthrough(Shape::new(2, 8));
let t = CircuitArrow::tensor(left, right);
assert_eq!(t.source(), Some(Shape::new(3, 8)));
assert_eq!(t.target(), Some(Shape::new(3, 8)));
}
#[test]
fn compose_succeeds_on_match() -> Result<(), crate::error::CsaError> {
let csa = CircuitArrow::csa_3to2(4);
let pt = CircuitArrow::passthrough(Shape::new(2, 5));
let composed = CircuitArrow::compose(csa, pt)?;
assert_eq!(composed.source(), Some(Shape::new(3, 4)));
assert_eq!(composed.target(), Some(Shape::new(2, 5)));
Ok(())
}
#[test]
fn compose_errors_on_mismatch() {
let csa = CircuitArrow::csa_3to2(4);
let wrong = CircuitArrow::passthrough(Shape::new(3, 5));
assert!(CircuitArrow::compose(csa, wrong).is_err());
}
#[test]
fn id_abstract_has_no_shape() {
let id = CircuitArrow::IdAbstract;
assert!(id.source().is_none());
assert!(id.target().is_none());
}
#[test]
fn braid_swaps_shapes() {
let b = CircuitArrow::braid(Shape::new(1, 4), Shape::new(2, 4));
assert_eq!(b.source(), Some(Shape::new(3, 4)));
assert_eq!(b.target(), Some(Shape::new(3, 4)));
}
}