hugr_llvm/custom/extension_op.rs
1use std::{collections::BTreeMap, rc::Rc};
2
3use hugr_core::{
4 HugrView, Node,
5 extension::{
6 ExtensionId,
7 simple_op::{MakeExtensionOp, MakeOpDef},
8 },
9 ops::{ExtensionOp, OpName},
10};
11
12use anyhow::{Result, bail};
13
14use strum::IntoEnumIterator;
15
16use crate::emit::{EmitFuncContext, EmitOpArgs};
17
18/// A helper trait for describing the callback used for emitting [`ExtensionOp`]s,
19/// and for hanging documentation.
20///
21/// We have the appropriate `Fn` as a supertrait,
22/// and there is a blanket impl for that `Fn`. We do not intend users to impl
23/// this trait.
24///
25/// `ExtensionOpFn` callbacks are registered against a fully qualified [`OpName`],
26/// i.e. including it's [`ExtensionId`]. Callbacks can assume that the provided
27/// [`EmitOpArgs::node`] holds an op matching that fully qualified name, and that
28/// the signature of that op determinies the length and types of
29/// [`EmitOpArgs::inputs`], and [`EmitOpArgs::outputs`] via
30/// [`EmitFuncContext::llvm_type`].
31///
32/// Callbacks should use the supplied [`EmitFuncContext`] to emit LLVM to match
33/// the desired semantics of the op. If a callback returns success then the callback must:
34/// - ensure that [`crate::emit::func::RowPromise::finish`] has been called on the outputs.
35/// - ensure that the contexts [`inkwell::builder::Builder`] is positioned at the end of a basic
36/// block, logically after the execution of the just-emitted op.
37///
38/// Callbacks may hold references with lifetimes older than `'a`.
39pub trait ExtensionOpFn<'a, H>:
40 for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
41{
42}
43
44impl<
45 'a,
46 H,
47 F: for<'c> Fn(
48 &mut EmitFuncContext<'c, 'a, H>,
49 EmitOpArgs<'c, '_, ExtensionOp, H>,
50 ) -> Result<()>
51 + ?Sized
52 + 'a,
53> ExtensionOpFn<'a, H> for F
54{
55}
56
57/// A collection of [`ExtensionOpFn`] callbacks keyed the fully qualified [`OpName`].
58///
59/// Those callbacks may hold references with lifetimes older than `'a`.
60#[derive(Default)]
61pub struct ExtensionOpMap<'a, H>(BTreeMap<(ExtensionId, OpName), Box<dyn ExtensionOpFn<'a, H>>>);
62
63impl<'a, H: HugrView<Node = Node>> ExtensionOpMap<'a, H> {
64 /// Register a callback to emit a [`ExtensionOp`], keyed by fully
65 /// qualified [`OpName`].
66 pub fn extension_op(
67 &mut self,
68 extension: ExtensionId,
69 op: OpName,
70 handler: impl ExtensionOpFn<'a, H>,
71 ) {
72 self.0.insert((extension, op), Box::new(handler));
73 }
74
75 /// Register callbacks to emit [`ExtensionOp`]s that match the
76 /// definitions generated by `Op`s impl of [`strum::IntoEnumIterator`]>
77 pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
78 &mut self,
79 handler: impl 'a
80 + for<'c> Fn(
81 &mut EmitFuncContext<'c, 'a, H>,
82 EmitOpArgs<'c, '_, ExtensionOp, H>,
83 Op,
84 ) -> Result<()>,
85 ) {
86 let handler = Rc::new(handler);
87 for op in Op::iter() {
88 let handler = handler.clone();
89 self.extension_op(
90 op.extension(),
91 op.opdef_id().clone(),
92 move |context, args| {
93 let op = Op::from_extension_op(&args.node())?;
94 handler(context, args, op)
95 },
96 );
97 }
98 }
99
100 /// Emit an [`ExtensionOp`] by delegating to the collected callbacks.
101 ///
102 /// If no handler is registered for the op an error will be returned.
103 pub fn emit_extension_op<'c>(
104 &self,
105 context: &mut EmitFuncContext<'c, 'a, H>,
106 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
107 ) -> Result<()> {
108 let node = args.node();
109 let key = (
110 node.def().extension_id().clone(),
111 node.unqualified_id().into(),
112 );
113 let Some(handler) = self.0.get(&key) else {
114 bail!("No extension could emit extension op: {key:?}")
115 };
116 (handler.as_ref())(context, args)
117 }
118}