hugr_llvm/
emit.rs

1use anyhow::{Result, anyhow};
2use delegate::delegate;
3use hugr_core::{
4    HugrView, Node, Visibility,
5    ops::{FuncDecl, FuncDefn, OpType},
6    types::PolyFuncType,
7};
8use inkwell::{
9    builder::Builder,
10    context::Context,
11    intrinsics::Intrinsic,
12    module::{Linkage, Module},
13    types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
14    values::{BasicValueEnum, CallSiteValue, FunctionValue, GlobalValue},
15};
16use std::{collections::HashSet, rc::Rc};
17
18use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
19
20use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
21
22pub mod args;
23pub mod func;
24pub mod libc;
25pub mod namer;
26pub mod ops;
27
28pub use args::EmitOpArgs;
29pub use func::{EmitFuncContext, RowPromise};
30pub use namer::Namer;
31pub use ops::emit_value;
32
33/// A context holding data required for emitting HUGRs into an LLVM module.
34/// This includes the module itself, a set of extensions for lowering custom
35/// elements, and policy for naming various HUGR elements.
36///
37/// `'c` names the lifetime of the LLVM context, while `'a` names the lifetime
38/// of other internal references.
39pub struct EmitModuleContext<'c, 'a, H>
40where
41    'a: 'c,
42{
43    iw_context: &'c Context,
44    module: Module<'c>,
45    extensions: Rc<CodegenExtsMap<'a, H>>,
46    namer: Rc<Namer>,
47}
48
49impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
50    delegate! {
51        to self.typing_session() {
52            /// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
53            pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
54            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
55            pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
56            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
57            pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
58        }
59
60        to self.namer {
61            /// Mangle the name of a [FuncDefn]  or a [FuncDecl].
62            pub fn name_func(&self, name: impl AsRef<str>, node: Node) -> String;
63        }
64    }
65
66    pub fn iw_context(&self) -> &'c Context {
67        self.iw_context
68    }
69
70    /// Creates a new  `EmitModuleContext`. We take ownership of the [Module],
71    /// and return it in [`EmitModuleContext::finish`].
72    pub fn new(
73        iw_context: &'c Context,
74        module: Module<'c>,
75        namer: Rc<Namer>,
76        extensions: Rc<CodegenExtsMap<'a, H>>,
77    ) -> Self {
78        Self {
79            iw_context,
80            module,
81            extensions,
82            namer,
83        }
84    }
85
86    /// Returns a reference to the inner [Module]. Note that this type has
87    /// "interior mutability", and this reference can be used to add functions
88    /// and globals to the [Module].
89    pub fn module(&self) -> &Module<'c> {
90        &self.module
91    }
92
93    /// Returns a reference to the inner [`CodegenExtsMap`].
94    pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
95        self.extensions.clone()
96    }
97
98    /// Returns a [`TypingSession`] constructed from it's members.
99    pub fn typing_session(&self) -> TypingSession<'c, 'a> {
100        self.extensions
101            .type_converter
102            .clone()
103            .session(self.iw_context)
104    }
105
106    fn get_func_impl(
107        &self,
108        name: impl AsRef<str>,
109        func_ty: FunctionType<'c>,
110        linkage: Option<Linkage>,
111    ) -> Result<FunctionValue<'c>> {
112        let func = self
113            .module()
114            .get_function(name.as_ref())
115            .unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
116        if func.get_type() != func_ty {
117            Err(anyhow!(
118                "Function '{}' has wrong type: expected: {func_ty} actual: {}",
119                name.as_ref(),
120                func.get_type()
121            ))?;
122        }
123        Ok(func)
124    }
125
126    fn get_hugr_func_impl(
127        &self,
128        name: impl AsRef<str>,
129        node: Node,
130        func_ty: &PolyFuncType,
131        visibility: &Visibility,
132    ) -> Result<FunctionValue<'c>> {
133        let func_ty = (func_ty.params().is_empty())
134            .then_some(func_ty.body())
135            .ok_or(anyhow!("function has type params"))?;
136        let llvm_func_ty = self.llvm_func_type(func_ty)?;
137        let name = self.name_func(name, node);
138        match visibility {
139            Visibility::Public => self.get_func_impl(name, llvm_func_ty, Some(Linkage::External)),
140            Visibility::Private => self.get_func_impl(name, llvm_func_ty, Some(Linkage::Private)),
141            _ => self.get_func_impl(name, llvm_func_ty, None),
142        }
143    }
144
145    /// Adds or gets the [`FunctionValue`] in the [Module] corresponding to the given [`FuncDefn`].
146    ///
147    /// The name of the result is mangled by [`EmitModuleContext::name_func`].
148    pub fn get_func_defn<'hugr>(
149        &self,
150        node: FatNode<'hugr, FuncDefn, H>,
151    ) -> Result<FunctionValue<'c>>
152    where
153        H: HugrView<Node = Node>,
154    {
155        self.get_hugr_func_impl(
156            node.func_name(),
157            node.node(),
158            node.signature(),
159            node.visibility(),
160        )
161    }
162
163    /// Adds or gets the [`FunctionValue`] in the [Module] corresponding to the given [`FuncDecl`].
164    ///
165    /// The name of the result is mangled by [`EmitModuleContext::name_func`].
166    pub fn get_func_decl<'hugr>(
167        &self,
168        node: FatNode<'hugr, FuncDecl, H>,
169    ) -> Result<FunctionValue<'c>>
170    where
171        H: HugrView<Node = Node>,
172    {
173        self.get_hugr_func_impl(
174            node.func_name(),
175            node.node(),
176            node.signature(),
177            node.visibility(),
178        )
179    }
180
181    /// Adds or get the [`FunctionValue`] in the [Module] with the given symbol
182    /// and function type.
183    ///
184    /// The name undergoes no mangling. The [`FunctionValue`] will have
185    /// [`Linkage::External`].
186    ///
187    /// If this function is called multiple times with the same arguments it
188    /// will return the same [`FunctionValue`].
189    ///
190    /// If a function with the given name exists but the type does not match
191    /// then an Error is returned.
192    pub fn get_extern_func(
193        &self,
194        symbol: impl AsRef<str>,
195        typ: FunctionType<'c>,
196    ) -> Result<FunctionValue<'c>> {
197        self.get_func_impl(symbol, typ, Some(Linkage::External))
198    }
199
200    /// Adds or gets the [`GlobalValue`] in the [Module] corresponding to the
201    /// given symbol and LLVM type.
202    ///
203    /// The name will not be mangled.
204    ///
205    /// If a global with the given name exists but the type or constant-ness
206    /// does not match then an error will be returned.
207    pub fn get_global(
208        &self,
209        symbol: impl AsRef<str>,
210        typ: impl BasicType<'c>,
211        constant: bool,
212    ) -> Result<GlobalValue<'c>> {
213        let symbol = symbol.as_ref();
214        let typ = typ.as_basic_type_enum();
215        if let Some(global) = self.module().get_global(symbol) {
216            let global_type = {
217                // TODO This is exposed as `get_value_type` on the master branch
218                // of inkwell, will be in the next release. When it's released
219                // use `get_value_type`.
220                use inkwell::types::AnyTypeEnum;
221                use inkwell::values::AsValueRef;
222                unsafe {
223                    AnyTypeEnum::new(inkwell::llvm_sys::core::LLVMGlobalGetValueType(
224                        global.as_value_ref(),
225                    ))
226                }
227            };
228            if global_type != typ.as_any_type_enum() {
229                Err(anyhow!(
230                    "Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
231                ))?;
232            }
233            if global.is_constant() != constant {
234                Err(anyhow!(
235                    "Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
236                    global.is_constant()
237                ))?;
238            }
239            Ok(global)
240        } else {
241            let global = self.module().add_global(typ, None, symbol.as_ref());
242            global.set_constant(constant);
243            Ok(global)
244        }
245    }
246
247    /// Consumes the `EmitModuleContext` and returns the internal [Module].
248    pub fn finish(self) -> Module<'c> {
249        self.module
250    }
251}
252
253type EmissionSet = HashSet<Node>;
254
255/// Emits [`HugrView`]s into an LLVM [Module].
256pub struct EmitHugr<'c, 'a, H>
257where
258    'a: 'c,
259{
260    emitted: EmissionSet,
261    module_context: EmitModuleContext<'c, 'a, H>,
262}
263
264impl<'c, 'a, H: HugrView<Node = Node>> EmitHugr<'c, 'a, H> {
265    delegate! {
266        to self.module_context {
267            /// Returns a reference to the inner [Context].
268            pub fn iw_context(&self) -> &'c Context;
269            /// Returns a reference to the inner [Module]. Note that this type has
270            /// "interior mutability", and this reference can be used to add functions
271            /// and globals to the [Module].
272            pub fn module(&self) -> &Module<'c>;
273        }
274    }
275
276    /// Creates a new  `EmitHugr`. We take ownership of the [Module], and return it in [`Self::finish`].
277    pub fn new(
278        iw_context: &'c Context,
279        module: Module<'c>,
280        namer: Rc<Namer>,
281        extensions: Rc<CodegenExtsMap<'a, H>>,
282    ) -> Self {
283        assert_eq!(iw_context, &module.get_context());
284        Self {
285            emitted: Default::default(),
286            module_context: EmitModuleContext::new(iw_context, module, namer, extensions),
287        }
288    }
289
290    /// Emits a `FuncDefn` into the inner [Module].
291    ///
292    /// `node` need not be a child of a hugr [Module](hugr_core::ops::Module), but it will
293    /// be emitted as a top-level function in the inner [Module]. Indeed, there
294    /// are only top-level functions in LLVM IR.
295    ///
296    /// Any child [`FuncDefn`] will also be emitted.
297    ///
298    /// It is safe to emit the same node multiple times: the second and further
299    /// emissions will be no-ops.
300    ///
301    /// If any LLVM IR declaration which is to be emitted already exists in the
302    /// [Module] and it differs from what would be emitted, then we fail.
303    pub fn emit_func(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<Self> {
304        let mut worklist: EmissionSet = [node.node()].into_iter().collect();
305        let pop = |wl: &mut EmissionSet| wl.iter().next().copied().map(|x| wl.take(&x).unwrap());
306
307        while let Some(next_node) = pop(&mut worklist) {
308            use crate::utils::fat::FatExt as _;
309            let Some(func) = node.hugr().try_fat(next_node) else {
310                panic!(
311                    "emit_func: node in worklist was not a FuncDefn: {:?}",
312                    node.hugr().get_optype(next_node)
313                )
314            };
315            let (new_self, new_tasks) = self.emit_func_impl(func)?;
316            self = new_self;
317            worklist.extend(new_tasks.into_iter());
318        }
319        Ok(self)
320    }
321
322    /// Emits all children of a hugr [Module](hugr_core::ops::Module).
323    ///
324    /// Note that type aliases are not supported, and that [`hugr_core::ops::Const`]
325    /// and [`hugr_core::ops::FuncDecl`] nodes are not emitted directly, but instead by
326    /// emission of ops with static edges from them. So [`FuncDefn`] are the only
327    /// interesting children.
328    pub fn emit_module(mut self, node: FatNode<'_, hugr_core::ops::Module, H>) -> Result<Self> {
329        for c in node.children() {
330            match c.as_ref() {
331                OpType::FuncDefn(fd) => {
332                    let fat_ot = c.into_ot(fd);
333                    self = self.emit_func(fat_ot)?;
334                }
335                // FuncDecls are allowed, but we don't need to do anything here.
336                OpType::FuncDecl(_) => (),
337                // Consts are allowed, but we don't need to do anything here.
338                OpType::Const(_) => (),
339                _ => Err(anyhow!("Module has invalid child: {c}"))?,
340            }
341        }
342        Ok(self)
343    }
344
345    fn emit_func_impl(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<(Self, EmissionSet)> {
346        if !self.emitted.insert(node.node()) {
347            return Ok((self, EmissionSet::default()));
348        }
349        let func = self.module_context.get_func_defn(node)?;
350        let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
351        let ret_rmb = func_ctx.new_row_mail_box(node.signature().body().output.iter(), "ret")?;
352        ops::emit_dataflow_parent(
353            &mut func_ctx,
354            EmitOpArgs {
355                node,
356                inputs: func.get_params(),
357                outputs: ret_rmb.promise(),
358            },
359        )?;
360        let builder = func_ctx.builder();
361        match &ret_rmb.read::<Vec<_>>(builder, [])?[..] {
362            [] => builder.build_return(None)?,
363            [x] => builder.build_return(Some(x))?,
364            xs => builder.build_aggregate_return(xs)?,
365        };
366        let (mctx, todos) = func_ctx.finish()?;
367        self.module_context = mctx;
368        Ok((self, todos))
369    }
370
371    /// Consumes the `EmitHugr` and returns the internal [Module].
372    pub fn finish(self) -> Module<'c> {
373        self.module_context.finish()
374    }
375}
376
377/// Extract all return values from the result of a `call`.
378///
379/// LLVM only supports functions with exactly zero or one return value.
380/// For functions with multiple return values, we return a struct containing
381/// all the return values.
382///
383/// `inkwell` provides a helper [`Builder::build_aggregate_return`] to construct
384/// the return value, see `EmitHugr::emit_func_impl`. This function performs the
385/// inverse.
386pub fn deaggregate_call_result<'c>(
387    builder: &Builder<'c>,
388    call_result: CallSiteValue<'c>,
389    num_results: usize,
390) -> Result<Vec<BasicValueEnum<'c>>> {
391    let call_result = call_result.try_as_basic_value();
392    Ok(match num_results as u32 {
393        0 => {
394            call_result.expect_right("void");
395            vec![]
396        }
397        1 => vec![call_result.expect_left("non-void")],
398        n => {
399            let return_struct = call_result.expect_left("non-void").into_struct_value();
400            (0..n)
401                .map(|i| builder.build_extract_value(return_struct, i, ""))
402                .collect::<Result<Vec<_>, _>>()?
403        }
404    })
405}
406
407pub fn get_intrinsic<'c>(
408    module: &Module<'c>,
409    name: impl AsRef<str>,
410    args: impl AsRef<[BasicTypeEnum<'c>]>,
411) -> Result<FunctionValue<'c>> {
412    let (name, args) = (name.as_ref(), args.as_ref());
413    let intrinsic = Intrinsic::find(name).ok_or(anyhow!("Failed to find intrinsic: '{name}'"))?;
414    intrinsic
415        .get_declaration(module, args.as_ref())
416        .ok_or(anyhow!(
417            "failed to get_declaration for intrinsic '{name}' with args '{args:?}'"
418        ))
419}
420
421#[cfg(any(test, feature = "test-utils"))]
422pub mod test;