hugr_llvm/custom/
extension_op.rs

1use std::{collections::HashMap, rc::Rc};
2
3use hugr_core::{
4    extension::{
5        simple_op::{MakeExtensionOp, MakeOpDef},
6        ExtensionId,
7    },
8    ops::{ExtensionOp, OpName},
9    HugrView, Node,
10};
11
12use anyhow::{bail, Result};
13
14use strum::IntoEnumIterator;
15
16use crate::emit::{EmitFuncContext, EmitOpArgs};
17
18/// A helper trait for describing the callback used for emitting [ExtensionOp]s,
19/// and for hanging documentation. We have the appropriate `Fn` as a supertrait,
20/// and there is a blanket impl for that `Fn`. We do not intend users to impl
21/// this trait.
22///
23/// `ExtensionOpFn` callbacks are registered against a fully qualified [OpName],
24/// i.e. including it's [ExtensionId]. Callbacks can assume that the provided
25/// [EmitOpArgs::node] holds an op matching that fully qualified name, and that
26/// the signature of that op determinies the length and types of
27/// [EmitOpArgs::inputs], and [EmitOpArgs::outputs] via
28/// [EmitFuncContext::llvm_type].
29///
30/// Callbacks should use the supplied [EmitFuncContext] to emit LLVM to match
31/// the desired semantics of the op. If a callback returns success then the callback must:
32///  - ensure that [crate::emit::func::RowPromise::finish] has been called on the outputs.
33///  - ensure that the contexts [inkwell::builder::Builder] is positioned at the end of a basic
34///    block, logically after the execution of the just-emitted op.
35///
36/// Callbacks may hold references with lifetimes older than `'a`.
37pub trait ExtensionOpFn<'a, H>:
38    for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
39{
40}
41
42impl<
43        'a,
44        H,
45        F: for<'c> Fn(
46                &mut EmitFuncContext<'c, 'a, H>,
47                EmitOpArgs<'c, '_, ExtensionOp, H>,
48            ) -> Result<()>
49            + ?Sized
50            + 'a,
51    > ExtensionOpFn<'a, H> for F
52{
53}
54
55/// A collection of [ExtensionOpFn] callbacks keyed the fully qualified [OpName].
56///
57/// Those callbacks may hold references with lifetimes older than `'a`.
58#[derive(Default)]
59pub struct ExtensionOpMap<'a, H>(HashMap<(ExtensionId, OpName), Box<dyn ExtensionOpFn<'a, H>>>);
60
61impl<'a, H: HugrView<Node = Node>> ExtensionOpMap<'a, H> {
62    /// Register a callback to emit a [ExtensionOp], keyed by fully
63    /// qualified [OpName].
64    pub fn extension_op(
65        &mut self,
66        extension: ExtensionId,
67        op: OpName,
68        handler: impl ExtensionOpFn<'a, H>,
69    ) {
70        self.0.insert((extension, op), Box::new(handler));
71    }
72
73    /// Register callbacks to emit [ExtensionOp]s that match the
74    /// definitions generated by `Op`s impl of [strum::IntoEnumIterator]>
75    pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
76        &mut self,
77        handler: impl 'a
78            + for<'c> Fn(
79                &mut EmitFuncContext<'c, 'a, H>,
80                EmitOpArgs<'c, '_, ExtensionOp, H>,
81                Op,
82            ) -> Result<()>,
83    ) {
84        let handler = Rc::new(handler);
85        for op in Op::iter() {
86            let handler = handler.clone();
87            self.extension_op(op.extension(), op.name().clone(), move |context, args| {
88                let op = Op::from_extension_op(&args.node())?;
89                handler(context, args, op)
90            });
91        }
92    }
93
94    /// Emit an [ExtensionOp]  by delegating to the collected callbacks.
95    ///
96    /// If no handler is registered for the op an error will be returned.
97    pub fn emit_extension_op<'c>(
98        &self,
99        context: &mut EmitFuncContext<'c, 'a, H>,
100        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
101    ) -> Result<()> {
102        let node = args.node();
103        let key = (node.def().extension_id().clone(), node.def().name().clone());
104        let Some(handler) = self.0.get(&key) else {
105            bail!("No extension could emit extension op: {key:?}")
106        };
107        (handler.as_ref())(context, args)
108    }
109}