hugr_llvm/custom/extension_op.rs
1use std::{collections::HashMap, rc::Rc};
2
3use hugr_core::{
4 extension::{
5 simple_op::{MakeExtensionOp, MakeOpDef},
6 ExtensionId,
7 },
8 ops::{ExtensionOp, OpName},
9 HugrView, Node,
10};
11
12use anyhow::{bail, Result};
13
14use strum::IntoEnumIterator;
15
16use crate::emit::{EmitFuncContext, EmitOpArgs};
17
18/// A helper trait for describing the callback used for emitting [ExtensionOp]s,
19/// and for hanging documentation. We have the appropriate `Fn` as a supertrait,
20/// and there is a blanket impl for that `Fn`. We do not intend users to impl
21/// this trait.
22///
23/// `ExtensionOpFn` callbacks are registered against a fully qualified [OpName],
24/// i.e. including it's [ExtensionId]. Callbacks can assume that the provided
25/// [EmitOpArgs::node] holds an op matching that fully qualified name, and that
26/// the signature of that op determinies the length and types of
27/// [EmitOpArgs::inputs], and [EmitOpArgs::outputs] via
28/// [EmitFuncContext::llvm_type].
29///
30/// Callbacks should use the supplied [EmitFuncContext] to emit LLVM to match
31/// the desired semantics of the op. If a callback returns success then the callback must:
32/// - ensure that [crate::emit::func::RowPromise::finish] has been called on the outputs.
33/// - ensure that the contexts [inkwell::builder::Builder] is positioned at the end of a basic
34/// block, logically after the execution of the just-emitted op.
35///
36/// Callbacks may hold references with lifetimes older than `'a`.
37pub trait ExtensionOpFn<'a, H>:
38 for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
39{
40}
41
42impl<
43 'a,
44 H,
45 F: for<'c> Fn(
46 &mut EmitFuncContext<'c, 'a, H>,
47 EmitOpArgs<'c, '_, ExtensionOp, H>,
48 ) -> Result<()>
49 + ?Sized
50 + 'a,
51 > ExtensionOpFn<'a, H> for F
52{
53}
54
55/// A collection of [ExtensionOpFn] callbacks keyed the fully qualified [OpName].
56///
57/// Those callbacks may hold references with lifetimes older than `'a`.
58#[derive(Default)]
59pub struct ExtensionOpMap<'a, H>(HashMap<(ExtensionId, OpName), Box<dyn ExtensionOpFn<'a, H>>>);
60
61impl<'a, H: HugrView<Node = Node>> ExtensionOpMap<'a, H> {
62 /// Register a callback to emit a [ExtensionOp], keyed by fully
63 /// qualified [OpName].
64 pub fn extension_op(
65 &mut self,
66 extension: ExtensionId,
67 op: OpName,
68 handler: impl ExtensionOpFn<'a, H>,
69 ) {
70 self.0.insert((extension, op), Box::new(handler));
71 }
72
73 /// Register callbacks to emit [ExtensionOp]s that match the
74 /// definitions generated by `Op`s impl of [strum::IntoEnumIterator]>
75 pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
76 &mut self,
77 handler: impl 'a
78 + for<'c> Fn(
79 &mut EmitFuncContext<'c, 'a, H>,
80 EmitOpArgs<'c, '_, ExtensionOp, H>,
81 Op,
82 ) -> Result<()>,
83 ) {
84 let handler = Rc::new(handler);
85 for op in Op::iter() {
86 let handler = handler.clone();
87 self.extension_op(op.extension(), op.name().clone(), move |context, args| {
88 let op = Op::from_extension_op(&args.node())?;
89 handler(context, args, op)
90 });
91 }
92 }
93
94 /// Emit an [ExtensionOp] by delegating to the collected callbacks.
95 ///
96 /// If no handler is registered for the op an error will be returned.
97 pub fn emit_extension_op<'c>(
98 &self,
99 context: &mut EmitFuncContext<'c, 'a, H>,
100 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
101 ) -> Result<()> {
102 let node = args.node();
103 let key = (node.def().extension_id().clone(), node.def().name().clone());
104 let Some(handler) = self.0.get(&key) else {
105 bail!("No extension could emit extension op: {key:?}")
106 };
107 (handler.as_ref())(context, args)
108 }
109}