hugr_llvm/emit/
func.rs

1use std::{collections::BTreeMap, rc::Rc};
2
3use anyhow::{Result, anyhow};
4use hugr_core::{
5    HugrView, Node, NodeIndex, PortIndex, Wire,
6    extension::prelude::{either_type, option_type},
7    ops::{ExtensionOp, FuncDecl, FuncDefn, constant::CustomConst},
8    types::Type,
9};
10use inkwell::{
11    basic_block::BasicBlock,
12    builder::Builder,
13    context::Context,
14    module::Module,
15    types::{BasicType, BasicTypeEnum, FunctionType},
16    values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31/// A context for emitting an LLVM function.
32///
33/// One of the primary interfaces for implementing codegen extensions.
34/// We have methods for:
35///   * Converting from hugr [Type]s to LLVM [Type](BasicTypeEnum)s;
36///   * Maintaining [`MailBox`](RowMailBox) for each [Wire] in the [`FuncDefn`];
37///   * Accessing the [`CodegenExtsMap`];
38///   * Accessing an in internal [Builder].
39///
40/// The internal [Builder] must always be positioned at the end of a
41/// [`BasicBlock`]. This invariant is not checked when the builder is accessed
42/// through [`EmitFuncContext::builder`].
43///
44/// [`MailBox`](RowMailBox)es are stack allocations that are `alloca`ed in the
45/// first basic block of the function, read from to get the input values of each
46/// node, and written to with the output values of each node.
47pub struct EmitFuncContext<'c, 'a, H>
48where
49    'a: 'c,
50{
51    emit_context: EmitModuleContext<'c, 'a, H>,
52    todo: EmissionSet,
53    func: FunctionValue<'c>,
54    env: BTreeMap<Wire, ValueMailBox<'c>>,
55    builder: Builder<'c>,
56    prologue_bb: BasicBlock<'c>,
57    launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61    delegate! {
62        to self.emit_context {
63            /// Returns the inkwell [Context].
64            pub fn iw_context(&self) ->  &'c Context;
65            /// Returns the internal [CodegenExtsMap] .
66            pub fn extensions(&self) ->  Rc<CodegenExtsMap<'a,H>>;
67            /// Returns a new [TypingSession].
68            pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69            /// Convert hugr [HugrType] into an LLVM [Type](BasicTypeEnum).
70            pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
72            pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
74            pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75            /// Adds or gets the [FunctionValue] in the [inkwell::module::Module] corresponding to the given [FuncDefn].
76            ///
77            /// The name of the result may have been mangled.
78            pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79            /// Adds or gets the [FunctionValue] in the [inkwell::module::Module] corresponding to the given [FuncDecl].
80            ///
81            /// The name of the result may have been mangled.
82            pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83            /// Adds or get the [FunctionValue] in the [inkwell::module::Module] with the given symbol
84            /// and function type.
85            ///
86            /// The name undergoes no mangling. The [FunctionValue] will have
87            /// [inkwell::module::Linkage::External].
88            ///
89            /// If this function is called multiple times with the same arguments it
90            /// will return the same [FunctionValue].
91            ///
92            /// If a function with the given name exists but the type does not match
93            /// then an Error is returned.
94            pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95            /// Adds or gets the [GlobalValue] in the [inkwell::module::Module] corresponding to the
96            /// given symbol and LLVM type.
97            ///
98            /// The name will not be mangled.
99            ///
100            /// If a global with the given name exists but the type or constant-ness
101            /// does not match then an error will be returned.
102            pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103        }
104    }
105
106    /// Used when emitters encounter a scoped definition. `node` will be
107    /// returned from [`EmitFuncContext::finish`].
108    pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109        self.todo.insert(node.node());
110    }
111
112    /// Returns the internal [Builder]. Callers must ensure that it is
113    /// positioned at the end of a basic block. This invariant is not checked(it
114    /// doesn't seem possible to check it).
115    pub fn builder(&self) -> &Builder<'c> {
116        &self.builder
117    }
118
119    /// Create a new basic block. When `before` is `Some` the block will be
120    /// created immediately before that block, otherwise at the end of the
121    /// function.
122    pub(crate) fn new_basic_block(
123        &mut self,
124        name: impl AsRef<str>,
125        before: Option<BasicBlock<'c>>,
126    ) -> BasicBlock<'c> {
127        if let Some(before) = before {
128            self.iw_context().prepend_basic_block(before, name.as_ref())
129        } else {
130            self.iw_context()
131                .append_basic_block(self.func, name.as_ref())
132        }
133    }
134
135    fn prologue_block(&self) -> BasicBlock<'c> {
136        // guaranteed to exist because we create it in new
137        self.func.get_first_basic_block().unwrap()
138    }
139
140    /// Creates a new `EmitFuncContext` for `func`, taking ownership of
141    /// `emit_context`. `emit_context` will be returned in
142    /// [`EmitFuncContext::finish`].
143    ///
144    /// If `func` has any existing [`BasicBlock`]s we will fail.
145    ///
146    /// TODO on failure return `emit_context`
147    pub fn new(
148        emit_context: EmitModuleContext<'c, 'a, H>,
149        func: FunctionValue<'c>,
150    ) -> Result<EmitFuncContext<'c, 'a, H>> {
151        if func.get_first_basic_block().is_some() {
152            Err(anyhow!(
153                "EmitContext::new: Function already has a basic block: {:?}",
154                func.get_name()
155            ))?;
156        }
157        let prologue_bb = emit_context
158            .iw_context()
159            .append_basic_block(func, "alloca_block");
160        let launch_bb = emit_context
161            .iw_context()
162            .append_basic_block(func, "entry_block");
163        let builder = emit_context.iw_context().create_builder();
164        builder.position_at_end(launch_bb);
165        Ok(Self {
166            emit_context,
167            todo: Default::default(),
168            func,
169            env: Default::default(),
170            builder,
171            prologue_bb,
172            launch_bb,
173        })
174    }
175
176    fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
177        let bte = self.llvm_type(t)?;
178        let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
179        Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
180    }
181
182    /// Create a new anonymous [`RowMailBox`]. This mailbox is not mapped to any
183    /// [Wire]s, and so will not interact with any mailboxes returned from
184    /// [`EmitFuncContext::node_ins_rmb`] or [`EmitFuncContext::node_outs_rmb`].
185    pub fn new_row_mail_box<'t>(
186        &mut self,
187        ts: impl IntoIterator<Item = &'t Type>,
188        name: impl AsRef<str>,
189    ) -> Result<RowMailBox<'c>> {
190        Ok(RowMailBox::new(
191            ts.into_iter()
192                .enumerate()
193                .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
194                .collect::<Result<Vec<_>>>()?,
195            Some(name.as_ref().into()),
196        ))
197    }
198
199    fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
200        let b = self.prologue_block();
201        self.build_positioned(b, |x| f(&x.builder))
202    }
203
204    /// Creates a new [`BasicBlock`] and calls `f` with the internal builder
205    /// positioned at the start of that [`BasicBlock`]. The builder will be
206    /// repositioned back to it's location before `f` before this function
207    /// returns.
208    pub fn build_positioned_new_block<T>(
209        &mut self,
210        name: impl AsRef<str>,
211        before: Option<BasicBlock<'c>>,
212        f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
213    ) -> T {
214        let bb = self.new_basic_block(name, before);
215        self.build_positioned(bb, |s| f(s, bb))
216    }
217
218    /// Positions the internal builder at the end of `block` and calls `f`.  The
219    /// builder will be repositioned back to it's location before `f` before
220    /// this function returns.
221    pub fn build_positioned<T>(
222        &mut self,
223        block: BasicBlock<'c>,
224        f: impl FnOnce(&mut Self) -> T,
225    ) -> T {
226        // safe because our builder is always positioned
227        let current = self.builder.get_insert_block().unwrap();
228        self.builder.position_at_end(block);
229        let r = f(self);
230        self.builder.position_at_end(current);
231        r
232    }
233
234    /// Returns a [`RowMailBox`] mapped to the input wires of `node`. When emitting a node
235    /// input values are from this mailbox.
236    pub fn node_ins_rmb<'hugr, OT: 'hugr>(
237        &mut self,
238        node: FatNode<'hugr, OT, H>,
239    ) -> Result<RowMailBox<'c>> {
240        let r = node
241            .in_value_types()
242            .map(|(p, t)| {
243                let (slo_n, slo_p) = node
244                    .single_linked_output(p)
245                    .ok_or(anyhow!("No single linked output"))?;
246                self.map_wire(slo_n, slo_p, &t)
247            })
248            .collect::<Result<RowMailBox>>()?;
249
250        debug_assert!(
251            zip_eq(node.in_value_types(), r.get_types())
252                .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
253        );
254        Ok(r)
255    }
256
257    /// Returns a [`RowMailBox`] mapped to the output wires of `node`. When emitting a node
258    /// output values are written to this mailbox.
259    pub fn node_outs_rmb<'hugr, OT: 'hugr>(
260        &mut self,
261        node: FatNode<'hugr, OT, H>,
262    ) -> Result<RowMailBox<'c>> {
263        let r = node
264            .out_value_types()
265            .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
266            .collect::<Result<RowMailBox>>()?;
267        debug_assert!(
268            zip_eq(node.out_value_types(), r.get_types())
269                .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
270        );
271        Ok(r)
272    }
273
274    fn map_wire<'hugr, OT>(
275        &mut self,
276        node: FatNode<'hugr, OT, H>,
277        port: hugr_core::OutgoingPort,
278        hugr_type: &Type,
279    ) -> Result<ValueMailBox<'c>> {
280        let wire = Wire::new(node.node(), port);
281        if let Some(mb) = self.env.get(&wire) {
282            debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
283            return Ok(mb.clone());
284        }
285        let mb = self.new_value_mail_box(
286            hugr_type,
287            format!("{}_{}", node.node().index(), port.index()),
288        )?;
289        self.env.insert(wire, mb.clone());
290        Ok(mb)
291    }
292
293    pub fn get_current_module(&self) -> &Module<'c> {
294        self.emit_context.module()
295    }
296
297    pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
298        let exts = self.extensions();
299        exts.as_ref()
300            .load_constant_handlers
301            .emit_load_constant(self, v)
302    }
303
304    pub(crate) fn emit_extension_op(
305        &mut self,
306        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
307    ) -> Result<()> {
308        let exts = self.extensions();
309        exts.as_ref()
310            .extension_op_handlers
311            .emit_extension_op(self, args)
312    }
313
314    /// Consumes the `EmitFuncContext` and returns both the inner
315    /// [`EmitModuleContext`] and the scoped [`FuncDefn`]s that were encountered.
316    pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
317        self.builder.position_at_end(self.prologue_bb);
318        self.builder.build_unconditional_branch(self.launch_bb)?;
319        Ok((self.emit_context, self.todo))
320    }
321}
322
323/// Builds an optional value wrapping `some_value` conditioned on the provided `is_some` flag.
324pub fn build_option<'c, H: HugrView<Node = Node>>(
325    ctx: &mut EmitFuncContext<'c, '_, H>,
326    is_some: IntValue<'c>,
327    some_value: BasicValueEnum<'c>,
328    hugr_ty: HugrType,
329) -> Result<BasicValueEnum<'c>> {
330    let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
331    let builder = ctx.builder();
332    let some = option_ty.build_tag(builder, 1, vec![some_value])?;
333    let none = option_ty.build_tag(builder, 0, vec![])?;
334    let option = builder.build_select(is_some, some, none, "")?;
335    Ok(option)
336}
337
338/// Builds a result value wrapping either `ok_value` or `else_value` depending on the provided
339/// `is_ok` flag.
340pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
341    ctx: &mut EmitFuncContext<'c, '_, H>,
342    is_ok: IntValue<'c>,
343    ok_value: BasicValueEnum<'c>,
344    ok_hugr_ty: HugrType,
345    else_value: BasicValueEnum<'c>,
346    else_hugr_ty: HugrType,
347) -> Result<BasicValueEnum<'c>> {
348    let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
349    let builder = ctx.builder();
350    let left = either_ty.build_tag(builder, 0, vec![else_value])?;
351    let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
352    let either = builder.build_select(is_ok, right, left, "")?;
353    Ok(either)
354}