mod op;
mod param;
use std::collections::{HashMap, HashSet};
use hugr::builder::{Container, Dataflow, DataflowHugr, FunctionBuilder};
use hugr::extension::prelude::{bool_t, qb_t};
use hugr::ops::handle::NodeHandle;
use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::Signature;
use hugr::{Hugr, Wire};
use derive_more::Display;
use indexmap::IndexMap;
use itertools::{EitherOrBoth, Itertools};
use serde_json::json;
use tket_json_rs::circuit_json;
use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::register;
use super::{
OpConvertError, RegisterHash, Tk1ConvertError, METADATA_B_OUTPUT_REGISTERS,
METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};
use crate::extension::rotation::{rotation_type, RotationOp};
use crate::serialize::pytket::METADATA_INPUT_PARAMETERS;
use crate::symbolic_constant_op;
use op::Tk1Op;
use param::{parse_pytket_param, PytketParam};
#[derive(Debug, Clone)]
pub(super) struct Tk1DecoderContext {
pub hugr: FunctionBuilder<Hugr>,
register_wires: HashMap<RegisterHash, Wire>,
ordered_registers: Vec<RegisterHash>,
qubit_registers: HashSet<RegisterHash>,
parameters: IndexMap<String, LoadedParameter>,
}
impl Tk1DecoderContext {
pub fn try_new(serialcirc: &SerialCircuit) -> Result<Self, Tk1ConvertError> {
let num_qubits = serialcirc.qubits.len();
let num_bits = serialcirc.bits.len();
let sig =
Signature::new_endo([vec![qb_t(); num_qubits], vec![bool_t(); num_bits]].concat());
let name = serialcirc.name.clone().unwrap_or_default();
let mut dfg = FunctionBuilder::new(name, sig).unwrap();
let dangling_wires = dfg.input_wires().collect::<Vec<_>>();
dfg.set_metadata(METADATA_PHASE, json!(serialcirc.phase));
dfg.set_metadata(METADATA_Q_REGISTERS, json!(serialcirc.qubits));
dfg.set_metadata(METADATA_B_REGISTERS, json!(serialcirc.bits));
let mut output_qubits = Vec::with_capacity(serialcirc.qubits.len());
let mut output_bits = Vec::with_capacity(serialcirc.bits.len());
let output_to_input: HashMap<register::ElementId, register::ElementId> = serialcirc
.implicit_permutation
.iter()
.map(|p| (p.1.clone().id, p.0.clone().id))
.collect();
for qubit in &serialcirc.qubits {
output_qubits.push(output_to_input.get(&qubit.id).unwrap_or(&qubit.id).clone());
}
for bit in &serialcirc.bits {
output_bits.push(output_to_input.get(&bit.id).unwrap_or(&bit.id).clone());
}
dfg.set_metadata(METADATA_Q_OUTPUT_REGISTERS, json!(output_qubits));
dfg.set_metadata(METADATA_B_OUTPUT_REGISTERS, json!(output_bits));
let qubit_registers = serialcirc.qubits.iter().map(RegisterHash::from).collect();
let ordered_registers = serialcirc
.qubits
.iter()
.map(|qb| &qb.id)
.chain(serialcirc.bits.iter().map(|bit| &bit.id))
.map(|reg| {
check_register(reg)?;
Ok(RegisterHash::from(reg))
})
.collect::<Result<Vec<RegisterHash>, Tk1ConvertError>>()?;
let register_wires: HashMap<RegisterHash, Wire> = ordered_registers
.iter()
.copied()
.zip(dangling_wires)
.collect();
Ok(Tk1DecoderContext {
hugr: dfg,
register_wires,
ordered_registers,
qubit_registers,
parameters: IndexMap::new(),
})
}
pub fn finish(mut self) -> Hugr {
let mut outputs = Vec::with_capacity(self.ordered_registers.len());
for register in self.ordered_registers {
let wire = self.register_wires.remove(®ister).unwrap();
outputs.push(wire);
}
debug_assert!(
self.register_wires.is_empty(),
"Some output wires were not associated with a register."
);
if !self.parameters.is_empty() {
let params = self.parameters.keys().cloned().collect_vec();
self.hugr
.set_metadata(METADATA_INPUT_PARAMETERS, json!(params));
}
self.hugr.finish_hugr_with_outputs(outputs).unwrap()
}
pub fn add_command(&mut self, command: circuit_json::Command) -> Result<(), OpConvertError> {
let circuit_json::Command {
op, args, opgroup, ..
} = command;
let op_params = op.params.clone().unwrap_or_default();
let num_qubits = args
.iter()
.take_while(|&arg| self.is_qubit_register(arg))
.count();
let num_input_bits = args.len() - num_qubits;
let tk1op = Tk1Op::from_serialised_op(op, num_qubits, num_input_bits);
let (input_wires, output_registers) = self.get_op_wires(&tk1op, &args, op_params)?;
let op: OpType = (&tk1op).into();
let new_op = self.hugr.add_dataflow_op(op, input_wires).unwrap();
let wires = new_op.outputs();
if let Some(opgroup) = opgroup {
self.hugr
.set_child_metadata(new_op.node(), METADATA_OPGROUP, json!(opgroup));
}
for (register, wire) in output_registers.into_iter().zip_eq(wires) {
self.set_register_wire(register, wire);
}
Ok(())
}
fn get_op_wires(
&mut self,
tk1op: &Tk1Op,
args: &[register::ElementId],
params: Vec<String>,
) -> Result<(Vec<Wire>, Vec<RegisterHash>), OpConvertError> {
let mut inputs: Vec<Wire> = Vec::with_capacity(args.len() + params.len());
let mut outputs: Vec<RegisterHash> =
Vec::with_capacity(tk1op.qubit_outputs() + tk1op.bit_outputs());
let mut current_arg = 0;
let mut next_arg = || {
if args.len() <= current_arg {
return Err(OpConvertError::MissingSerialisedArguments {
optype: tk1op.optype(),
expected_qubits: tk1op.qubit_inputs(),
expected_bits: tk1op.bit_inputs(),
args: args.to_owned(),
});
}
current_arg += 1;
Ok(&args[current_arg - 1])
};
assert_eq!(
tk1op.qubit_inputs(),
tk1op.qubit_outputs(),
"Operations with different numbers of input and output qubits are not currently supported."
);
for _ in 0..tk1op.qubit_inputs() {
let reg = next_arg()?;
inputs.push(self.register_wire(reg));
outputs.push(reg.into());
}
for zip in (0..tk1op.bit_inputs()).zip_longest(0..tk1op.bit_outputs()) {
let reg = next_arg()?;
match zip {
EitherOrBoth::Both(_inp, _out) => {
inputs.push(self.register_wire(reg));
outputs.push(reg.into());
}
EitherOrBoth::Left(_inp) => {
inputs.push(self.register_wire(reg));
}
EitherOrBoth::Right(_out) => {
outputs.push(reg.into());
}
}
}
if tk1op.num_params() > params.len() {
return Err(OpConvertError::MissingSerialisedParams {
optype: tk1op.optype(),
expected: tk1op.num_params(),
params,
});
}
inputs.extend(
tk1op
.param_ports()
.zip(params)
.map(|(_port, param)| self.load_parameter(param)),
);
Ok((inputs, outputs))
}
fn load_parameter(&mut self, param: String) -> Wire {
fn process(
hugr: &mut FunctionBuilder<Hugr>,
input_params: &mut IndexMap<String, LoadedParameter>,
parsed: PytketParam,
param: &str,
) -> LoadedParameter {
match parsed {
PytketParam::Constant(half_turns) => {
let value: Value = ConstF64::new(half_turns).into();
let wire = hugr.add_load_const(value);
LoadedParameter::float(wire)
}
PytketParam::Sympy(expr) => {
let symb_op = symbolic_constant_op(expr.to_string());
let wire = hugr.add_dataflow_op(symb_op, []).unwrap().out_wire(0);
LoadedParameter::rotation(wire)
}
PytketParam::InputVariable { name } => {
if name == "pi" {
let value: Value = ConstF64::new(std::f64::consts::PI).into();
let wire = hugr.add_load_const(value);
return LoadedParameter::float(wire);
}
*input_params.entry(name.to_string()).or_insert_with(|| {
let wire = hugr.add_input(rotation_type());
LoadedParameter::rotation(wire)
})
}
PytketParam::Operation { op, args } => {
let input_wires = args
.into_iter()
.map(|arg| process(hugr, input_params, arg, param).as_float(hugr).wire)
.collect_vec();
let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| {
panic!("Error while decoding pytket operation parameter \"{param}\". {e}",)
});
assert_eq!(res.num_value_outputs(), 1, "An operation decoded from the pytket op parameter \"{param}\" had {} outputs", res.num_value_outputs());
LoadedParameter::float(res.out_wire(0))
}
}
}
let parsed = parse_pytket_param(¶m);
process(&mut self.hugr, &mut self.parameters, parsed, ¶m)
.as_rotation(&mut self.hugr)
.wire
}
fn register_wire(&self, register: impl Into<RegisterHash>) -> Wire {
self.register_wires[®ister.into()]
}
fn set_register_wire(&mut self, register: impl Into<RegisterHash>, unit: Wire) {
self.register_wires.insert(register.into(), unit);
}
fn is_qubit_register(&self, register: impl Into<RegisterHash>) -> bool {
self.qubit_registers.contains(®ister.into())
}
}
fn check_register(register: ®ister::ElementId) -> Result<(), Tk1ConvertError> {
if register.1.len() != 1 {
Err(Tk1ConvertError::MultiIndexedRegister {
register: register.0.clone(),
})
} else {
Ok(())
}
}
#[derive(Debug, Display, Clone, Copy, Hash, PartialEq, Eq)]
enum LoadedParameterType {
Float,
Rotation,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
struct LoadedParameter {
pub typ: LoadedParameterType,
pub wire: Wire,
}
impl LoadedParameter {
pub fn float(wire: Wire) -> LoadedParameter {
LoadedParameter {
typ: LoadedParameterType::Float,
wire,
}
}
pub fn rotation(wire: Wire) -> LoadedParameter {
LoadedParameter {
typ: LoadedParameterType::Rotation,
wire,
}
}
pub fn as_type(
&self,
typ: LoadedParameterType,
hugr: &mut FunctionBuilder<Hugr>,
) -> LoadedParameter {
match (self.typ, typ) {
(LoadedParameterType::Float, LoadedParameterType::Rotation) => {
let wire = hugr
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire])
.unwrap()
.out_wire(0);
LoadedParameter::rotation(wire)
}
(LoadedParameterType::Rotation, LoadedParameterType::Float) => {
let wire = hugr
.add_dataflow_op(RotationOp::to_halfturns, [self.wire])
.unwrap()
.out_wire(0);
LoadedParameter::float(wire)
}
_ => {
debug_assert_eq!(self.typ, typ, "cannot convert {} to {}", self.typ, typ);
*self
}
}
}
pub fn as_float(&self, hugr: &mut FunctionBuilder<Hugr>) -> LoadedParameter {
self.as_type(LoadedParameterType::Float, hugr)
}
pub fn as_rotation(&self, hugr: &mut FunctionBuilder<Hugr>) -> LoadedParameter {
self.as_type(LoadedParameterType::Rotation, hugr)
}
}