use std::{collections::BTreeMap, rc::Rc};
use hugr_core::{
HugrView, Node,
extension::{
ExtensionId,
simple_op::{MakeExtensionOp, MakeOpDef},
},
ops::{ExtensionOp, OpName},
};
use anyhow::{Result, bail};
use strum::IntoEnumIterator;
use crate::emit::{EmitFuncContext, EmitOpArgs};
pub trait ExtensionOpFn<'a, H>:
for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
{
}
impl<
'a,
H,
F: for<'c> Fn(
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()>
+ ?Sized
+ 'a,
> ExtensionOpFn<'a, H> for F
{
}
#[derive(Default)]
pub struct ExtensionOpMap<'a, H>(BTreeMap<(ExtensionId, OpName), Box<dyn ExtensionOpFn<'a, H>>>);
impl<'a, H: HugrView<Node = Node>> ExtensionOpMap<'a, H> {
pub fn extension_op(
&mut self,
extension: ExtensionId,
op: OpName,
handler: impl ExtensionOpFn<'a, H>,
) {
self.0.insert((extension, op), Box::new(handler));
}
pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
&mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
) {
let handler = Rc::new(handler);
for op in Op::iter() {
let handler = handler.clone();
self.extension_op(
op.extension(),
op.opdef_id().clone(),
move |context, args| {
let op = Op::from_extension_op(&args.node())?;
handler(context, args, op)
},
);
}
}
pub fn emit_extension_op<'c>(
&self,
context: &mut EmitFuncContext<'c, 'a, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
let node = args.node();
let key = (
node.def().extension_id().clone(),
node.unqualified_id().into(),
);
let Some(handler) = self.0.get(&key) else {
bail!("No extension could emit extension op: {key:?}")
};
(handler.as_ref())(context, args)
}
}