use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::NodeMetadata;
use itertools::Itertools;
use crate::Circuit;
use super::CircuitRewrite;
pub const METADATA_REWRITES: &str = "TKET2.rewrites";
pub const REWRITE_TRACING_ENABLED: bool = cfg!(feature = "rewrite-tracing");
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct RewriteTrace {
individual_matches: u16,
}
impl From<&CircuitRewrite> for RewriteTrace {
#[inline]
fn from(_rewrite: &CircuitRewrite) -> Self {
Self {
individual_matches: 1,
}
}
}
impl RewriteTrace {
#[inline]
pub fn new(individual_matches: u16) -> Self {
Self { individual_matches }
}
}
impl From<&serde_json::Value> for RewriteTrace {
#[inline]
fn from(value: &serde_json::Value) -> Self {
Self {
individual_matches: value.as_u64().unwrap() as u16,
}
}
}
impl From<RewriteTrace> for serde_json::Value {
#[inline]
fn from(trace: RewriteTrace) -> Self {
serde_json::Value::from(trace.individual_matches)
}
}
impl<T: HugrMut> Circuit<T> {
#[inline]
pub fn enable_rewrite_tracing(&mut self) {
if !REWRITE_TRACING_ENABLED {
return;
}
let root = self.parent();
let meta = self.hugr_mut().get_metadata_mut(root, METADATA_REWRITES);
if *meta == NodeMetadata::Null {
*meta = NodeMetadata::Array(vec![]);
}
}
#[inline]
pub fn add_rewrite_trace(&mut self, rewrite: impl Into<RewriteTrace>) -> bool {
if !REWRITE_TRACING_ENABLED {
return false;
}
let root = self.parent();
match self
.hugr_mut()
.get_metadata_mut(root, METADATA_REWRITES)
.as_array_mut()
{
Some(meta) => {
let rewrite = rewrite.into();
meta.push(rewrite.into());
true
}
None => false,
}
}
#[inline]
pub fn rewrite_trace(&self) -> Option<impl Iterator<Item = RewriteTrace> + '_> {
if !REWRITE_TRACING_ENABLED {
return None;
}
let meta = self.hugr().get_metadata(self.parent(), METADATA_REWRITES)?;
let rewrites = meta.as_array()?;
Some(rewrites.iter().map_into())
}
}