use hdl_cat::ir::{BinOp, HdlGraph, HdlGraphBuilder, Op, WireId, WireTy};
use hdl_cat::Error as HdlError;
use crate::category::arrow::CircuitArrow;
use crate::shape::Shape;
#[derive(Debug, Clone)]
pub struct Lowered {
graph: HdlGraph,
inputs: Vec<WireId>,
outputs: Vec<WireId>,
}
impl Lowered {
pub fn graph(&self) -> &HdlGraph {
&self.graph
}
#[must_use]
pub fn inputs(&self) -> &[WireId] {
&self.inputs
}
#[must_use]
pub fn outputs(&self) -> &[WireId] {
&self.outputs
}
pub fn into_parts(self) -> (HdlGraph, Vec<WireId>, Vec<WireId>) {
(self.graph, self.inputs, self.outputs)
}
}
pub fn lower(arrow: &CircuitArrow) -> Result<Lowered, HdlError> {
match arrow {
CircuitArrow::Id(shape) | CircuitArrow::Passthrough(shape) => lower_identity(*shape),
CircuitArrow::IdAbstract => lower_identity(Shape::new(0, 0)),
CircuitArrow::FullAdder => lower_full_adder(),
CircuitArrow::Csa3to2 { width } => lower_csa_3to2(*width),
CircuitArrow::Braid { left, right } => lower_braid(*left, *right),
CircuitArrow::Tensor { left, right } => lower_tensor(left, right),
CircuitArrow::Compose { first, second } => lower_compose(first, second),
}
}
fn wire_ty_for_width(width: usize) -> Result<WireTy, HdlError> {
if width <= 1 {
Ok(WireTy::Bit)
} else {
u32::try_from(width)
.map(WireTy::Bits)
.map_err(|_| HdlError::Overflow {
width: hdl_cat::Width::new(u32::MAX),
})
}
}
fn lower_identity(shape: Shape) -> Result<Lowered, HdlError> {
let ty = wire_ty_for_width(shape.width())?;
let (bld, wires) = (0..shape.bundles()).fold(
(HdlGraphBuilder::new(), Vec::new()),
|(bld, acc), _| {
let (next_bld, id) = bld.with_wire(ty.clone());
(next_bld, acc.into_iter().chain(std::iter::once(id)).collect())
},
);
Ok(Lowered {
graph: bld.build(),
inputs: wires.clone(),
outputs: wires,
})
}
fn lower_full_adder() -> Result<Lowered, HdlError> {
let (bld, a) = HdlGraphBuilder::new().with_wire(WireTy::Bit);
let (bld, b) = bld.with_wire(WireTy::Bit);
let (bld, cin) = bld.with_wire(WireTy::Bit);
let (bld, ab) = bld.with_wire(WireTy::Bit);
let (bld, ab_and) = bld.with_wire(WireTy::Bit);
let (bld, c_and) = bld.with_wire(WireTy::Bit);
let (bld, sum) = bld.with_wire(WireTy::Bit);
let (bld, cout) = bld.with_wire(WireTy::Bit);
let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![a, b], ab)?;
let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![a, b], ab_and)?;
let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![cin, ab], c_and)?;
let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![ab, cin], sum)?;
let bld = bld.with_instruction(Op::Bin(BinOp::Or), vec![ab_and, c_and], cout)?;
Ok(Lowered {
graph: bld.build(),
inputs: vec![a, b, cin],
outputs: vec![sum, cout],
})
}
fn lower_csa_3to2(width: usize) -> Result<Lowered, HdlError> {
let ty = wire_ty_for_width(width)?;
let (bld, a) = HdlGraphBuilder::new().with_wire(ty.clone());
let (bld, b) = bld.with_wire(ty.clone());
let (bld, cin) = bld.with_wire(ty.clone());
let (bld, ab) = bld.with_wire(ty.clone());
let (bld, ab_and) = bld.with_wire(ty.clone());
let (bld, c_and) = bld.with_wire(ty.clone());
let (bld, s) = bld.with_wire(ty.clone());
let (bld, cout) = bld.with_wire(ty);
let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![a, b], ab)?;
let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![a, b], ab_and)?;
let bld = bld.with_instruction(Op::Bin(BinOp::And), vec![cin, ab], c_and)?;
let bld = bld.with_instruction(Op::Bin(BinOp::Xor), vec![ab, cin], s)?;
let bld = bld.with_instruction(Op::Bin(BinOp::Or), vec![ab_and, c_and], cout)?;
Ok(Lowered {
graph: bld.build(),
inputs: vec![a, b, cin],
outputs: vec![s, cout],
})
}
fn lower_braid(left: Shape, right: Shape) -> Result<Lowered, HdlError> {
let left_ty = wire_ty_for_width(left.width())?;
let right_ty = wire_ty_for_width(right.width())?;
let (bld, left_wires) = (0..left.bundles()).fold(
(HdlGraphBuilder::new(), Vec::new()),
|(bld, acc), _| {
let (next_bld, id) = bld.with_wire(left_ty.clone());
(next_bld, acc.into_iter().chain(std::iter::once(id)).collect())
},
);
let (bld, right_wires) = (0..right.bundles()).fold(
(bld, Vec::new()),
|(bld, acc), _| {
let (next_bld, id) = bld.with_wire(right_ty.clone());
(next_bld, acc.into_iter().chain(std::iter::once(id)).collect())
},
);
let inputs: Vec<WireId> = left_wires.iter().chain(right_wires.iter()).copied().collect();
let outputs: Vec<WireId> = right_wires.iter().chain(left_wires.iter()).copied().collect();
Ok(Lowered {
graph: bld.build(),
inputs,
outputs,
})
}
fn lower_tensor(left: &CircuitArrow, right: &CircuitArrow) -> Result<Lowered, HdlError> {
let l = lower(left)?;
let r = lower(right)?;
let offset = l.graph.wires().len();
let bld = append_wires(HdlGraphBuilder::new(), l.graph.wires().iter().cloned());
let bld = append_wires(bld, r.graph.wires().iter().cloned());
let bld = replay_instructions(&l.graph, bld, |w| w)?;
let shift = |w: WireId| WireId::new(w.index() + offset);
let bld = replay_instructions(&r.graph, bld, shift)?;
let inputs: Vec<WireId> = l.inputs.into_iter()
.chain(r.inputs.into_iter().map(shift))
.collect();
let outputs: Vec<WireId> = l.outputs.into_iter()
.chain(r.outputs.into_iter().map(shift))
.collect();
Ok(Lowered {
graph: bld.build(),
inputs,
outputs,
})
}
fn lower_compose(first: &CircuitArrow, second: &CircuitArrow) -> Result<Lowered, HdlError> {
let f = lower(first)?;
let g = lower(second)?;
let offset = f.graph.wires().len();
let substitution: Vec<(WireId, WireId)> = g.inputs.iter()
.zip(f.outputs.iter())
.map(|(g_in, f_out)| (WireId::new(g_in.index() + offset), *f_out))
.collect();
let remap = |w: WireId| -> WireId {
let shifted = WireId::new(w.index() + offset);
substitution.iter()
.find_map(|(from, to)| (*from == shifted).then_some(*to))
.unwrap_or(shifted)
};
let bld = append_wires(HdlGraphBuilder::new(), f.graph.wires().iter().cloned());
let bld = append_wires(bld, g.graph.wires().iter().cloned());
let bld = replay_instructions(&f.graph, bld, |w| w)?;
let bld = replay_instructions(&g.graph, bld, remap)?;
let outputs: Vec<WireId> = g.outputs.into_iter().map(remap).collect();
Ok(Lowered {
graph: bld.build(),
inputs: f.inputs,
outputs,
})
}
fn append_wires(
bld: HdlGraphBuilder,
tys: impl IntoIterator<Item = WireTy>,
) -> HdlGraphBuilder {
tys.into_iter().fold(bld, |bld, ty| bld.with_wire(ty).0)
}
fn replay_instructions(
graph: &HdlGraph,
bld: HdlGraphBuilder,
remap: impl Fn(WireId) -> WireId,
) -> Result<HdlGraphBuilder, HdlError> {
graph.instructions().iter().try_fold(bld, |bld, instr| {
let new_inputs: Vec<WireId> = instr.inputs().iter().copied().map(&remap).collect();
let new_output = remap(instr.output());
bld.with_instruction(instr.op().clone(), new_inputs, new_output)
})
}
#[cfg(test)]
mod tests {
use super::lower;
use crate::category::arrow::CircuitArrow;
use crate::shape::Shape;
#[test]
fn lower_identity() -> Result<(), hdl_cat::Error> {
let id = CircuitArrow::identity(Shape::new(3, 8));
let lowered = lower(&id)?;
assert_eq!(lowered.inputs().len(), 3);
assert_eq!(lowered.outputs().len(), 3);
assert_eq!(lowered.graph().instructions().len(), 0);
Ok(())
}
#[test]
fn lower_full_adder() -> Result<(), hdl_cat::Error> {
let fa = CircuitArrow::full_adder();
let lowered = lower(&fa)?;
assert_eq!(lowered.inputs().len(), 3);
assert_eq!(lowered.outputs().len(), 2);
assert_eq!(lowered.graph().instructions().len(), 5);
Ok(())
}
#[test]
fn lower_csa_3to2() -> Result<(), hdl_cat::Error> {
let csa = CircuitArrow::csa_3to2(16);
let lowered = lower(&csa)?;
assert_eq!(lowered.inputs().len(), 3);
assert_eq!(lowered.outputs().len(), 2);
assert_eq!(lowered.graph().instructions().len(), 5);
Ok(())
}
#[test]
fn lower_braid() -> Result<(), hdl_cat::Error> {
let b = CircuitArrow::braid(Shape::new(1, 4), Shape::new(2, 4));
let lowered = lower(&b)?;
assert_eq!(lowered.inputs().len(), 3);
assert_eq!(lowered.outputs().len(), 3);
assert_eq!(lowered.graph().instructions().len(), 0);
Ok(())
}
#[test]
fn lower_tensor() -> Result<(), hdl_cat::Error> {
let left = CircuitArrow::csa_3to2(8);
let right = CircuitArrow::passthrough(Shape::new(1, 8));
let t = CircuitArrow::tensor(left, right);
let lowered = lower(&t)?;
assert_eq!(lowered.inputs().len(), 4);
assert_eq!(lowered.outputs().len(), 3);
assert_eq!(lowered.graph().instructions().len(), 5);
Ok(())
}
#[test]
fn lower_compose() -> Result<(), hdl_cat::Error> {
let csa = CircuitArrow::csa_3to2(4);
let pt = CircuitArrow::passthrough(Shape::new(2, 5));
let composed = CircuitArrow::compose(csa, pt)
.map_err(|_| hdl_cat::Error::Overflow {
width: hdl_cat::Width::new(0),
})?;
let lowered = lower(&composed)?;
assert_eq!(lowered.inputs().len(), 3);
assert_eq!(lowered.outputs().len(), 2);
Ok(())
}
#[test]
fn lower_nine_op_tree() -> Result<(), hdl_cat::Error> {
let tree = crate::tree::compressor_tree(9, 16)
.map_err(|_| hdl_cat::Error::Overflow {
width: hdl_cat::Width::new(0),
})?;
let lowered = lower(&tree)?;
assert_eq!(lowered.inputs().len(), 9);
assert_eq!(lowered.outputs().len(), 2);
assert!(!lowered.graph().instructions().is_empty());
Ok(())
}
}