use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use hugr::core::HugrNode;
use hugr::ops::OpParent;
use hugr::{HugrView, Wire};
use hugr_core::metadata::Metadata;
use itertools::Itertools;
use tket_json_rs::circuit_json;
use tket_json_rs::register::ElementId as RegisterUnit;
use crate::metadata;
use crate::serialize::pytket::circuit::StraightThroughWire;
use crate::serialize::pytket::extension::RegisterCount;
use crate::serialize::pytket::{PytketEncodeError, PytketEncodeOpError, RegisterHash};
use super::PytketEncoderConfig;
use super::unit_generator::RegisterUnitGenerator;
#[derive(derive_more::Debug, Clone)]
#[debug(bounds(N: std::fmt::Debug))]
pub struct ValueTracker<N> {
qubits: Vec<RegisterUnit>,
bits: Vec<RegisterUnit>,
params: Vec<String>,
wires: BTreeMap<Wire<N>, TrackedWire>,
unused_qubits: BTreeSet<TrackedQubit>,
unused_bits: BTreeSet<TrackedBit>,
qubit_reg_generator: RegisterUnitGenerator,
bit_reg_generator: RegisterUnitGenerator,
input_params: Vec<String>,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("qubit#{}", self.0)]
pub struct TrackedQubit(usize);
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("bit#{}", self.0)]
pub struct TrackedBit(usize);
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display,
)]
#[display("param#{}", self.0)]
pub struct TrackedParam(usize);
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
derive_more::From,
derive_more::Display,
)]
#[non_exhaustive]
pub enum TrackedValue {
Qubit(TrackedQubit),
Bit(TrackedBit),
Param(TrackedParam),
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct TrackedValues {
pub qubits: Vec<TrackedQubit>,
pub bits: Vec<TrackedBit>,
pub params: Vec<TrackedParam>,
}
impl Extend<TrackedValue> for TrackedValues {
fn extend<T: IntoIterator<Item = TrackedValue>>(&mut self, iter: T) {
for v in iter {
match v {
TrackedValue::Qubit(qb) => self.qubits.push(qb),
TrackedValue::Bit(bit) => self.bits.push(bit),
TrackedValue::Param(param) => self.params.push(param),
}
}
}
}
#[derive(Debug, Clone)]
struct TrackedWire {
pub(self) values: Option<Vec<TrackedValue>>,
pub(self) unexplored_neighbours: usize,
}
#[derive(Debug, Clone)]
pub struct ValueTrackerResult {
pub qubits: Vec<RegisterUnit>,
pub bits: Vec<RegisterUnit>,
pub params: Vec<String>,
pub qubit_outputs: Vec<RegisterUnit>,
pub bit_outputs: Vec<RegisterUnit>,
pub qubit_permutation: Vec<circuit_json::ImplicitPermutation>,
pub input_params: Vec<String>,
pub straight_through_wires: Vec<StraightThroughWire>,
}
impl<N: HugrNode> ValueTracker<N> {
pub(super) fn new<H: HugrView<Node = N>>(
hugr: &H,
region: N,
config: &PytketEncoderConfig<H>,
) -> Result<Self, PytketEncodeError<N>> {
let param_variable_names: Vec<String> =
read_metadata_json_list::<_, _, metadata::InputParameters>(hugr, region);
let mut tracker = ValueTracker {
qubits: read_metadata_json_list::<_, _, metadata::QubitRegisters>(hugr, region)
.into_iter()
.map(|q| q.id)
.collect_vec(),
bits: read_metadata_json_list::<_, _, metadata::BitRegisters>(hugr, region)
.into_iter()
.map(|b| b.id)
.collect_vec(),
params: Vec::with_capacity(param_variable_names.len()),
wires: BTreeMap::new(),
unused_qubits: BTreeSet::new(),
unused_bits: BTreeSet::new(),
qubit_reg_generator: RegisterUnitGenerator::default(),
bit_reg_generator: RegisterUnitGenerator::default(),
input_params: Vec::with_capacity(param_variable_names.len()),
};
tracker.unused_qubits = (0..tracker.qubits.len()).map(TrackedQubit).collect();
tracker.unused_bits = (0..tracker.bits.len()).map(TrackedBit).collect();
tracker.qubit_reg_generator = RegisterUnitGenerator::new("q", tracker.qubits.iter());
tracker.bit_reg_generator = RegisterUnitGenerator::new("c", tracker.bits.iter());
let existing_param_vars: HashSet<String> = param_variable_names.iter().cloned().collect();
let mut param_gen = param_variable_names.into_iter().chain(
(0..)
.map(|i| format!("f{i}"))
.filter(|name| !existing_param_vars.contains(name)),
);
let region_optype = hugr.get_optype(region);
let signature = region_optype.inner_function_type().ok_or_else(|| {
let optype = hugr.get_optype(region).to_string();
PytketEncodeError::NonDataflowRegion { region, optype }
})?;
let inp_node = hugr.get_io(region).unwrap()[0];
for (port, typ) in hugr.node_outputs(inp_node).zip(signature.input().iter()) {
let wire = Wire::new(inp_node, port);
let Some(count) = config.type_to_pytket(typ) else {
continue;
};
let mut wire_values = Vec::with_capacity(count.total());
for _ in 0..count.qubits {
let qb = tracker.new_qubit();
wire_values.push(TrackedValue::Qubit(qb));
}
for _ in 0..count.bits {
let bit = tracker.new_bit();
wire_values.push(TrackedValue::Bit(bit));
}
for _ in 0..count.params {
let param_name = param_gen.next().unwrap();
tracker.input_params.push(param_name.clone());
let param = tracker.new_param(param_name);
wire_values.push(TrackedValue::Param(param));
}
tracker.register_wire(wire, wire_values, hugr)?;
}
Ok(tracker)
}
pub fn new_qubit(&mut self) -> TrackedQubit {
self.unused_qubits.pop_first().unwrap_or_else(|| {
self.qubits.push(self.qubit_reg_generator.next());
TrackedQubit(self.qubits.len() - 1)
})
}
pub fn new_bit(&mut self) -> TrackedBit {
self.unused_bits.pop_first().unwrap_or_else(|| {
self.bits.push(self.bit_reg_generator.next());
TrackedBit(self.bits.len() - 1)
})
}
pub fn free_qubit(&mut self, qb: TrackedQubit) {
self.unused_qubits.insert(qb);
}
pub fn free_bit(&mut self, bit: TrackedBit) {
self.unused_bits.insert(bit);
}
pub fn new_param(&mut self, expression: impl ToString) -> TrackedParam {
self.params.push(expression.to_string());
TrackedParam(self.params.len() - 1)
}
pub fn register_wire<Val: Into<TrackedValue>>(
&mut self,
wire: Wire<N>,
values: impl IntoIterator<Item = Val>,
hugr: &impl HugrView<Node = N>,
) -> Result<(), PytketEncodeOpError<N>> {
let values = values.into_iter().map(|v| v.into()).collect_vec();
for value in &values {
match value {
TrackedValue::Qubit(qb) => {
self.unused_qubits.remove(qb);
}
TrackedValue::Bit(bit) => {
self.unused_bits.remove(bit);
}
TrackedValue::Param(_) => {}
}
}
let unexplored_neighbours = hugr.linked_ports(wire.node(), wire.source()).count();
let tracked = TrackedWire {
values: Some(values),
unexplored_neighbours,
};
if self.wires.insert(wire, tracked).is_some() {
return Err(PytketEncodeOpError::WireAlreadyHasValues { wire });
}
if unexplored_neighbours == 0 {
self.unregister_wire(wire)
.expect("Wire should be registered in the tracker");
}
Ok(())
}
pub(super) fn wire_values(&mut self, wire: Wire<N>) -> Option<Cow<'_, [TrackedValue]>> {
let values = self.wires.get(&wire)?;
if values.unexplored_neighbours != 1 {
let wire = self.wires.get_mut(&wire).unwrap();
wire.unexplored_neighbours -= 1;
let values = wire.values.as_ref()?;
return Some(Cow::Borrowed(values));
}
let values = self.unregister_wire(wire)?;
Some(Cow::Owned(values))
}
pub(super) fn peek_wire_values(&self, wire: Wire<N>) -> Option<&[TrackedValue]> {
let wire = self.wires.get(&wire)?;
let values = wire.values.as_ref()?;
Some(&values[..])
}
fn unregister_wire(&mut self, wire: Wire<N>) -> Option<Vec<TrackedValue>> {
let wire = self.wires.remove(&wire).unwrap();
let values = wire.values?;
Some(values)
}
pub fn qubit_register(&self, qb: TrackedQubit) -> &RegisterUnit {
&self.qubits[qb.0]
}
pub fn bit_register(&self, bit: TrackedBit) -> &RegisterUnit {
&self.bits[bit.0]
}
pub fn param_expression(&self, param: TrackedParam) -> &str {
&self.params[param.0]
}
pub(super) fn finish(
self,
hugr: &impl HugrView<Node = N>,
region: N,
) -> Result<ValueTrackerResult, PytketEncodeOpError<N>> {
let [input_node, output_node] = hugr.get_io(region).unwrap();
let mut straight_through_wires = Vec::new();
let mut qubit_outputs = Vec::with_capacity(self.qubits.len() - self.unused_qubits.len());
let mut bit_outputs = Vec::with_capacity(self.bits.len() - self.unused_bits.len());
let mut param_outputs = Vec::new();
for tgt_port in hugr.node_inputs(output_node) {
for (src_node, src_port) in hugr.linked_outputs(output_node, tgt_port) {
let wire = Wire::new(src_node, src_port);
let Some(values) = self.peek_wire_values(wire) else {
if src_node == input_node {
straight_through_wires.push(StraightThroughWire {
input_source: src_port,
output_target: tgt_port,
});
}
continue;
};
for value in values {
match value {
TrackedValue::Qubit(qb) => {
qubit_outputs.push(self.qubit_register(*qb).clone())
}
TrackedValue::Bit(bit) => bit_outputs.push(self.bit_register(*bit).clone()),
TrackedValue::Param(param) => {
param_outputs.push(self.param_expression(*param).to_string())
}
}
}
}
}
let qubit_permutation = compute_final_permutation(qubit_outputs.clone(), &self.qubits);
Ok(ValueTrackerResult {
qubits: self.qubits,
bits: self.bits,
params: param_outputs,
qubit_outputs,
bit_outputs,
qubit_permutation,
input_params: self.input_params,
straight_through_wires,
})
}
}
impl TrackedValues {
pub fn new_qubits(qubits: impl IntoIterator<Item = TrackedQubit>) -> Self {
let qubits = qubits.into_iter().collect();
Self {
qubits,
bits: Vec::new(),
params: Vec::new(),
}
}
pub fn new_bits(bits: impl IntoIterator<Item = TrackedBit>) -> Self {
let bits = bits.into_iter().collect();
Self {
qubits: Vec::new(),
bits,
params: Vec::new(),
}
}
pub fn new_params(params: impl IntoIterator<Item = TrackedParam>) -> Self {
let params = params.into_iter().collect();
Self {
qubits: Vec::new(),
bits: Vec::new(),
params,
}
}
pub fn count(&self) -> RegisterCount {
RegisterCount::new(self.qubits.len(), self.bits.len(), self.params.len())
}
pub fn iter(&self) -> impl Iterator<Item = TrackedValue> + '_ {
self.qubits
.iter()
.map(|&qb| TrackedValue::Qubit(qb))
.chain(self.bits.iter().map(|&bit| TrackedValue::Bit(bit)))
.chain(self.params.iter().map(|¶m| TrackedValue::Param(param)))
}
pub fn append(&mut self, other: TrackedValues) {
self.qubits.extend(other.qubits);
self.bits.extend(other.bits);
self.params.extend(other.params);
}
}
impl IntoIterator for TrackedValues {
type Item = TrackedValue;
type IntoIter = std::iter::Chain<
std::iter::Chain<
itertools::MapInto<std::vec::IntoIter<TrackedQubit>, TrackedValue>,
itertools::MapInto<std::vec::IntoIter<TrackedBit>, TrackedValue>,
>,
itertools::MapInto<std::vec::IntoIter<TrackedParam>, TrackedValue>,
>;
fn into_iter(self) -> Self::IntoIter {
self.qubits
.into_iter()
.map_into()
.chain(self.bits.into_iter().map_into())
.chain(self.params.into_iter().map_into())
}
}
fn read_metadata_json_list<T: serde::de::DeserializeOwned, H: HugrView, K: Metadata>(
hugr: &H,
region: H::Node,
) -> Vec<T>
where
for<'hugr> K::Type<'hugr>: Into<Vec<T>>,
{
hugr.get_metadata::<K>(region)
.map(Into::into)
.unwrap_or_default()
}
pub(super) fn compute_final_permutation(
mut actual_outputs: Vec<RegisterUnit>,
all_inputs: &[RegisterUnit],
) -> Vec<circuit_json::ImplicitPermutation> {
let declared_outputs: Vec<&RegisterUnit> = all_inputs.iter().collect();
let mut actual_outputs_hashes: HashSet<RegisterHash> =
actual_outputs.iter().map(RegisterHash::from).collect();
let mut input_hashes: HashMap<RegisterHash, usize> = HashMap::default();
for (i, inp) in all_inputs.iter().enumerate() {
let hash = inp.into();
input_hashes.insert(hash, i);
}
for reg in all_inputs {
let hash = reg.into();
if !actual_outputs_hashes.contains(&hash) {
actual_outputs.push(reg.clone());
actual_outputs_hashes.insert(hash);
}
}
actual_outputs
.iter()
.map(|reg| {
let hash = reg.into();
let i = input_hashes.get(&hash).unwrap();
let out = declared_outputs[*i].clone();
circuit_json::ImplicitPermutation(
tket_json_rs::register::Qubit { id: reg.clone() },
tket_json_rs::register::Qubit { id: out },
)
})
.collect_vec()
}