mod config;
mod unit_generator;
mod unsupported_tracker;
mod value_tracker;
pub use config::{default_encoder_config, Tk1EncoderConfig};
use hugr::envelope::EnvelopeConfig;
use hugr::hugr::views::SiblingSubgraph;
use hugr::package::Package;
use hugr_core::hugr::internal::PortgraphNodeMap;
use tket_json_rs::clexpr::InputClRegister;
use tket_json_rs::opbox::BoxID;
pub use value_tracker::{
RegisterCount, TrackedBit, TrackedParam, TrackedQubit, TrackedValue, TrackedValues,
ValueTracker,
};
use hugr::ops::{OpTrait, OpType};
use hugr::types::EdgeKind;
use std::borrow::Cow;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::sync::{Arc, RwLock};
use hugr::{HugrView, Wire};
use itertools::Itertools;
use tket_json_rs::circuit_json::{self, SerialCircuit};
use unsupported_tracker::UnsupportedTracker;
use super::{
OpConvertError, Tk1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_OPGROUP, METADATA_PHASE,
METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS,
};
use crate::circuit::Circuit;
#[derive(derive_more::Debug)]
#[debug(bounds(H: HugrView))]
pub struct Tk1EncoderContext<H: HugrView> {
name: Option<String>,
phase: String,
commands: Vec<circuit_json::Command>,
pub values: ValueTracker<H::Node>,
unsupported: UnsupportedTracker<H::Node>,
config: Arc<Tk1EncoderConfig<H>>,
function_cache: Arc<RwLock<HashMap<H::Node, CachedEncodedFunction>>>,
}
impl<H: HugrView> Tk1EncoderContext<H> {
pub(super) fn new(
circ: &Circuit<H>,
region: H::Node,
config: Tk1EncoderConfig<H>,
) -> Result<Self, Tk1ConvertError<H::Node>> {
Self::new_arc(circ, region, Arc::new(config))
}
fn new_arc(
circ: &Circuit<H>,
region: H::Node,
config: Arc<Tk1EncoderConfig<H>>,
) -> Result<Self, Tk1ConvertError<H::Node>> {
let hugr = circ.hugr();
let name = Circuit::new(hugr.with_entrypoint(region))
.name()
.map(str::to_string);
let phase = match hugr.get_metadata(region, METADATA_PHASE) {
Some(p) => p.as_str().unwrap().to_string(),
None => "0".to_string(),
};
Ok(Self {
name,
phase,
commands: vec![],
values: ValueTracker::new(circ, region, &config)?,
unsupported: UnsupportedTracker::new(circ),
config,
function_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
pub(super) fn run_encoder(
&mut self,
circ: &Circuit<H>,
region: H::Node,
) -> Result<(), Tk1ConvertError<H::Node>> {
let (region, node_map) = circ.hugr().region_portgraph(region);
let io_nodes = circ.io_nodes();
let mut topo = petgraph::visit::Topo::new(®ion);
while let Some(pg_node) = topo.next(®ion) {
let node = node_map.from_portgraph(pg_node);
if io_nodes.contains(&node) {
continue;
}
self.try_encode_node(node, circ)?;
}
Ok(())
}
pub(super) fn finish(
mut self,
circ: &Circuit<H>,
region: H::Node,
) -> Result<(SerialCircuit, Vec<String>), Tk1ConvertError<H::Node>> {
while !self.unsupported.is_empty() {
let node = self.unsupported.iter().next().unwrap();
let unsupported_subgraph = self.unsupported.extract_component(node);
self.emit_unsupported(unsupported_subgraph, circ)?;
}
let final_values = self.values.finish(circ, region)?;
let mut ser = SerialCircuit::new(self.name, self.phase);
ser.commands = self.commands;
ser.qubits = final_values.qubits.into_iter().map_into().collect();
ser.bits = final_values.bits.into_iter().map_into().collect();
ser.implicit_permutation = final_values.qubit_permutation;
ser.number_of_ws = None;
Ok((ser, final_values.params))
}
pub fn config(&self) -> &Tk1EncoderConfig<H> {
&self.config
}
pub fn get_wire_values(
&mut self,
wire: Wire<H::Node>,
circ: &Circuit<H>,
) -> Result<Cow<'_, [TrackedValue]>, Tk1ConvertError<H::Node>> {
if self.values.peek_wire_values(wire).is_some() {
return Ok(self.values.wire_values(wire).unwrap());
}
if self.unsupported.is_unsupported(wire.node()) {
let unsupported_subgraph = self.unsupported.extract_component(wire.node());
self.emit_unsupported(unsupported_subgraph, circ)?;
debug_assert!(!self.unsupported.is_unsupported(wire.node()));
return self.get_wire_values(wire, circ);
}
Err(OpConvertError::WireHasNoValues { wire }.into())
}
pub fn get_input_values(
&mut self,
node: H::Node,
circ: &Circuit<H>,
) -> Result<TrackedValues, Tk1ConvertError<H::Node>> {
self.get_input_values_internal(node, circ, |_| true)
}
fn get_input_values_internal(
&mut self,
node: H::Node,
circ: &Circuit<H>,
wire_filter: impl Fn(Wire<H::Node>) -> bool,
) -> Result<TrackedValues, Tk1ConvertError<H::Node>> {
let mut qubits: Vec<TrackedQubit> = Vec::new();
let mut bits: Vec<TrackedBit> = Vec::new();
let mut params: Vec<TrackedParam> = Vec::new();
let optype = circ.hugr().get_optype(node);
let other_input_port = optype.other_input_port();
for input in circ.hugr().node_inputs(node) {
if Some(input) == other_input_port {
continue;
}
let Some((neigh, neigh_out)) = circ.hugr().single_linked_output(node, input) else {
return Err(
OpConvertError::UnsupportedOpSerialization { op: optype.clone() }.into(),
);
};
let wire = Wire::new(neigh, neigh_out);
if !wire_filter(wire) {
continue;
}
for value in self.get_wire_values(wire, circ)?.iter() {
match value {
TrackedValue::Qubit(qb) => qubits.push(*qb),
TrackedValue::Bit(b) => bits.push(*b),
TrackedValue::Param(p) => params.push(*p),
}
}
}
Ok(TrackedValues {
qubits,
bits,
params,
})
}
pub fn emit_node(
&mut self,
tk1_optype: tket_json_rs::OpType,
node: H::Node,
circ: &Circuit<H>,
) -> Result<(), Tk1ConvertError<H::Node>> {
self.emit_node_with_out_params(tk1_optype, node, circ, |_| Vec::new())
}
pub fn emit_node_with_out_params(
&mut self,
tk1_optype: tket_json_rs::OpType,
node: H::Node,
circ: &Circuit<H>,
output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec<String>,
) -> Result<(), Tk1ConvertError<H::Node>> {
self.emit_node_command(node, circ, output_params, move |inputs| {
make_tk1_operation(tk1_optype, inputs)
})
}
pub fn emit_node_command(
&mut self,
node: H::Node,
circ: &Circuit<H>,
output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec<String>,
make_operation: impl FnOnce(MakeOperationArgs<'_>) -> tket_json_rs::circuit_json::Operation,
) -> Result<(), Tk1ConvertError<H::Node>> {
let TrackedValues {
mut qubits,
mut bits,
params,
} = self.get_input_values(node, circ)?;
let params: Vec<String> = params
.into_iter()
.map(|p| self.values.param_expression(p).to_owned())
.collect();
let mut qubit_iterator = qubits.iter().copied();
let new_outputs = self.register_node_outputs(
node,
circ,
&mut qubit_iterator,
¶ms,
output_params,
|_| true,
)?;
qubits.extend(new_outputs.qubits);
bits.extend(new_outputs.bits);
let opgroup: Option<String> = circ
.hugr()
.get_metadata(node, METADATA_OPGROUP)
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let args = MakeOperationArgs {
num_qubits: qubits.len(),
num_bits: bits.len(),
params: ¶ms,
};
let op = make_operation(args);
self.emit_command(op, &qubits, &bits, opgroup);
Ok(())
}
pub fn emit_transparent_node(
&mut self,
node: H::Node,
circ: &Circuit<H>,
output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec<String>,
) -> Result<(), Tk1ConvertError<H::Node>> {
let input_values = self.get_input_values(node, circ)?;
let output_counts = self.node_output_values(node, circ)?;
let total_out_count: RegisterCount = output_counts.iter().map(|(_, c)| *c).sum();
let input_params: Vec<String> = input_values
.params
.into_iter()
.map(|p| self.values.param_expression(p).to_owned())
.collect_vec();
let out_params = output_params(OutputParamArgs {
expected_count: total_out_count.params,
input_params: &input_params,
});
if input_values.qubits.len() != total_out_count.qubits {
return Err(Tk1ConvertError::custom(format!(
"Mismatched number of input and output qubits while trying to emit a transparent operation for {}. We have {} inputs but {} outputs.",
circ.hugr().get_optype(node),
input_values.qubits.len(),
total_out_count.qubits,
)));
}
if input_values.bits.len() != total_out_count.bits {
return Err(Tk1ConvertError::custom(format!(
"Mismatched number of input and output bits while trying to emit a transparent operation for {}. We have {} inputs but {} outputs.",
circ.hugr().get_optype(node),
input_values.bits.len(),
total_out_count.bits,
)));
}
if out_params.len() != total_out_count.params {
return Err(Tk1ConvertError::custom(format!(
"Expected {} parameters in the input values for a {}, but got {}.",
total_out_count.params,
circ.hugr().get_optype(node),
out_params.len()
)));
}
let mut qubits = input_values.qubits.into_iter();
let mut bits = input_values.bits.into_iter();
let mut params = out_params.into_iter();
for (wire, count) in output_counts {
let mut values: Vec<TrackedValue> = Vec::with_capacity(count.total());
values.extend(qubits.by_ref().take(count.qubits).map(TrackedValue::Qubit));
values.extend(bits.by_ref().take(count.bits).map(TrackedValue::Bit));
for p in params.by_ref().take(count.params) {
values.push(self.values.new_param(p).into());
}
self.values.register_wire(wire, values, circ)?;
}
Ok(())
}
fn emit_unsupported(
&mut self,
unsupported_nodes: BTreeSet<H::Node>,
circ: &Circuit<H>,
) -> Result<(), Tk1ConvertError<H::Node>> {
let subcircuit_id = format!("tk{}", unsupported_nodes.iter().min().unwrap());
let subgraph = SiblingSubgraph::try_from_nodes(
unsupported_nodes.iter().cloned().collect_vec(),
circ.hugr(),
)
.unwrap_or_else(|_| {
panic!(
"Failed to create subgraph from unsupported nodes [{}]",
unsupported_nodes.iter().join(", ")
)
});
let input_nodes: HashSet<_> = subgraph
.incoming_ports()
.iter()
.flat_map(|inp| inp.iter().map(|(n, _)| *n))
.collect();
let output_nodes: HashSet<_> = subgraph.outgoing_ports().iter().map(|(n, _)| *n).collect();
let unsupported_hugr = subgraph.extract_subgraph(circ.hugr(), &subcircuit_id);
let payload = Package::from_hugr(unsupported_hugr)
.store_str(EnvelopeConfig::text())
.unwrap();
let mut op_values = TrackedValues::default();
for node in &input_nodes {
let node_vals = self.get_input_values_internal(*node, circ, |w| {
unsupported_nodes.contains(&w.node())
})?;
op_values.append(node_vals);
}
let input_param_exprs: Vec<String> = std::mem::take(&mut op_values.params)
.into_iter()
.map(|p| self.values.param_expression(p).to_owned())
.collect();
let mut input_qubits = op_values.qubits.clone().into_iter();
for &node in &output_nodes {
let new_outputs = self.register_node_outputs(
node,
circ,
&mut input_qubits,
&[],
|p| {
(0..p.expected_count)
.map(|i| format!("{subcircuit_id}_out{i}"))
.collect_vec()
},
|_| true,
)?;
op_values.append(new_outputs);
}
let args = MakeOperationArgs {
num_qubits: op_values.qubits.len(),
num_bits: op_values.bits.len(),
params: &input_param_exprs,
};
let mut tk1_op = make_tk1_operation(tket_json_rs::OpType::Barrier, args);
tk1_op.data = Some(payload);
let opgroup = Some("tket2".to_string());
self.emit_command(tk1_op, &op_values.qubits, &op_values.bits, opgroup);
Ok(())
}
pub fn emit_command(
&mut self,
tk1_operation: circuit_json::Operation,
qubits: &[TrackedQubit],
bits: &[TrackedBit],
opgroup: Option<String>,
) {
let qubit_regs = qubits.iter().map(|&qb| self.values.qubit_register(qb));
let bit_regs = bits.iter().map(|&b| self.values.bit_register(b));
let command = circuit_json::Command {
op: tk1_operation,
args: qubit_regs.chain(bit_regs).cloned().collect(),
opgroup,
};
self.commands.push(command);
}
fn emit_subcircuit(
&mut self,
node: H::Node,
circ: &Circuit<H>,
) -> Result<EncodeStatus, Tk1ConvertError<H::Node>> {
let config = Arc::clone(&self.config);
let mut subencoder = Tk1EncoderContext::new_arc(circ, node, config)?;
subencoder.function_cache = self.function_cache.clone();
subencoder.run_encoder(circ, node)?;
let (serial_subcirc, output_params) = subencoder.finish(circ, node)?;
if !output_params.is_empty() {
return Ok(EncodeStatus::Unsupported);
}
self.emit_circ_box(node, serial_subcirc, circ)?;
Ok(EncodeStatus::Success)
}
fn emit_function_call(
&mut self,
node: H::Node,
function: H::Node,
circ: &Circuit<H>,
) -> Result<EncodeStatus, Tk1ConvertError<H::Node>> {
let cache = self.function_cache.read().ok();
if let Some(encoded) = cache.as_ref().and_then(|c| c.get(&function)) {
let encoded = encoded.clone();
drop(cache);
match encoded {
CachedEncodedFunction::Encoded { serial_circuit } => {
self.emit_circ_box(node, serial_circuit, circ)?;
return Ok(EncodeStatus::Success);
}
CachedEncodedFunction::Unsupported => return Ok(EncodeStatus::Unsupported),
};
}
drop(cache);
let config = Arc::clone(&self.config);
let mut subencoder = Tk1EncoderContext::new_arc(circ, function, config)?;
subencoder.function_cache = self.function_cache.clone();
subencoder.run_encoder(circ, function)?;
let (serial_subcirc, output_params) = subencoder.finish(circ, function)?;
let (result, cached_fn) = match output_params.is_empty() {
true => (
EncodeStatus::Success,
CachedEncodedFunction::Encoded {
serial_circuit: serial_subcirc.clone(),
},
),
false => (
EncodeStatus::Unsupported,
CachedEncodedFunction::Unsupported,
),
};
if let Ok(mut cache) = self.function_cache.write() {
cache.insert(function, cached_fn);
}
if result == EncodeStatus::Success {
self.emit_circ_box(node, serial_subcirc, circ)?;
}
Ok(result)
}
fn emit_circ_box(
&mut self,
node: H::Node,
boxed_circuit: SerialCircuit,
circ: &Circuit<H>,
) -> Result<(), Tk1ConvertError<H::Node>> {
self.emit_node_command(
node,
circ,
|args| {
debug_assert!(args.expected_count == 0);
Vec::new()
},
|args| {
let mut pytket_op = make_tk1_operation(tket_json_rs::OpType::CircBox, args);
pytket_op.op_box = Some(tket_json_rs::opbox::OpBox::CircBox {
id: BoxID::new(),
circuit: boxed_circuit,
});
pytket_op
},
)?;
Ok(())
}
fn try_encode_node(
&mut self,
node: H::Node,
circ: &Circuit<H>,
) -> Result<EncodeStatus, Tk1ConvertError<H::Node>> {
let optype = circ.hugr().get_optype(node);
match optype {
OpType::ExtensionOp(op) => {
let config = Arc::clone(&self.config);
if config.op_to_pytket(node, op, circ, self)? == EncodeStatus::Success {
return Ok(EncodeStatus::Success);
}
}
OpType::LoadConstant(_) => {
self.emit_transparent_node(node, circ, |ps| ps.input_params.to_owned())?;
return Ok(EncodeStatus::Success);
}
OpType::Const(op) => {
let config = Arc::clone(&self.config);
if let Some(values) = config.const_to_pytket(&op.value, self)? {
let wire = Wire::new(node, 0);
self.values.register_wire(wire, values.into_iter(), circ)?;
return Ok(EncodeStatus::Success);
}
}
OpType::DFG(_) => return self.emit_subcircuit(node, circ),
OpType::Call(call) => {
let (fn_node, _) = circ
.hugr()
.single_linked_output(node, call.called_function_port())
.expect("Function call must be linked to a function");
return self.emit_function_call(node, fn_node, circ);
}
_ => {}
}
self.unsupported.record_node(node, circ);
Ok(EncodeStatus::Unsupported)
}
fn register_node_outputs(
&mut self,
node: H::Node,
circ: &Circuit<H>,
qubits: &mut impl Iterator<Item = TrackedQubit>,
input_params: &[String],
output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec<String>,
wire_filter: impl Fn(Wire<H::Node>) -> bool,
) -> Result<TrackedValues, Tk1ConvertError<H::Node>> {
let output_counts = self.node_output_values(node, circ)?;
let total_out_count: RegisterCount = output_counts.iter().map(|(_, c)| *c).sum();
let out_params = output_params(OutputParamArgs {
expected_count: total_out_count.params,
input_params,
});
if out_params.len() != total_out_count.params {
return Err(Tk1ConvertError::custom(format!(
"Expected {} parameters in the input values for a {}, but got {}.",
total_out_count.params,
circ.hugr().get_optype(node),
out_params.len()
)));
}
let mut params = out_params.into_iter();
let mut new_outputs = TrackedValues::default();
for (wire, count) in output_counts {
if !wire_filter(wire) {
continue;
}
let mut out_wire_values = Vec::with_capacity(count.total());
out_wire_values.extend(qubits.by_ref().take(count.qubits).map(TrackedValue::Qubit));
for _ in out_wire_values.len()..count.qubits {
let qb = self.values.new_qubit();
new_outputs.qubits.push(qb);
out_wire_values.push(TrackedValue::Qubit(qb));
}
for _ in 0..count.bits {
let b = self.values.new_bit();
new_outputs.bits.push(b);
out_wire_values.push(TrackedValue::Bit(b));
}
for expr in params.by_ref().take(count.params) {
let p = self.values.new_param(expr);
new_outputs.params.push(p);
out_wire_values.push(p.into());
}
self.values.register_wire(wire, out_wire_values, circ)?;
}
Ok(new_outputs)
}
#[allow(clippy::type_complexity)]
fn node_output_values(
&self,
node: H::Node,
circ: &Circuit<H>,
) -> Result<Vec<(Wire<H::Node>, RegisterCount)>, Tk1ConvertError<H::Node>> {
let op = circ.hugr().get_optype(node);
let signature = op.dataflow_signature();
let static_output = op.static_output_port();
let other_output = op.other_output_port();
let mut wire_counts = Vec::with_capacity(circ.hugr().num_outputs(node));
for out_port in circ.hugr().node_outputs(node) {
let ty = if Some(out_port) == other_output {
continue;
} else if Some(out_port) == static_output {
let EdgeKind::Const(ty) = op.static_output().unwrap() else {
return Err(Tk1ConvertError::custom(format!(
"Cannot emit a static output for a {op}."
)));
};
ty
} else {
let Some(ty) = signature
.as_ref()
.and_then(|s| s.out_port_type(out_port).cloned())
else {
return Err(Tk1ConvertError::custom(
"Cannot emit a transparent node without a dataflow signature.",
));
};
ty
};
let wire = hugr::Wire::new(node, out_port);
let Some(count) = self.config().type_to_pytket(&ty)? else {
return Err(Tk1ConvertError::custom(format!(
"Found an unsupported type while encoding a {op}."
)));
};
wire_counts.push((wire, count));
}
Ok(wire_counts)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, derive_more::Display)]
pub enum EncodeStatus {
Success,
Unsupported,
}
#[derive(Clone, Copy, Debug)]
pub struct OutputParamArgs<'a> {
pub expected_count: usize,
pub input_params: &'a [String],
}
#[derive(Clone, Copy, Debug)]
pub struct MakeOperationArgs<'a> {
pub num_qubits: usize,
pub num_bits: usize,
pub params: &'a [String],
}
#[derive(Clone, Debug)]
enum CachedEncodedFunction {
Encoded {
serial_circuit: SerialCircuit,
},
Unsupported,
}
pub fn make_tk1_operation(
tk1_optype: tket_json_rs::OpType,
inputs: MakeOperationArgs<'_>,
) -> circuit_json::Operation {
let mut op = circuit_json::Operation::default();
op.op_type = tk1_optype;
op.n_qb = Some(inputs.num_qubits as u32);
op.params = match inputs.params.is_empty() {
false => Some(inputs.params.to_owned()),
true => None,
};
op.signature = Some(
[
vec!["Q".into(); inputs.num_qubits],
vec!["B".into(); inputs.num_bits],
]
.concat(),
);
op
}
pub fn make_tk1_classical_operation(
tk1_optype: tket_json_rs::OpType,
bit_count: usize,
classical: tket_json_rs::circuit_json::Classical,
) -> tket_json_rs::circuit_json::Operation {
let args = MakeOperationArgs {
num_qubits: 0,
num_bits: bit_count,
params: &[],
};
let mut op = make_tk1_operation(tk1_optype, args);
op.classical = Some(Box::new(classical));
op
}
pub fn make_tk1_classical_expression(
bit_count: usize,
output_bits: &[u32],
registers: &[InputClRegister],
expression: tket_json_rs::clexpr::operator::ClOperator,
) -> tket_json_rs::circuit_json::Operation {
let mut clexpr = tket_json_rs::clexpr::ClExpr::default();
clexpr.bit_posn = (0..bit_count as u32).map(|i| (i, i)).collect();
clexpr.reg_posn = registers.to_vec();
clexpr.output_posn = tket_json_rs::clexpr::ClRegisterBits(output_bits.to_vec());
clexpr.expr = expression;
let args = MakeOperationArgs {
num_qubits: 0,
num_bits: bit_count,
params: &[],
};
let mut op = make_tk1_operation(tket_json_rs::OpType::ClExpr, args);
op.classical_expr = Some(clexpr);
op
}