use std::collections::{BTreeSet, HashMap, VecDeque};
use hugr::extension::{ExtensionId, ExtensionSet};
use hugr::ops::{ExtensionOp, Value};
use hugr::types::{SumType, Type};
use crate::serialize::pytket::encoder::EncodeStatus;
use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount, set_bits_op};
use crate::serialize::pytket::{PytketEmitter, PytketEncodeError};
use super::super::encoder::{PytketEncoderContext, TrackedValues};
use super::TypeTranslatorSet;
use hugr::HugrView;
use itertools::Itertools;
#[derive(derive_more::Debug)]
#[debug(bounds(H: HugrView))]
pub struct PytketEncoderConfig<H: HugrView> {
#[debug(skip)]
pub(super) emitters: Vec<Box<dyn PytketEmitter<H>>>,
#[debug("{:?}", extension_emitters.keys().collect_vec())]
extension_emitters: HashMap<ExtensionId, Vec<usize>>,
no_extension_emitters: Vec<usize>,
type_translators: TypeTranslatorSet,
}
impl<H: HugrView> PytketEncoderConfig<H> {
pub fn new() -> Self {
Self::default()
}
pub fn add_emitter(&mut self, encoder: impl PytketEmitter<H> + 'static) {
let idx = self.emitters.len();
match encoder.extensions() {
Some(extensions) => {
for ext in extensions {
self.extension_emitters.entry(ext).or_default().push(idx);
}
}
None => self.no_extension_emitters.push(idx),
}
self.emitters.push(Box::new(encoder));
}
pub fn add_type_translator(
&mut self,
translator: impl PytketTypeTranslator + Send + Sync + 'static,
) {
self.type_translators.add_type_translator(translator);
}
pub fn supported_extensions(&self) -> impl Iterator<Item = &ExtensionId> {
self.extension_emitters.keys()
}
pub fn op_to_pytket(
&self,
node: H::Node,
op: &ExtensionOp,
hugr: &H,
encoder: &mut PytketEncoderContext<H>,
) -> Result<EncodeStatus, PytketEncodeError<H::Node>> {
let mut result = EncodeStatus::Unsupported;
let extension = op.def().extension_id();
for enc in self.emitters_for_extension(extension) {
if enc.op_to_pytket(node, op, hugr, encoder)? == EncodeStatus::Success {
result = EncodeStatus::Success;
break;
}
}
Ok(result)
}
pub fn const_to_pytket(
&self,
value: &Value,
encoder: &mut PytketEncoderContext<H>,
) -> Result<Option<TrackedValues>, PytketEncodeError<H::Node>> {
let mut values = TrackedValues::default();
let mut queue = VecDeque::from([value]);
while let Some(value) = queue.pop_front() {
match value {
Value::Sum(sum) => {
if sum.sum_type == SumType::new_unary(2) {
let new_bit = encoder.values.new_bit();
if value == &Value::true_val() {
let op = set_bits_op(&[true]);
encoder.emit_command(op, &[], &[new_bit], None);
}
return Ok(Some(TrackedValues::new_bits([new_bit])));
}
if sum.sum_type.as_tuple().is_some() {
for v in sum.values.iter() {
queue.push_back(v);
}
}
}
Value::Extension { e: opaque } => {
let typ = opaque.value().get_type();
let type_exts = typ.used_extensions().unwrap_or_else(|e| {
panic!("Tried to encode a type with partially initialized extension. {e}");
});
let exts_set = ExtensionSet::from_iter(type_exts.ids().cloned());
let mut encoded = false;
for e in self.emitters_for_extensions(&exts_set) {
if let Some(vs) = e.const_to_pytket(opaque, encoder)? {
values.append(vs);
encoded = true;
break;
}
}
if !encoded {
return Ok(None);
}
}
}
}
Ok(Some(values))
}
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_translators.type_to_pytket(typ)
}
fn emitters_for_extension(
&self,
ext: &ExtensionId,
) -> impl Iterator<Item = &Box<dyn PytketEmitter<H>>> + use<'_, H> {
self.extension_emitters
.get(ext)
.into_iter()
.flatten()
.chain(self.no_extension_emitters.iter())
.map(move |idx| &self.emitters[*idx])
}
fn emitters_for_extensions(
&self,
exts: &ExtensionSet,
) -> impl Iterator<Item = &Box<dyn PytketEmitter<H>>> + use<'_, H> {
let emitter_ids: BTreeSet<usize> = exts
.iter()
.flat_map(|ext| self.extension_emitters.get(ext).into_iter().flatten())
.chain(self.no_extension_emitters.iter())
.copied()
.collect();
emitter_ids.into_iter().map(move |idx| &self.emitters[idx])
}
}
impl<H: HugrView> Default for PytketEncoderConfig<H> {
fn default() -> Self {
Self {
emitters: Default::default(),
extension_emitters: Default::default(),
no_extension_emitters: Default::default(),
type_translators: TypeTranslatorSet::default(),
}
}
}