use std::collections::{BTreeSet, HashMap, VecDeque};
use hugr::extension::{ExtensionId, ExtensionSet};
use hugr::ops::{ExtensionOp, Value};
use hugr::types::{SumType, Type, TypeEnum};
use crate::serialize::pytket::encoder::EncodeStatus;
use crate::serialize::pytket::extension::{
set_bits_op, BoolEmitter, FloatEmitter, PreludeEmitter, RotationEmitter, Tk1Emitter, Tk2Emitter,
};
use crate::serialize::pytket::{PytketEmitter, Tk1ConvertError};
use crate::Circuit;
use super::value_tracker::RegisterCount;
use super::{Tk1EncoderContext, TrackedValues};
use hugr::extension::prelude::bool_t;
use hugr::HugrView;
use itertools::Itertools;
pub fn default_encoder_config<H: HugrView>() -> Tk1EncoderConfig<H> {
let mut config = Tk1EncoderConfig::new();
config.add_emitter(PreludeEmitter);
config.add_emitter(BoolEmitter);
config.add_emitter(FloatEmitter);
config.add_emitter(RotationEmitter);
config.add_emitter(Tk1Emitter);
config.add_emitter(Tk2Emitter);
config
}
#[derive(derive_more::Debug)]
#[debug(bounds(H: HugrView))]
pub struct Tk1EncoderConfig<H: HugrView> {
#[debug(skip)]
pub(super) emitters: Vec<Box<dyn PytketEmitter<H>>>,
#[debug("{:?}", extension_emitters.keys().collect_vec())]
extension_emitters: HashMap<ExtensionId, Vec<usize>>,
no_extension_emitters: Vec<usize>,
}
impl<H: HugrView> Tk1EncoderConfig<H> {
pub fn new() -> Self {
Self {
emitters: vec![],
extension_emitters: HashMap::new(),
no_extension_emitters: vec![],
}
}
pub fn add_emitter(&mut self, encoder: impl PytketEmitter<H> + 'static) {
let idx = self.emitters.len();
match encoder.extensions() {
Some(extensions) => {
for ext in extensions {
self.extension_emitters.entry(ext).or_default().push(idx);
}
}
None => self.no_extension_emitters.push(idx),
}
self.emitters.push(Box::new(encoder));
}
pub fn supported_extensions(&self) -> impl Iterator<Item = &ExtensionId> {
self.extension_emitters.keys()
}
pub(super) fn op_to_pytket(
&self,
node: H::Node,
op: &ExtensionOp,
circ: &Circuit<H>,
encoder: &mut Tk1EncoderContext<H>,
) -> Result<EncodeStatus, Tk1ConvertError<H::Node>> {
let mut result = EncodeStatus::Unsupported;
let extension = op.def().extension_id();
for enc in self.emitters_for_extension(extension) {
if enc.op_to_pytket(node, op, circ, encoder)? == EncodeStatus::Success {
result = EncodeStatus::Success;
break;
}
}
Ok(result)
}
pub fn type_to_pytket(
&self,
typ: &Type,
) -> Result<Option<RegisterCount>, Tk1ConvertError<H::Node>> {
match typ.as_type_enum() {
TypeEnum::Sum(sum) => {
if typ == &bool_t() {
return Ok(Some(RegisterCount {
qubits: 0,
bits: 1,
params: 0,
}));
}
if let Some(tuple) = sum.as_tuple() {
let count: Result<Option<RegisterCount>, Tk1ConvertError<H::Node>> = tuple
.iter()
.map(|ty| {
match ty.clone().try_into() {
Ok(ty) => Ok(self.type_to_pytket(&ty)?),
Err(_) => Ok(None),
}
})
.sum();
return count;
}
}
TypeEnum::Extension(custom) => {
let type_ext = custom.extension();
for encoder in self.emitters_for_extension(type_ext) {
if let Some(count) = encoder.type_to_pytket(custom)? {
return Ok(Some(count));
}
}
}
_ => {}
}
Ok(None)
}
pub(super) fn const_to_pytket(
&self,
value: &Value,
encoder: &mut Tk1EncoderContext<H>,
) -> Result<Option<TrackedValues>, Tk1ConvertError<H::Node>> {
let mut values = TrackedValues::default();
let mut queue = VecDeque::from([value]);
while let Some(value) = queue.pop_front() {
match value {
Value::Sum(sum) => {
if sum.sum_type == SumType::new_unary(2) {
let new_bit = encoder.values.new_bit();
if value == &Value::true_val() {
let op = set_bits_op(&[true]);
encoder.emit_command(op, &[], &[new_bit], None);
}
return Ok(Some(TrackedValues::new_bits([new_bit])));
}
if sum.sum_type.as_tuple().is_some() {
for v in sum.values.iter() {
queue.push_back(v);
}
}
}
Value::Extension { e: opaque } => {
let typ = opaque.value().get_type();
let type_exts = typ.used_extensions().unwrap_or_else(|e| {
panic!("Tried to encode a type with partially initialized extension. {e}");
});
let exts_set = ExtensionSet::from_iter(type_exts.ids().cloned());
let mut encoded = false;
for e in self.emitters_for_extensions(&exts_set) {
if let Some(vs) = e.const_to_pytket(opaque, encoder)? {
values.append(vs);
encoded = true;
break;
}
}
if !encoded {
return Ok(None);
}
}
_ => return Ok(None),
}
}
Ok(Some(values))
}
fn emitters_for_extension(
&self,
ext: &ExtensionId,
) -> impl Iterator<Item = &Box<dyn PytketEmitter<H>>> {
self.extension_emitters
.get(ext)
.into_iter()
.flatten()
.chain(self.no_extension_emitters.iter())
.map(move |idx| &self.emitters[*idx])
}
fn emitters_for_extensions(
&self,
exts: &ExtensionSet,
) -> impl Iterator<Item = &Box<dyn PytketEmitter<H>>> {
let emitter_ids: BTreeSet<usize> = exts
.iter()
.flat_map(|ext| self.extension_emitters.get(ext).into_iter().flatten())
.chain(self.no_extension_emitters.iter())
.copied()
.collect();
emitter_ids.into_iter().map(move |idx| &self.emitters[idx])
}
}
impl<H: HugrView> Default for Tk1EncoderConfig<H> {
fn default() -> Self {
Self {
emitters: Default::default(),
extension_emitters: Default::default(),
no_extension_emitters: Default::default(),
}
}
}