hugr_llvm/
emit.rs

1use anyhow::{anyhow, Result};
2use delegate::delegate;
3use hugr_core::{
4    ops::{FuncDecl, FuncDefn, OpType},
5    types::PolyFuncType,
6    HugrView, Node,
7};
8use inkwell::{
9    builder::Builder,
10    context::Context,
11    intrinsics::Intrinsic,
12    module::{Linkage, Module},
13    types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
14    values::{BasicValueEnum, CallSiteValue, FunctionValue, GlobalValue},
15};
16use std::{collections::HashSet, rc::Rc};
17
18use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
19
20use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
21
22pub mod args;
23pub mod func;
24pub mod libc;
25pub mod namer;
26pub mod ops;
27
28pub use args::EmitOpArgs;
29pub use func::{EmitFuncContext, RowPromise};
30pub use namer::Namer;
31pub use ops::emit_value;
32
33/// A context holding data required for emitting HUGRs into an LLVM module.
34/// This includes the module itself, a set of extensions for lowering custom
35/// elements, and policy for naming various HUGR elements.
36///
37/// `'c` names the lifetime of the LLVM context, while `'a` names the lifetime
38/// of other internal references.
39pub struct EmitModuleContext<'c, 'a, H>
40where
41    'a: 'c,
42{
43    iw_context: &'c Context,
44    module: Module<'c>,
45    extensions: Rc<CodegenExtsMap<'a, H>>,
46    namer: Rc<Namer>,
47}
48
49impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
50    delegate! {
51        to self.typing_session() {
52            /// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
53            pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
54            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
55            pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
56            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
57            pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
58        }
59
60        to self.namer {
61            /// Mangle the name of a [FuncDefn]  or a [FuncDecl].
62            pub fn name_func(&self, name: impl AsRef<str>, node: Node) -> String;
63        }
64    }
65
66    pub fn iw_context(&self) -> &'c Context {
67        self.iw_context
68    }
69
70    /// Creates a new  `EmitModuleContext`. We take ownership of the [Module],
71    /// and return it in [EmitModuleContext::finish].
72    pub fn new(
73        iw_context: &'c Context,
74        module: Module<'c>,
75        namer: Rc<Namer>,
76        extensions: Rc<CodegenExtsMap<'a, H>>,
77    ) -> Self {
78        Self {
79            iw_context,
80            module,
81            namer,
82            extensions,
83        }
84    }
85
86    /// Returns a reference to the inner [Module]. Note that this type has
87    /// "interior mutability", and this reference can be used to add functions
88    /// and globals to the [Module].
89    pub fn module(&self) -> &Module<'c> {
90        &self.module
91    }
92
93    /// Returns a reference to the inner [CodegenExtsMap].
94    pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
95        self.extensions.clone()
96    }
97
98    /// Returns a [TypingSession] constructed from it's members.
99    pub fn typing_session(&self) -> TypingSession<'c, 'a> {
100        self.extensions
101            .type_converter
102            .clone()
103            .session(self.iw_context)
104    }
105
106    fn get_func_impl(
107        &self,
108        name: impl AsRef<str>,
109        func_ty: FunctionType<'c>,
110        linkage: Option<Linkage>,
111    ) -> Result<FunctionValue<'c>> {
112        let func = self
113            .module()
114            .get_function(name.as_ref())
115            .unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
116        if func.get_type() != func_ty {
117            Err(anyhow!(
118                "Function '{}' has wrong type: expected: {func_ty} actual: {}",
119                name.as_ref(),
120                func.get_type()
121            ))?
122        }
123        Ok(func)
124    }
125
126    fn get_hugr_func_impl(
127        &self,
128        name: impl AsRef<str>,
129        node: Node,
130        func_ty: &PolyFuncType,
131    ) -> Result<FunctionValue<'c>> {
132        let func_ty = (func_ty.params().is_empty())
133            .then_some(func_ty.body())
134            .ok_or(anyhow!("function has type params"))?;
135        let llvm_func_ty = self.llvm_func_type(func_ty)?;
136        let name = self.name_func(name, node);
137        self.get_func_impl(name, llvm_func_ty, None)
138    }
139
140    /// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDefn].
141    ///
142    /// The name of the result is mangled by [EmitModuleContext::name_func].
143    pub fn get_func_defn<'hugr>(
144        &self,
145        node: FatNode<'hugr, FuncDefn, H>,
146    ) -> Result<FunctionValue<'c>>
147    where
148        H: HugrView<Node = Node>,
149    {
150        self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
151    }
152
153    /// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDecl].
154    ///
155    /// The name of the result is mangled by [EmitModuleContext::name_func].
156    pub fn get_func_decl<'hugr>(
157        &self,
158        node: FatNode<'hugr, FuncDecl, H>,
159    ) -> Result<FunctionValue<'c>>
160    where
161        H: HugrView<Node = Node>,
162    {
163        self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
164    }
165
166    /// Adds or get the [FunctionValue] in the [Module] with the given symbol
167    /// and function type.
168    ///
169    /// The name undergoes no mangling. The [FunctionValue] will have
170    /// [Linkage::External].
171    ///
172    /// If this function is called multiple times with the same arguments it
173    /// will return the same [FunctionValue].
174    ///
175    /// If a function with the given name exists but the type does not match
176    /// then an Error is returned.
177    pub fn get_extern_func(
178        &self,
179        symbol: impl AsRef<str>,
180        typ: FunctionType<'c>,
181    ) -> Result<FunctionValue<'c>> {
182        self.get_func_impl(symbol, typ, Some(Linkage::External))
183    }
184
185    /// Adds or gets the [GlobalValue] in the [Module] corresponding to the
186    /// given symbol and LLVM type.
187    ///
188    /// The name will not be mangled.
189    ///
190    /// If a global with the given name exists but the type or constant-ness
191    /// does not match then an error will be returned.
192    pub fn get_global(
193        &self,
194        symbol: impl AsRef<str>,
195        typ: impl BasicType<'c>,
196        constant: bool,
197    ) -> Result<GlobalValue<'c>> {
198        let symbol = symbol.as_ref();
199        let typ = typ.as_basic_type_enum();
200        if let Some(global) = self.module().get_global(symbol) {
201            let global_type = {
202                // TODO This is exposed as `get_value_type` on the master branch
203                // of inkwell, will be in the next release. When it's released
204                // use `get_value_type`.
205                use inkwell::types::AnyTypeEnum;
206                use inkwell::values::AsValueRef;
207                unsafe {
208                    AnyTypeEnum::new(inkwell::llvm_sys::core::LLVMGlobalGetValueType(
209                        global.as_value_ref(),
210                    ))
211                }
212            };
213            if global_type != typ.as_any_type_enum() {
214                Err(anyhow!(
215                    "Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
216                ))?
217            }
218            if global.is_constant() != constant {
219                Err(anyhow!(
220                    "Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
221                    global.is_constant()
222                ))?
223            }
224            Ok(global)
225        } else {
226            let global = self.module().add_global(typ, None, symbol.as_ref());
227            global.set_constant(constant);
228            Ok(global)
229        }
230    }
231
232    /// Consumes the `EmitModuleContext` and returns the internal [Module].
233    pub fn finish(self) -> Module<'c> {
234        self.module
235    }
236}
237
238type EmissionSet = HashSet<Node>;
239
240/// Emits [HugrView]s into an LLVM [Module].
241pub struct EmitHugr<'c, 'a, H>
242where
243    'a: 'c,
244{
245    emitted: EmissionSet,
246    module_context: EmitModuleContext<'c, 'a, H>,
247}
248
249impl<'c, 'a, H: HugrView<Node = Node>> EmitHugr<'c, 'a, H> {
250    delegate! {
251        to self.module_context {
252            /// Returns a reference to the inner [Context].
253            pub fn iw_context(&self) -> &'c Context;
254            /// Returns a reference to the inner [Module]. Note that this type has
255            /// "interior mutability", and this reference can be used to add functions
256            /// and globals to the [Module].
257            pub fn module(&self) -> &Module<'c>;
258        }
259    }
260
261    /// Creates a new  `EmitHugr`. We take ownership of the [Module], and return it in [Self::finish].
262    pub fn new(
263        iw_context: &'c Context,
264        module: Module<'c>,
265        namer: Rc<Namer>,
266        extensions: Rc<CodegenExtsMap<'a, H>>,
267    ) -> Self {
268        assert_eq!(iw_context, &module.get_context());
269        Self {
270            emitted: Default::default(),
271            module_context: EmitModuleContext::new(iw_context, module, namer, extensions),
272        }
273    }
274
275    /// Emits a FuncDefn into the inner [Module].
276    ///
277    /// `node` need not be a child of a hugr [Module](hugr_core::ops::Module), but it will
278    /// be emitted as a top-level function in the inner [Module]. Indeed, there
279    /// are only top-level functions in LLVM IR.
280    ///
281    /// Any child [FuncDefn] will also be emitted.
282    ///
283    /// It is safe to emit the same node multiple times: the second and further
284    /// emissions will be no-ops.
285    ///
286    /// If any LLVM IR declaration which is to be emitted already exists in the
287    /// [Module] and it differs from what would be emitted, then we fail.
288    pub fn emit_func(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<Self> {
289        let mut worklist: EmissionSet = [node.node()].into_iter().collect();
290        let pop = |wl: &mut EmissionSet| wl.iter().next().cloned().map(|x| wl.take(&x).unwrap());
291
292        while let Some(next_node) = pop(&mut worklist) {
293            use crate::utils::fat::FatExt as _;
294            let Some(func) = node.hugr().try_fat(next_node) else {
295                panic!(
296                    "emit_func: node in worklist was not a FuncDefn: {:?}",
297                    node.hugr().get_optype(next_node)
298                )
299            };
300            let (new_self, new_tasks) = self.emit_func_impl(func)?;
301            self = new_self;
302            worklist.extend(new_tasks.into_iter());
303        }
304        Ok(self)
305    }
306
307    /// Emits all children of a hugr [Module](hugr_core::ops::Module).
308    ///
309    /// Note that type aliases are not supported, and that [hugr_core::ops::Const]
310    /// and [hugr_core::ops::FuncDecl] nodes are not emitted directly, but instead by
311    /// emission of ops with static edges from them. So [FuncDefn] are the only
312    /// interesting children.
313    pub fn emit_module(mut self, node: FatNode<'_, hugr_core::ops::Module, H>) -> Result<Self> {
314        for c in node.children() {
315            match c.as_ref() {
316                OpType::FuncDefn(ref fd) => {
317                    let fat_ot = c.into_ot(fd);
318                    self = self.emit_func(fat_ot)?;
319                }
320                // FuncDecls are allowed, but we don't need to do anything here.
321                OpType::FuncDecl(_) => (),
322                // Consts are allowed, but we don't need to do anything here.
323                OpType::Const(_) => (),
324                _ => Err(anyhow!("Module has invalid child: {c}"))?,
325            }
326        }
327        Ok(self)
328    }
329
330    fn emit_func_impl(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<(Self, EmissionSet)> {
331        if !self.emitted.insert(node.node()) {
332            return Ok((self, EmissionSet::default()));
333        }
334        let func = self.module_context.get_func_defn(node)?;
335        let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
336        let ret_rmb = func_ctx.new_row_mail_box(node.signature.body().output.iter(), "ret")?;
337        ops::emit_dataflow_parent(
338            &mut func_ctx,
339            EmitOpArgs {
340                node,
341                inputs: func.get_params(),
342                outputs: ret_rmb.promise(),
343            },
344        )?;
345        let builder = func_ctx.builder();
346        match &ret_rmb.read::<Vec<_>>(builder, [])?[..] {
347            [] => builder.build_return(None)?,
348            [x] => builder.build_return(Some(x))?,
349            xs => builder.build_aggregate_return(xs)?,
350        };
351        let (mctx, todos) = func_ctx.finish()?;
352        self.module_context = mctx;
353        Ok((self, todos))
354    }
355
356    /// Consumes the `EmitHugr` and returns the internal [Module].
357    pub fn finish(self) -> Module<'c> {
358        self.module_context.finish()
359    }
360}
361
362/// Extract all return values from the result of a `call`.
363///
364/// LLVM only supports functions with exactly zero or one return value.
365/// For functions with multiple return values, we return a struct containing
366/// all the return values.
367///
368/// `inkwell` provides a helper [Builder::build_aggregate_return] to construct
369/// the return value, see `EmitHugr::emit_func_impl`. This function performs the
370/// inverse.
371pub fn deaggregate_call_result<'c>(
372    builder: &Builder<'c>,
373    call_result: CallSiteValue<'c>,
374    num_results: usize,
375) -> Result<Vec<BasicValueEnum<'c>>> {
376    let call_result = call_result.try_as_basic_value();
377    Ok(match num_results as u32 {
378        0 => {
379            call_result.expect_right("void");
380            vec![]
381        }
382        1 => vec![call_result.expect_left("non-void")],
383        n => {
384            let return_struct = call_result.expect_left("non-void").into_struct_value();
385            (0..n)
386                .map(|i| builder.build_extract_value(return_struct, i, ""))
387                .collect::<Result<Vec<_>, _>>()?
388        }
389    })
390}
391
392pub fn get_intrinsic<'c>(
393    module: &Module<'c>,
394    name: impl AsRef<str>,
395    args: impl AsRef<[BasicTypeEnum<'c>]>,
396) -> Result<FunctionValue<'c>> {
397    let (name, args) = (name.as_ref(), args.as_ref());
398    let intrinsic = Intrinsic::find(name).ok_or(anyhow!("Failed to find intrinsic: '{name}'"))?;
399    intrinsic
400        .get_declaration(module, args.as_ref())
401        .ok_or(anyhow!(
402            "failed to get_declaration for intrinsic '{name}' with args '{args:?}'"
403        ))
404}
405
406#[cfg(any(test, feature = "test-utils"))]
407pub mod test;