use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use hugr::core::HugrNode;
use hugr::ops::OpParent;
use hugr::{HugrView, Wire};
use itertools::Itertools;
use tket_json_rs::circuit_json;
use tket_json_rs::register::ElementId as RegisterUnit;
use crate::circuit::Circuit;
use crate::serialize::pytket::{
OpConvertError, RegisterHash, Tk1ConvertError, METADATA_B_REGISTERS, METADATA_INPUT_PARAMETERS,
};
use super::unit_generator::RegisterUnitGenerator;
use super::{
Tk1EncoderConfig, METADATA_B_OUTPUT_REGISTERS, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};
#[derive(derive_more::Debug, Clone)]
#[debug(bounds(N: std::fmt::Debug))]
pub struct ValueTracker<N> {
qubits: Vec<RegisterUnit>,
bits: Vec<RegisterUnit>,
params: Vec<String>,
wires: BTreeMap<Wire<N>, TrackedWire>,
output_qubits: Vec<RegisterUnit>,
#[allow(unused)]
output_bits: Vec<RegisterUnit>,
unused_qubits: BTreeSet<TrackedQubit>,
unused_bits: BTreeSet<TrackedBit>,
qubit_reg_generator: RegisterUnitGenerator,
bit_reg_generator: RegisterUnitGenerator,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("qubit#{}", self.0)]
pub struct TrackedQubit(usize);
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("bit#{}", self.0)]
pub struct TrackedBit(usize);
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("param#{}", self.0)]
pub struct TrackedParam(usize);
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
derive_more::From,
derive_more::Display,
)]
#[non_exhaustive]
pub enum TrackedValue {
Qubit(TrackedQubit),
Bit(TrackedBit),
Param(TrackedParam),
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct TrackedValues {
pub qubits: Vec<TrackedQubit>,
pub bits: Vec<TrackedBit>,
pub params: Vec<TrackedParam>,
}
#[derive(Debug, Clone)]
struct TrackedWire {
pub(self) values: Option<Vec<TrackedValue>>,
pub(self) unexplored_neighbours: usize,
}
#[derive(
Clone,
Copy,
PartialEq,
Eq,
Hash,
Debug,
Default,
derive_more::Display,
derive_more::Add,
derive_more::Sub,
derive_more::Sum,
)]
#[display("{qubits} qubits, {bits} bits, {params} parameters")]
#[non_exhaustive]
pub struct RegisterCount {
pub qubits: usize,
pub bits: usize,
pub params: usize,
}
#[derive(Debug, Clone)]
pub struct ValueTrackerResult {
pub qubits: Vec<RegisterUnit>,
pub bits: Vec<RegisterUnit>,
pub params: Vec<String>,
pub qubit_permutation: Vec<circuit_json::ImplicitPermutation>,
}
impl<N: HugrNode> ValueTracker<N> {
pub(super) fn new<H: HugrView<Node = N>>(
circ: &Circuit<H>,
region: N,
config: &Tk1EncoderConfig<H>,
) -> Result<Self, Tk1ConvertError<N>> {
let param_variable_names: Vec<String> =
read_metadata_json_list(circ, region, METADATA_INPUT_PARAMETERS);
let mut tracker = ValueTracker {
qubits: read_metadata_json_list(circ, region, METADATA_Q_REGISTERS),
bits: read_metadata_json_list(circ, region, METADATA_B_REGISTERS),
params: Vec::with_capacity(param_variable_names.len()),
wires: BTreeMap::new(),
output_qubits: read_metadata_json_list(circ, region, METADATA_Q_OUTPUT_REGISTERS),
output_bits: read_metadata_json_list(circ, region, METADATA_B_OUTPUT_REGISTERS),
unused_qubits: BTreeSet::new(),
unused_bits: BTreeSet::new(),
qubit_reg_generator: RegisterUnitGenerator::default(),
bit_reg_generator: RegisterUnitGenerator::default(),
};
if !tracker.output_qubits.is_empty() {
let inputs: HashSet<_> = tracker.qubits.iter().cloned().collect();
for q in &tracker.output_qubits {
if !inputs.contains(q) {
tracker.qubits.push(q.clone());
}
}
}
tracker.unused_qubits = (0..tracker.qubits.len()).map(TrackedQubit).collect();
tracker.unused_bits = (0..tracker.bits.len()).map(TrackedBit).collect();
tracker.qubit_reg_generator = RegisterUnitGenerator::new("q", tracker.qubits.iter());
tracker.bit_reg_generator = RegisterUnitGenerator::new("c", tracker.bits.iter());
let existing_param_vars: HashSet<String> = param_variable_names.iter().cloned().collect();
let mut param_gen = param_variable_names.into_iter().chain(
(0..)
.map(|i| format!("f{i}"))
.filter(|name| !existing_param_vars.contains(name)),
);
let region_optype = circ.hugr().get_optype(region);
let signature = region_optype.inner_function_type().ok_or_else(|| {
let optype = circ.hugr().get_optype(region).to_string();
Tk1ConvertError::NonDataflowRegion { region, optype }
})?;
let inp_node = circ.hugr().get_io(region).unwrap()[0];
for (port, typ) in circ
.hugr()
.node_outputs(inp_node)
.zip(signature.input().iter())
{
let wire = Wire::new(inp_node, port);
let Some(count) = config.type_to_pytket(typ)? else {
tracker.register_wire::<TrackedValue>(wire, [], circ)?;
continue;
};
let mut wire_values = Vec::with_capacity(count.total());
for _ in 0..count.qubits {
let qb = tracker.new_qubit();
wire_values.push(TrackedValue::Qubit(qb));
}
for _ in 0..count.bits {
let bit = tracker.new_bit();
wire_values.push(TrackedValue::Bit(bit));
}
for _ in 0..count.params {
let param = tracker.new_param(param_gen.next().unwrap());
wire_values.push(TrackedValue::Param(param));
}
tracker.register_wire(wire, wire_values, circ)?;
}
Ok(tracker)
}
pub fn new_qubit(&mut self) -> TrackedQubit {
self.unused_qubits.pop_first().unwrap_or_else(|| {
self.qubits.push(self.qubit_reg_generator.next());
TrackedQubit(self.qubits.len() - 1)
})
}
pub fn new_bit(&mut self) -> TrackedBit {
self.unused_bits.pop_first().unwrap_or_else(|| {
self.bits.push(self.bit_reg_generator.next());
TrackedBit(self.bits.len() - 1)
})
}
pub fn new_param(&mut self, expression: impl ToString) -> TrackedParam {
self.params.push(expression.to_string());
TrackedParam(self.params.len() - 1)
}
pub fn register_wire<Val: Into<TrackedValue>>(
&mut self,
wire: Wire<N>,
values: impl IntoIterator<Item = Val>,
circ: &Circuit<impl HugrView<Node = N>>,
) -> Result<(), OpConvertError<N>> {
let values = values.into_iter().map(|v| v.into()).collect_vec();
for value in &values {
match value {
TrackedValue::Qubit(qb) => {
self.unused_qubits.remove(qb);
}
TrackedValue::Bit(bit) => {
self.unused_bits.remove(bit);
}
TrackedValue::Param(_) => {}
}
}
let unexplored_neighbours = circ.hugr().linked_ports(wire.node(), wire.source()).count();
let tracked = TrackedWire {
values: Some(values),
unexplored_neighbours,
};
if self.wires.insert(wire, tracked).is_some() {
return Err(OpConvertError::WireAlreadyHasValues { wire });
}
if unexplored_neighbours == 0 {
self.unregister_wire(wire)
.expect("Wire should be registered in the tracker");
}
Ok(())
}
pub(super) fn wire_values(&mut self, wire: Wire<N>) -> Option<Cow<'_, [TrackedValue]>> {
let values = self.wires.get(&wire)?;
if values.unexplored_neighbours != 1 {
let wire = self.wires.get_mut(&wire).unwrap();
wire.unexplored_neighbours -= 1;
let values = wire.values.as_ref()?;
return Some(Cow::Borrowed(values));
}
let values = self.unregister_wire(wire)?;
Some(Cow::Owned(values))
}
pub(super) fn peek_wire_values(&self, wire: Wire<N>) -> Option<&[TrackedValue]> {
let wire = self.wires.get(&wire)?;
let values = wire.values.as_ref()?;
Some(&values[..])
}
fn unregister_wire(&mut self, wire: Wire<N>) -> Option<Vec<TrackedValue>> {
let wire = self.wires.remove(&wire).unwrap();
let values = wire.values?;
for value in &values {
match value {
TrackedValue::Qubit(qb) => {
self.unused_qubits.insert(*qb);
}
TrackedValue::Bit(bit) => {
self.unused_bits.insert(*bit);
}
TrackedValue::Param(_) => {}
}
}
Some(values)
}
pub fn qubit_register(&self, qb: TrackedQubit) -> &RegisterUnit {
&self.qubits[qb.0]
}
pub fn bit_register(&self, bit: TrackedBit) -> &RegisterUnit {
&self.bits[bit.0]
}
pub fn param_expression(&self, param: TrackedParam) -> &str {
&self.params[param.0]
}
pub(super) fn finish(
self,
circ: &Circuit<impl HugrView<Node = N>>,
region: N,
) -> Result<ValueTrackerResult, OpConvertError<N>> {
let output_node = circ.hugr().get_io(region).unwrap()[1];
let mut qubit_outputs = Vec::with_capacity(self.qubits.len() - self.unused_qubits.len());
let mut bit_outputs = Vec::with_capacity(self.bits.len() - self.unused_bits.len());
let mut param_outputs = Vec::new();
for (node, port) in circ.hugr().all_linked_outputs(output_node) {
let wire = Wire::new(node, port);
let values = self
.peek_wire_values(wire)
.ok_or_else(|| OpConvertError::WireHasNoValues { wire })?;
for value in values {
match value {
TrackedValue::Qubit(qb) => qubit_outputs.push(self.qubit_register(*qb).clone()),
TrackedValue::Bit(bit) => bit_outputs.push(self.bit_register(*bit).clone()),
TrackedValue::Param(param) => {
param_outputs.push(self.param_expression(*param).to_string())
}
}
}
}
if qubit_outputs.len() < self.output_qubits.len() {
let qbs = self
.unused_qubits
.iter()
.take(self.output_qubits.len() - qubit_outputs.len())
.map(|&qb| self.qubit_register(qb).clone());
qubit_outputs.extend(qbs);
}
let (qubit_outputs, qubit_permutation) =
compute_final_permutation(qubit_outputs, &self.qubits, &self.output_qubits);
Ok(ValueTrackerResult {
qubits: qubit_outputs,
bits: bit_outputs,
params: param_outputs,
qubit_permutation,
})
}
}
impl TrackedValues {
pub fn new_qubits(qubits: impl IntoIterator<Item = TrackedQubit>) -> Self {
let qubits = qubits.into_iter().collect();
Self {
qubits,
bits: Vec::new(),
params: Vec::new(),
}
}
pub fn new_bits(bits: impl IntoIterator<Item = TrackedBit>) -> Self {
let bits = bits.into_iter().collect();
Self {
qubits: Vec::new(),
bits,
params: Vec::new(),
}
}
pub fn new_params(params: impl IntoIterator<Item = TrackedParam>) -> Self {
let params = params.into_iter().collect();
Self {
qubits: Vec::new(),
bits: Vec::new(),
params,
}
}
pub fn count(&self) -> RegisterCount {
RegisterCount::new(self.qubits.len(), self.bits.len(), self.params.len())
}
pub fn iter(&self) -> impl Iterator<Item = TrackedValue> + '_ {
self.qubits
.iter()
.map(|&qb| TrackedValue::Qubit(qb))
.chain(self.bits.iter().map(|&bit| TrackedValue::Bit(bit)))
.chain(self.params.iter().map(|¶m| TrackedValue::Param(param)))
}
pub fn append(&mut self, other: TrackedValues) {
self.qubits.extend(other.qubits);
self.bits.extend(other.bits);
self.params.extend(other.params);
}
}
impl IntoIterator for TrackedValues {
type Item = TrackedValue;
type IntoIter = std::iter::Chain<
std::iter::Chain<
itertools::MapInto<std::vec::IntoIter<TrackedQubit>, TrackedValue>,
itertools::MapInto<std::vec::IntoIter<TrackedBit>, TrackedValue>,
>,
itertools::MapInto<std::vec::IntoIter<TrackedParam>, TrackedValue>,
>;
fn into_iter(self) -> Self::IntoIter {
self.qubits
.into_iter()
.map_into()
.chain(self.bits.into_iter().map_into())
.chain(self.params.into_iter().map_into())
}
}
impl RegisterCount {
pub const fn new(qubits: usize, bits: usize, params: usize) -> Self {
RegisterCount {
qubits,
bits,
params,
}
}
pub const fn only_qubits(qubits: usize) -> Self {
RegisterCount {
qubits,
bits: 0,
params: 0,
}
}
pub const fn only_bits(bits: usize) -> Self {
RegisterCount {
qubits: 0,
bits,
params: 0,
}
}
pub const fn only_params(params: usize) -> Self {
RegisterCount {
qubits: 0,
bits: 0,
params,
}
}
pub const fn total(&self) -> usize {
self.qubits + self.bits + self.params
}
}
fn read_metadata_json_list<T: serde::de::DeserializeOwned, H: HugrView>(
circ: &Circuit<H>,
region: H::Node,
metadata_key: &str,
) -> Vec<T> {
let Some(value) = circ.hugr().get_metadata(region, metadata_key) else {
return vec![];
};
serde_json::from_value::<Vec<T>>(value.clone()).unwrap_or_default()
}
pub(super) fn compute_final_permutation(
mut actual_outputs: Vec<RegisterUnit>,
all_inputs: &[RegisterUnit],
declared_outputs: &[RegisterUnit],
) -> (Vec<RegisterUnit>, Vec<circuit_json::ImplicitPermutation>) {
let mut declared_outputs: Vec<&RegisterUnit> = declared_outputs.iter().collect();
let mut declared_outputs_hashes: HashSet<RegisterHash> = declared_outputs
.iter()
.map(|®| RegisterHash::from(reg))
.collect();
let mut actual_outputs_hashes: HashSet<RegisterHash> =
actual_outputs.iter().map(RegisterHash::from).collect();
let mut input_hashes: HashMap<RegisterHash, usize> = HashMap::default();
for (i, inp) in all_inputs.iter().enumerate() {
let hash = inp.into();
input_hashes.insert(hash, i);
if !declared_outputs_hashes.contains(&hash) {
declared_outputs.push(inp);
declared_outputs_hashes.insert(hash);
}
}
for reg in all_inputs {
let hash = reg.into();
if !actual_outputs_hashes.contains(&hash) {
actual_outputs.push(reg.clone());
actual_outputs_hashes.insert(hash);
}
}
let permutation = actual_outputs
.iter()
.map(|reg| {
let hash = reg.into();
let i = input_hashes.get(&hash).unwrap();
let out = declared_outputs[*i].clone();
circuit_json::ImplicitPermutation(
tket_json_rs::register::Qubit { id: reg.clone() },
tket_json_rs::register::Qubit { id: out },
)
})
.collect_vec();
(actual_outputs, permutation)
}