hugr_llvm/emit/
func.rs

1use std::{collections::BTreeMap, rc::Rc};
2
3use anyhow::{Result, anyhow};
4use hugr_core::{
5    HugrView, Node, NodeIndex, PortIndex, Wire,
6    extension::prelude::{either_type, option_type},
7    ops::{ExtensionOp, FuncDecl, FuncDefn, constant::CustomConst},
8    types::Type,
9};
10use inkwell::{
11    basic_block::BasicBlock,
12    builder::Builder,
13    context::Context,
14    module::Module,
15    types::{BasicType, BasicTypeEnum, FunctionType},
16    values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31/// A context for emitting an LLVM function.
32///
33/// One of the primary interfaces for implementing codegen extensions.
34/// We have methods for:
35///   * Converting from hugr [Type]s to LLVM [Type](BasicTypeEnum)s;
36///   * Maintaining [`MailBox`](RowMailBox) for each [Wire] in the [`FuncDefn`];
37///   * Accessing the [`CodegenExtsMap`];
38///   * Accessing an in internal [Builder].
39///
40/// The internal [Builder] must always be positioned at the end of a
41/// [`BasicBlock`]. This invariant is not checked when the builder is accessed
42/// through [`EmitFuncContext::builder`].
43///
44/// [`MailBox`](RowMailBox)es are stack allocations that are `alloca`ed in the
45/// first basic block of the function, read from to get the input values of each
46/// node, and written to with the output values of each node.
47pub struct EmitFuncContext<'c, 'a, H>
48where
49    'a: 'c,
50{
51    emit_context: EmitModuleContext<'c, 'a, H>,
52    todo: EmissionSet,
53    func: FunctionValue<'c>,
54    env: BTreeMap<Wire, ValueMailBox<'c>>,
55    builder: Builder<'c>,
56    prologue_bb: BasicBlock<'c>,
57    launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61    delegate! {
62        to self.emit_context {
63            /// Returns the inkwell [Context].
64            pub fn iw_context(&self) ->  &'c Context;
65            /// Returns the internal [CodegenExtsMap] .
66            pub fn extensions(&self) ->  Rc<CodegenExtsMap<'a,H>>;
67            /// Returns a new [TypingSession].
68            pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69            /// Convert hugr [HugrType] into an LLVM [Type](BasicTypeEnum).
70            pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
72            pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
74            pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75            /// Adds or gets the [FunctionValue] in the [inkwell::module::Module] corresponding to the given [FuncDefn].
76            ///
77            /// The name of the result may have been mangled.
78            pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79            /// Adds or gets the [FunctionValue] in the [inkwell::module::Module] corresponding to the given [FuncDecl].
80            ///
81            /// The name of the result may have been mangled.
82            pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83            /// Adds or get the [FunctionValue] in the [inkwell::module::Module] with the given symbol
84            /// and function type.
85            ///
86            /// The name undergoes no mangling. The [FunctionValue] will have
87            /// [inkwell::module::Linkage::External].
88            ///
89            /// If this function is called multiple times with the same arguments it
90            /// will return the same [FunctionValue].
91            ///
92            /// If a function with the given name exists but the type does not match
93            /// then an Error is returned.
94            pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95            /// Adds or gets the [GlobalValue] in the [inkwell::module::Module] corresponding to the
96            /// given symbol and LLVM type.
97            ///
98            /// The name will not be mangled.
99            ///
100            /// If a global with the given name exists but the type or constant-ness
101            /// does not match then an error will be returned.
102            pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103        }
104    }
105
106    /// Used when emitters encounter a scoped definition. `node` will be
107    /// returned from [`EmitFuncContext::finish`].
108    pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109        self.todo.insert(node.node());
110    }
111
112    /// Returns the current [FunctionValue] being emitted.
113    pub fn func(&self) -> FunctionValue<'c> {
114        self.func
115    }
116
117    /// Returns the internal [Builder]. Callers must ensure that it is
118    /// positioned at the end of a basic block. This invariant is not checked(it
119    /// doesn't seem possible to check it).
120    pub fn builder(&self) -> &Builder<'c> {
121        &self.builder
122    }
123
124    /// Create a new basic block. When `before` is `Some` the block will be
125    /// created immediately before that block, otherwise at the end of the
126    /// function.
127    pub(crate) fn new_basic_block(
128        &mut self,
129        name: impl AsRef<str>,
130        before: Option<BasicBlock<'c>>,
131    ) -> BasicBlock<'c> {
132        if let Some(before) = before {
133            self.iw_context().prepend_basic_block(before, name.as_ref())
134        } else {
135            self.iw_context()
136                .append_basic_block(self.func, name.as_ref())
137        }
138    }
139
140    fn prologue_block(&self) -> BasicBlock<'c> {
141        // guaranteed to exist because we create it in new
142        self.func.get_first_basic_block().unwrap()
143    }
144
145    /// Creates a new `EmitFuncContext` for `func`, taking ownership of
146    /// `emit_context`. `emit_context` will be returned in
147    /// [`EmitFuncContext::finish`].
148    ///
149    /// If `func` has any existing [`BasicBlock`]s we will fail.
150    ///
151    /// TODO on failure return `emit_context`
152    pub fn new(
153        emit_context: EmitModuleContext<'c, 'a, H>,
154        func: FunctionValue<'c>,
155    ) -> Result<EmitFuncContext<'c, 'a, H>> {
156        if func.get_first_basic_block().is_some() {
157            Err(anyhow!(
158                "EmitContext::new: Function already has a basic block: {:?}",
159                func.get_name()
160            ))?;
161        }
162        let prologue_bb = emit_context
163            .iw_context()
164            .append_basic_block(func, "alloca_block");
165        let launch_bb = emit_context
166            .iw_context()
167            .append_basic_block(func, "entry_block");
168        let builder = emit_context.iw_context().create_builder();
169        builder.position_at_end(launch_bb);
170        Ok(Self {
171            emit_context,
172            todo: Default::default(),
173            func,
174            env: Default::default(),
175            builder,
176            prologue_bb,
177            launch_bb,
178        })
179    }
180
181    fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
182        let bte = self.llvm_type(t)?;
183        let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
184        Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
185    }
186
187    /// Create a new anonymous [`RowMailBox`]. This mailbox is not mapped to any
188    /// [Wire]s, and so will not interact with any mailboxes returned from
189    /// [`EmitFuncContext::node_ins_rmb`] or [`EmitFuncContext::node_outs_rmb`].
190    pub fn new_row_mail_box<'t>(
191        &mut self,
192        ts: impl IntoIterator<Item = &'t Type>,
193        name: impl AsRef<str>,
194    ) -> Result<RowMailBox<'c>> {
195        Ok(RowMailBox::new(
196            ts.into_iter()
197                .enumerate()
198                .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
199                .collect::<Result<Vec<_>>>()?,
200            Some(name.as_ref().into()),
201        ))
202    }
203
204    fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
205        let b = self.prologue_block();
206        self.build_positioned(b, |x| f(&x.builder))
207    }
208
209    /// Creates a new [`BasicBlock`] and calls `f` with the internal builder
210    /// positioned at the start of that [`BasicBlock`]. The builder will be
211    /// repositioned back to it's location before `f` before this function
212    /// returns.
213    pub fn build_positioned_new_block<T>(
214        &mut self,
215        name: impl AsRef<str>,
216        before: Option<BasicBlock<'c>>,
217        f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
218    ) -> T {
219        let bb = self.new_basic_block(name, before);
220        self.build_positioned(bb, |s| f(s, bb))
221    }
222
223    /// Positions the internal builder at the end of `block` and calls `f`.  The
224    /// builder will be repositioned back to it's location before `f` before
225    /// this function returns.
226    pub fn build_positioned<T>(
227        &mut self,
228        block: BasicBlock<'c>,
229        f: impl FnOnce(&mut Self) -> T,
230    ) -> T {
231        // safe because our builder is always positioned
232        let current = self.builder.get_insert_block().unwrap();
233        self.builder.position_at_end(block);
234        let r = f(self);
235        self.builder.position_at_end(current);
236        r
237    }
238
239    /// Returns a [`RowMailBox`] mapped to the input wires of `node`. When emitting a node
240    /// input values are from this mailbox.
241    pub fn node_ins_rmb<'hugr, OT: 'hugr>(
242        &mut self,
243        node: FatNode<'hugr, OT, H>,
244    ) -> Result<RowMailBox<'c>> {
245        let r = node
246            .in_value_types()
247            .map(|(p, t)| {
248                let (slo_n, slo_p) = node
249                    .single_linked_output(p)
250                    .ok_or(anyhow!("No single linked output"))?;
251                self.map_wire(slo_n, slo_p, &t)
252            })
253            .collect::<Result<RowMailBox>>()?;
254
255        debug_assert!(
256            zip_eq(node.in_value_types(), r.get_types())
257                .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
258        );
259        Ok(r)
260    }
261
262    /// Returns a [`RowMailBox`] mapped to the output wires of `node`. When emitting a node
263    /// output values are written to this mailbox.
264    pub fn node_outs_rmb<'hugr, OT: 'hugr>(
265        &mut self,
266        node: FatNode<'hugr, OT, H>,
267    ) -> Result<RowMailBox<'c>> {
268        let r = node
269            .out_value_types()
270            .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
271            .collect::<Result<RowMailBox>>()?;
272        debug_assert!(
273            zip_eq(node.out_value_types(), r.get_types())
274                .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
275        );
276        Ok(r)
277    }
278
279    fn map_wire<'hugr, OT>(
280        &mut self,
281        node: FatNode<'hugr, OT, H>,
282        port: hugr_core::OutgoingPort,
283        hugr_type: &Type,
284    ) -> Result<ValueMailBox<'c>> {
285        let wire = Wire::new(node.node(), port);
286        if let Some(mb) = self.env.get(&wire) {
287            debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
288            return Ok(mb.clone());
289        }
290        let mb = self.new_value_mail_box(
291            hugr_type,
292            format!("{}_{}", node.node().index(), port.index()),
293        )?;
294        self.env.insert(wire, mb.clone());
295        Ok(mb)
296    }
297
298    pub fn get_current_module(&self) -> &Module<'c> {
299        self.emit_context.module()
300    }
301
302    pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
303        let exts = self.extensions();
304        exts.as_ref()
305            .load_constant_handlers
306            .emit_load_constant(self, v)
307    }
308
309    pub(crate) fn emit_extension_op(
310        &mut self,
311        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
312    ) -> Result<()> {
313        let exts = self.extensions();
314        exts.as_ref()
315            .extension_op_handlers
316            .emit_extension_op(self, args)
317    }
318
319    /// Consumes the `EmitFuncContext` and returns both the inner
320    /// [`EmitModuleContext`] and the scoped [`FuncDefn`]s that were encountered.
321    pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
322        self.builder.position_at_end(self.prologue_bb);
323        self.builder.build_unconditional_branch(self.launch_bb)?;
324        Ok((self.emit_context, self.todo))
325    }
326}
327
328/// Builds an optional value wrapping `some_value` conditioned on the provided `is_some` flag.
329pub fn build_option<'c, H: HugrView<Node = Node>>(
330    ctx: &mut EmitFuncContext<'c, '_, H>,
331    is_some: IntValue<'c>,
332    some_value: BasicValueEnum<'c>,
333    hugr_ty: HugrType,
334) -> Result<BasicValueEnum<'c>> {
335    let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
336    let builder = ctx.builder();
337    let some = option_ty.build_tag(builder, 1, vec![some_value])?;
338    let none = option_ty.build_tag(builder, 0, vec![])?;
339    let option = builder.build_select(is_some, some, none, "")?;
340    Ok(option)
341}
342
343/// Builds a result value wrapping either `ok_value` or `else_value` depending on the provided
344/// `is_ok` flag.
345pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
346    ctx: &mut EmitFuncContext<'c, '_, H>,
347    is_ok: IntValue<'c>,
348    ok_value: BasicValueEnum<'c>,
349    ok_hugr_ty: HugrType,
350    else_value: BasicValueEnum<'c>,
351    else_hugr_ty: HugrType,
352) -> Result<BasicValueEnum<'c>> {
353    let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
354    let builder = ctx.builder();
355    let left = either_ty.build_tag(builder, 0, vec![else_value])?;
356    let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
357    let either = builder.build_select(is_ok, right, left, "")?;
358    Ok(either)
359}
360
361#[cfg(test)]
362mod tests {
363    #[test]
364    fn test_func_getter() {
365        // Use TestContext for consistent test setup
366        let test_ctx = crate::test::test_ctx(-1);
367        let emit_context = test_ctx.get_emit_module_context();
368        let func_type = emit_context.iw_context().void_type().fn_type(&[], false);
369        let function = emit_context
370            .module()
371            .add_function("test_func", func_type, None);
372        let func_context = super::EmitFuncContext::new(emit_context, function).unwrap();
373
374        // Assert the getter returns the correct function
375        assert_eq!(func_context.func(), function);
376    }
377}