hugr_llvm/custom/
extension_op.rs

1use std::{collections::BTreeMap, rc::Rc};
2
3use hugr_core::{
4    HugrView, Node,
5    extension::{
6        ExtensionId,
7        simple_op::{MakeExtensionOp, MakeOpDef},
8    },
9    ops::{ExtensionOp, OpName},
10};
11
12use anyhow::{Result, bail};
13
14use strum::IntoEnumIterator;
15
16use crate::emit::{EmitFuncContext, EmitOpArgs};
17
18/// A helper trait for describing the callback used for emitting [`ExtensionOp`]s,
19/// and for hanging documentation.
20///
21/// We have the appropriate `Fn` as a supertrait,
22/// and there is a blanket impl for that `Fn`. We do not intend users to impl
23/// this trait.
24///
25/// `ExtensionOpFn` callbacks are registered against a fully qualified [`OpName`],
26/// i.e. including it's [`ExtensionId`]. Callbacks can assume that the provided
27/// [`EmitOpArgs::node`] holds an op matching that fully qualified name, and that
28/// the signature of that op determinies the length and types of
29/// [`EmitOpArgs::inputs`], and [`EmitOpArgs::outputs`] via
30/// [`EmitFuncContext::llvm_type`].
31///
32/// Callbacks should use the supplied [`EmitFuncContext`] to emit LLVM to match
33/// the desired semantics of the op. If a callback returns success then the callback must:
34///  - ensure that [`crate::emit::func::RowPromise::finish`] has been called on the outputs.
35///  - ensure that the contexts [`inkwell::builder::Builder`] is positioned at the end of a basic
36///    block, logically after the execution of the just-emitted op.
37///
38/// Callbacks may hold references with lifetimes older than `'a`.
39pub trait ExtensionOpFn<'a, H>:
40    for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
41{
42}
43
44impl<
45    'a,
46    H,
47    F: for<'c> Fn(
48            &mut EmitFuncContext<'c, 'a, H>,
49            EmitOpArgs<'c, '_, ExtensionOp, H>,
50        ) -> Result<()>
51        + ?Sized
52        + 'a,
53> ExtensionOpFn<'a, H> for F
54{
55}
56
57/// A collection of [`ExtensionOpFn`] callbacks keyed the fully qualified [`OpName`].
58///
59/// Those callbacks may hold references with lifetimes older than `'a`.
60#[derive(Default)]
61pub struct ExtensionOpMap<'a, H>(BTreeMap<(ExtensionId, OpName), Box<dyn ExtensionOpFn<'a, H>>>);
62
63impl<'a, H: HugrView<Node = Node>> ExtensionOpMap<'a, H> {
64    /// Register a callback to emit a [`ExtensionOp`], keyed by fully
65    /// qualified [`OpName`].
66    pub fn extension_op(
67        &mut self,
68        extension: ExtensionId,
69        op: OpName,
70        handler: impl ExtensionOpFn<'a, H>,
71    ) {
72        self.0.insert((extension, op), Box::new(handler));
73    }
74
75    /// Register callbacks to emit [`ExtensionOp`]s that match the
76    /// definitions generated by `Op`s impl of [`strum::IntoEnumIterator`]>
77    pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
78        &mut self,
79        handler: impl 'a
80        + for<'c> Fn(
81            &mut EmitFuncContext<'c, 'a, H>,
82            EmitOpArgs<'c, '_, ExtensionOp, H>,
83            Op,
84        ) -> Result<()>,
85    ) {
86        let handler = Rc::new(handler);
87        for op in Op::iter() {
88            let handler = handler.clone();
89            self.extension_op(
90                op.extension(),
91                op.opdef_id().clone(),
92                move |context, args| {
93                    let op = Op::from_extension_op(&args.node())?;
94                    handler(context, args, op)
95                },
96            );
97        }
98    }
99
100    /// Emit an [`ExtensionOp`]  by delegating to the collected callbacks.
101    ///
102    /// If no handler is registered for the op an error will be returned.
103    pub fn emit_extension_op<'c>(
104        &self,
105        context: &mut EmitFuncContext<'c, 'a, H>,
106        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
107    ) -> Result<()> {
108        let node = args.node();
109        let key = (
110            node.def().extension_id().clone(),
111            node.unqualified_id().into(),
112        );
113        let Some(handler) = self.0.get(&key) else {
114            bail!("No extension could emit extension op: {key:?}")
115        };
116        (handler.as_ref())(context, args)
117    }
118}