use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{Hugr, Wire};
use itertools::Itertools;
use std::collections::HashMap;
use crate::serialize::pytket::PytketDecodeError;
use crate::serialize::pytket::decoder::{
DecodeStatus, LoadedParameter, PytketDecoderContext, TrackedBit, TrackedQubit,
};
use crate::serialize::pytket::extension::{PytketDecoder, PytketTypeTranslator, RegisterCount};
use super::TypeTranslatorSet;
#[derive(Default, derive_more::Debug)]
pub struct PytketDecoderConfig {
#[debug(skip)]
pub(super) decoders: Vec<Box<dyn PytketDecoder + Send + Sync>>,
#[debug("{:?}", optype_decoders.keys().collect_vec())]
optype_decoders: HashMap<tket_json_rs::OpType, Vec<usize>>,
type_translators: TypeTranslatorSet,
}
impl PytketDecoderConfig {
pub fn new() -> Self {
Self {
decoders: vec![],
optype_decoders: HashMap::new(),
type_translators: TypeTranslatorSet::default(),
}
}
pub fn add_decoder(&mut self, decoder: impl PytketDecoder + Send + Sync + 'static) {
let idx = self.decoders.len();
for optype in decoder.op_types() {
self.optype_decoders.entry(optype).or_default().push(idx);
}
self.decoders.push(Box::new(decoder));
}
pub fn add_type_translator(
&mut self,
translator: impl PytketTypeTranslator + Send + Sync + 'static,
) {
self.type_translators.add_type_translator(translator);
}
pub(in crate::serialize::pytket) fn op_to_hugr<'a>(
&self,
op: &tket_json_rs::circuit_json::Operation,
qubits: &[TrackedQubit],
bits: &[TrackedBit],
params: &[LoadedParameter],
opgroup: &Option<String>,
decoder: &mut PytketDecoderContext<'a>,
) -> Result<DecodeStatus, PytketDecodeError> {
let mut result = DecodeStatus::Unsupported;
let opgroup = opgroup.as_deref();
for enc in self.decoders_for_optype(&op.op_type) {
result = enc.op_to_hugr(op, qubits, bits, params, opgroup, decoder)?;
if result == DecodeStatus::Success {
break;
}
}
Ok(result)
}
fn decoders_for_optype(
&self,
optype: &tket_json_rs::OpType,
) -> impl Iterator<Item = &Box<dyn PytketDecoder + Send + Sync>> + use<'_> {
self.optype_decoders
.get(optype)
.into_iter()
.flatten()
.map(move |idx| &self.decoders[*idx])
}
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_translators.type_to_pytket(typ)
}
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
self.type_translators.types_are_isomorphic(typ1, typ2)
}
pub(in crate::serialize::pytket) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
self.type_translators
.transform_typed_value(wire, initial_type, target_type, builder)
}
}