use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};
use derive_more::From;
use hugr::builder::{Container, FunctionBuilder};
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::TopoConvexChecker;
use hugr::hugr::views::SiblingSubgraph;
use hugr::hugr::{HugrError, NodeMetadataMap};
use hugr::ops::handle::DataflowParentID;
use hugr::ops::OpType;
use hugr::types::Signature;
use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire};
use hugr_core::hugr::internal::{HugrInternals, HugrMutInternals as _};
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use crate::Circuit;
use crate::circuit::cost::{CircuitCost, CostDelta};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
pub struct ChunkConnection(Wire);
#[derive(Debug, Clone)]
pub struct Chunk {
pub circ: Circuit,
inputs: Vec<ChunkConnection>,
outputs: Vec<ChunkConnection>,
}
impl Chunk {
pub(self) fn extract(
circ: &Circuit,
nodes: impl IntoIterator<Item = Node>,
checker: &TopoConvexChecker<'_, Hugr>,
) -> Self {
let subgraph = SiblingSubgraph::try_from_nodes_with_checker(
nodes.into_iter().collect_vec(),
circ.hugr(),
checker,
)
.expect("Failed to define the chunk subgraph");
let extracted = subgraph.extract_subgraph(circ.hugr(), "Chunk").into();
let inputs = subgraph
.incoming_ports()
.iter()
.map(|wires| {
let (inp_node, inp_port) = wires[0];
let (out_node, out_port) = circ
.hugr()
.linked_outputs(inp_node, inp_port)
.exactly_one()
.ok()
.unwrap();
Wire::new(out_node, out_port).into()
})
.collect();
let outputs = subgraph
.outgoing_ports()
.iter()
.map(|&(node, port)| Wire::new(node, port).into())
.collect();
Self {
circ: extracted,
inputs,
outputs,
}
}
pub(self) fn insert(
&self,
circ: &mut impl HugrMut<Node = Node>,
root: Node,
) -> ChunkInsertResult {
let chunk = self.circ.hugr();
let chunk_root = chunk.entrypoint();
if chunk.children(self.circ.parent()).nth(2).is_none() {
return self.empty_chunk_insert_result();
}
let [chunk_inp, chunk_out] = chunk.get_io(chunk_root).unwrap();
let subgraph =
SiblingSubgraph::<Node>::try_new_dataflow_subgraph::<_, DataflowParentID>(&chunk)
.unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}"));
let node_map = circ.insert_subgraph(root, &chunk, &subgraph);
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());
for (&connection, chunk_inp_port) in self.inputs.iter().zip(chunk.node_outputs(chunk_inp)) {
let connection_targets: Vec<ConnectionTarget> = chunk
.linked_inputs(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
if node == chunk_out {
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
} else {
ConnectionTarget::InsertedInput(*node_map.get(&node).unwrap(), port)
}
})
.collect();
input_map.insert(connection, connection_targets);
}
for (&wire, chunk_out_port) in self.outputs.iter().zip(chunk.node_inputs(chunk_out)) {
let (node, port) = chunk
.linked_outputs(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
let target = if node == chunk_inp {
let input_connection = self.inputs[port.index()];
ConnectionTarget::TransitiveConnection(input_connection)
} else {
ConnectionTarget::InsertedOutput(*node_map.get(&node).unwrap(), port)
};
output_map.insert(wire, target);
}
ChunkInsertResult {
incoming_connections: input_map,
outgoing_connections: output_map,
}
}
fn empty_chunk_insert_result(&self) -> ChunkInsertResult {
let hugr = self.circ.hugr();
let [chunk_inp, chunk_out] = self.circ.io_nodes();
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());
for (&connection, chunk_inp_port) in self.inputs.iter().zip(hugr.node_outputs(chunk_inp)) {
let connection_targets: Vec<ConnectionTarget> = hugr
.linked_ports(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
assert_eq!(node, chunk_out);
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
})
.collect();
input_map.insert(connection, connection_targets);
}
for (&wire, chunk_out_port) in self.outputs.iter().zip(hugr.node_inputs(chunk_out)) {
let (node, port) = hugr
.linked_ports(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
assert_eq!(node, chunk_inp);
let input_connection = self.inputs[port.index()];
output_map.insert(
wire,
ConnectionTarget::TransitiveConnection(input_connection),
);
}
ChunkInsertResult {
incoming_connections: input_map,
outgoing_connections: output_map,
}
}
}
#[derive(Debug, Clone)]
struct ChunkInsertResult {
pub incoming_connections: HashMap<ChunkConnection, Vec<ConnectionTarget>>,
pub outgoing_connections: HashMap<ChunkConnection, ConnectionTarget>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum ConnectionTarget {
InsertedInput(Node, IncomingPort),
InsertedOutput(Node, OutgoingPort),
TransitiveConnection(ChunkConnection),
}
#[derive(Debug, Clone)]
pub struct CircuitChunks {
signature: Signature,
root_meta: Option<NodeMetadataMap>,
input_connections: Vec<ChunkConnection>,
output_connections: Vec<ChunkConnection>,
chunks: Vec<Chunk>,
}
impl CircuitChunks {
pub fn split(circ: &Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1)
}
pub fn split_with_cost<C: CircuitCost>(
circ: &Circuit,
max_cost: C,
op_cost: impl Fn(&OpType) -> C,
) -> Self {
let hugr = circ.hugr();
let root_meta = hugr.node_metadata_map(circ.parent()).clone();
let signature = circ.circuit_signature().clone();
let [circ_input, circ_output] = circ.io_nodes();
let input_connections = hugr
.node_outputs(circ_input)
.map(|port| Wire::new(circ_input, port).into())
.collect();
let output_connections = hugr
.node_inputs(circ_output)
.flat_map(|p| hugr.linked_outputs(circ_output, p))
.map(|(n, p)| Wire::new(n, p).into())
.collect();
let mut chunks = Vec::new();
let convex_checker = TopoConvexChecker::new(circ.hugr(), circ.parent());
let mut running_cost = C::default();
let mut current_group = 0;
for (_, commands) in &circ.commands().map(|cmd| cmd.node()).chunk_by(|&node| {
let new_cost = running_cost.clone() + op_cost(hugr.get_optype(node));
if new_cost.sub_cost(&max_cost).as_isize() > 0 {
running_cost = C::default();
current_group += 1;
} else {
running_cost = new_cost;
}
current_group
}) {
chunks.push(Chunk::extract(circ, commands, &convex_checker));
}
Self {
signature: signature.into_owned(),
root_meta: Some(root_meta),
input_connections,
output_connections,
chunks,
}
}
pub fn reassemble(self) -> Result<Circuit, HugrError> {
let name = self
.root_meta
.as_ref()
.and_then(|map| map.get("name"))
.and_then(|s| s.as_str())
.unwrap_or("");
let mut builder = FunctionBuilder::new(name, self.signature).unwrap();
let mut reassembled = mem::take(builder.hugr_mut());
let root = reassembled.entrypoint();
let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap();
let mut sources: HashMap<ChunkConnection, (Node, OutgoingPort)> = HashMap::new();
let mut targets: HashMap<ChunkConnection, Vec<(Node, IncomingPort)>> = HashMap::new();
let mut transitive_connections: HashMap<ChunkConnection, ChunkConnection> = HashMap::new();
let get_merged_connection = |transitive_connections: &HashMap<_, _>, connection| {
transitive_connections
.get(&connection)
.copied()
.unwrap_or(connection)
};
for (&connection, port) in self
.input_connections
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
sources.insert(connection, (reassembled_input, port));
}
for chunk in self.chunks {
let ChunkInsertResult {
incoming_connections,
outgoing_connections,
} = chunk.insert(&mut reassembled, root);
for (connection, conn_target) in outgoing_connections {
match conn_target {
ConnectionTarget::InsertedOutput(node, port) => {
sources.insert(connection, (node, port));
}
ConnectionTarget::TransitiveConnection(merged_connection) => {
let merged_connection =
get_merged_connection(&transitive_connections, merged_connection);
transitive_connections.insert(connection, merged_connection);
}
_ => panic!("Unexpected connection target"),
}
}
for (connection, conn_targets) in incoming_connections {
let connection = get_merged_connection(&transitive_connections, connection);
for tgt in conn_targets {
match tgt {
ConnectionTarget::InsertedInput(node, port) => {
targets.entry(connection).or_default().push((node, port));
}
ConnectionTarget::TransitiveConnection(_merged_connection) => {
}
_ => panic!("Unexpected connection target"),
}
}
}
}
for (&connection, port) in self
.output_connections
.iter()
.zip(reassembled.node_inputs(reassembled_output))
{
let connection = get_merged_connection(&transitive_connections, connection);
targets
.entry(connection)
.or_default()
.push((reassembled_output, port));
}
for (connection, (source, source_port)) in sources {
let Some(tgts) = targets.remove(&connection) else {
continue;
};
for (target, target_port) in tgts {
reassembled.connect(source, source_port, target, target_port);
}
}
*reassembled.node_metadata_map_mut(root) = self.root_meta.unwrap_or_default();
Ok(reassembled.into())
}
pub fn iter(&self) -> impl Iterator<Item = &Circuit> {
self.chunks.iter().map(|chunk| &chunk.circ)
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Circuit> {
self.chunks.iter_mut().map(|chunk| &mut chunk.circ)
}
pub fn len(&self) -> usize {
self.chunks.len()
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
fn par_iter_mut(
&mut self,
) -> rayon::iter::Map<
rayon::slice::IterMut<'_, Chunk>,
for<'a> fn(&'a mut Chunk) -> &'a mut Circuit,
> {
self.chunks
.as_parallel_slice_mut()
.into_par_iter()
.map(|chunk| &mut chunk.circ)
}
}
impl Index<usize> for CircuitChunks {
type Output = Circuit;
fn index(&self, index: usize) -> &Self::Output {
&self.chunks[index].circ
}
}
impl IndexMut<usize> for CircuitChunks {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.chunks[index].circ
}
}
impl<'data> IntoParallelRefMutIterator<'data> for CircuitChunks {
type Item = &'data mut Circuit;
type Iter = rayon::iter::Map<
rayon::slice::IterMut<'data, Chunk>,
for<'a> fn(&'a mut Chunk) -> &'a mut Circuit,
>;
fn par_iter_mut(&'data mut self) -> Self::Iter {
self.par_iter_mut()
}
}
#[cfg(test)]
mod test {
use crate::circuit::CircuitHash;
use crate::utils::build_simple_circuit;
use crate::Tk2Op;
use super::*;
#[test]
fn split_reassemble() {
let circ = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::T, [1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
Ok(())
})
.unwrap();
let chunks = CircuitChunks::split(&circ, 3);
assert_eq!(chunks.len(), 3);
let mut reassembled = chunks.reassemble().unwrap();
reassembled.hugr_mut().validate().unwrap();
assert_eq!(
circ.circuit_hash(circ.parent()),
reassembled.circuit_hash(reassembled.parent())
);
}
#[test]
fn reassemble_empty() {
let circ = build_simple_circuit(3, |circ| {
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::H, [1])?;
Ok(())
})
.unwrap();
let circ_1q_id = build_simple_circuit(1, |_| Ok(())).unwrap();
let circ_2q_id_h = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
Ok(())
})
.unwrap();
let mut chunks = CircuitChunks::split(&circ, 1);
chunks[0] = circ_2q_id_h.clone();
chunks[1] = circ_1q_id.clone();
chunks[2] = circ_1q_id.clone();
let mut reassembled = chunks.reassemble().unwrap();
reassembled.hugr_mut().validate().unwrap();
assert_eq!(reassembled.commands().count(), 1);
let h = reassembled.commands().next().unwrap().node();
let [inp, out] = reassembled.io_nodes();
assert_eq!(
&reassembled.hugr().output_neighbours(inp).collect_vec(),
&[h, out, out]
);
}
}