1use anyhow::{anyhow, Result};
2use delegate::delegate;
3use hugr_core::{
4 ops::{FuncDecl, FuncDefn, OpType},
5 types::PolyFuncType,
6 HugrView, Node,
7};
8use inkwell::{
9 builder::Builder,
10 context::Context,
11 intrinsics::Intrinsic,
12 module::{Linkage, Module},
13 types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
14 values::{BasicValueEnum, CallSiteValue, FunctionValue, GlobalValue},
15};
16use std::{collections::HashSet, rc::Rc};
17
18use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
19
20use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
21
22pub mod args;
23pub mod func;
24pub mod libc;
25pub mod namer;
26pub mod ops;
27
28pub use args::EmitOpArgs;
29pub use func::{EmitFuncContext, RowPromise};
30pub use namer::Namer;
31pub use ops::emit_value;
32
33pub struct EmitModuleContext<'c, 'a, H>
40where
41 'a: 'c,
42{
43 iw_context: &'c Context,
44 module: Module<'c>,
45 extensions: Rc<CodegenExtsMap<'a, H>>,
46 namer: Rc<Namer>,
47}
48
49impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
50 delegate! {
51 to self.typing_session() {
52 pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
54 pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
56 pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
58 }
59
60 to self.namer {
61 pub fn name_func(&self, name: impl AsRef<str>, node: Node) -> String;
63 }
64 }
65
66 pub fn iw_context(&self) -> &'c Context {
67 self.iw_context
68 }
69
70 pub fn new(
73 iw_context: &'c Context,
74 module: Module<'c>,
75 namer: Rc<Namer>,
76 extensions: Rc<CodegenExtsMap<'a, H>>,
77 ) -> Self {
78 Self {
79 iw_context,
80 module,
81 namer,
82 extensions,
83 }
84 }
85
86 pub fn module(&self) -> &Module<'c> {
90 &self.module
91 }
92
93 pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
95 self.extensions.clone()
96 }
97
98 pub fn typing_session(&self) -> TypingSession<'c, 'a> {
100 self.extensions
101 .type_converter
102 .clone()
103 .session(self.iw_context)
104 }
105
106 fn get_func_impl(
107 &self,
108 name: impl AsRef<str>,
109 func_ty: FunctionType<'c>,
110 linkage: Option<Linkage>,
111 ) -> Result<FunctionValue<'c>> {
112 let func = self
113 .module()
114 .get_function(name.as_ref())
115 .unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
116 if func.get_type() != func_ty {
117 Err(anyhow!(
118 "Function '{}' has wrong type: expected: {func_ty} actual: {}",
119 name.as_ref(),
120 func.get_type()
121 ))?
122 }
123 Ok(func)
124 }
125
126 fn get_hugr_func_impl(
127 &self,
128 name: impl AsRef<str>,
129 node: Node,
130 func_ty: &PolyFuncType,
131 ) -> Result<FunctionValue<'c>> {
132 let func_ty = (func_ty.params().is_empty())
133 .then_some(func_ty.body())
134 .ok_or(anyhow!("function has type params"))?;
135 let llvm_func_ty = self.llvm_func_type(func_ty)?;
136 let name = self.name_func(name, node);
137 self.get_func_impl(name, llvm_func_ty, None)
138 }
139
140 pub fn get_func_defn<'hugr>(
144 &self,
145 node: FatNode<'hugr, FuncDefn, H>,
146 ) -> Result<FunctionValue<'c>>
147 where
148 H: HugrView<Node = Node>,
149 {
150 self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
151 }
152
153 pub fn get_func_decl<'hugr>(
157 &self,
158 node: FatNode<'hugr, FuncDecl, H>,
159 ) -> Result<FunctionValue<'c>>
160 where
161 H: HugrView<Node = Node>,
162 {
163 self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
164 }
165
166 pub fn get_extern_func(
178 &self,
179 symbol: impl AsRef<str>,
180 typ: FunctionType<'c>,
181 ) -> Result<FunctionValue<'c>> {
182 self.get_func_impl(symbol, typ, Some(Linkage::External))
183 }
184
185 pub fn get_global(
193 &self,
194 symbol: impl AsRef<str>,
195 typ: impl BasicType<'c>,
196 constant: bool,
197 ) -> Result<GlobalValue<'c>> {
198 let symbol = symbol.as_ref();
199 let typ = typ.as_basic_type_enum();
200 if let Some(global) = self.module().get_global(symbol) {
201 let global_type = {
202 use inkwell::types::AnyTypeEnum;
206 use inkwell::values::AsValueRef;
207 unsafe {
208 AnyTypeEnum::new(inkwell::llvm_sys::core::LLVMGlobalGetValueType(
209 global.as_value_ref(),
210 ))
211 }
212 };
213 if global_type != typ.as_any_type_enum() {
214 Err(anyhow!(
215 "Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
216 ))?
217 }
218 if global.is_constant() != constant {
219 Err(anyhow!(
220 "Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
221 global.is_constant()
222 ))?
223 }
224 Ok(global)
225 } else {
226 let global = self.module().add_global(typ, None, symbol.as_ref());
227 global.set_constant(constant);
228 Ok(global)
229 }
230 }
231
232 pub fn finish(self) -> Module<'c> {
234 self.module
235 }
236}
237
238type EmissionSet = HashSet<Node>;
239
240pub struct EmitHugr<'c, 'a, H>
242where
243 'a: 'c,
244{
245 emitted: EmissionSet,
246 module_context: EmitModuleContext<'c, 'a, H>,
247}
248
249impl<'c, 'a, H: HugrView<Node = Node>> EmitHugr<'c, 'a, H> {
250 delegate! {
251 to self.module_context {
252 pub fn iw_context(&self) -> &'c Context;
254 pub fn module(&self) -> &Module<'c>;
258 }
259 }
260
261 pub fn new(
263 iw_context: &'c Context,
264 module: Module<'c>,
265 namer: Rc<Namer>,
266 extensions: Rc<CodegenExtsMap<'a, H>>,
267 ) -> Self {
268 assert_eq!(iw_context, &module.get_context());
269 Self {
270 emitted: Default::default(),
271 module_context: EmitModuleContext::new(iw_context, module, namer, extensions),
272 }
273 }
274
275 pub fn emit_func(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<Self> {
289 let mut worklist: EmissionSet = [node.node()].into_iter().collect();
290 let pop = |wl: &mut EmissionSet| wl.iter().next().cloned().map(|x| wl.take(&x).unwrap());
291
292 while let Some(next_node) = pop(&mut worklist) {
293 use crate::utils::fat::FatExt as _;
294 let Some(func) = node.hugr().try_fat(next_node) else {
295 panic!(
296 "emit_func: node in worklist was not a FuncDefn: {:?}",
297 node.hugr().get_optype(next_node)
298 )
299 };
300 let (new_self, new_tasks) = self.emit_func_impl(func)?;
301 self = new_self;
302 worklist.extend(new_tasks.into_iter());
303 }
304 Ok(self)
305 }
306
307 pub fn emit_module(mut self, node: FatNode<'_, hugr_core::ops::Module, H>) -> Result<Self> {
314 for c in node.children() {
315 match c.as_ref() {
316 OpType::FuncDefn(ref fd) => {
317 let fat_ot = c.into_ot(fd);
318 self = self.emit_func(fat_ot)?;
319 }
320 OpType::FuncDecl(_) => (),
322 OpType::Const(_) => (),
324 _ => Err(anyhow!("Module has invalid child: {c}"))?,
325 }
326 }
327 Ok(self)
328 }
329
330 fn emit_func_impl(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<(Self, EmissionSet)> {
331 if !self.emitted.insert(node.node()) {
332 return Ok((self, EmissionSet::default()));
333 }
334 let func = self.module_context.get_func_defn(node)?;
335 let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
336 let ret_rmb = func_ctx.new_row_mail_box(node.signature.body().output.iter(), "ret")?;
337 ops::emit_dataflow_parent(
338 &mut func_ctx,
339 EmitOpArgs {
340 node,
341 inputs: func.get_params(),
342 outputs: ret_rmb.promise(),
343 },
344 )?;
345 let builder = func_ctx.builder();
346 match &ret_rmb.read::<Vec<_>>(builder, [])?[..] {
347 [] => builder.build_return(None)?,
348 [x] => builder.build_return(Some(x))?,
349 xs => builder.build_aggregate_return(xs)?,
350 };
351 let (mctx, todos) = func_ctx.finish()?;
352 self.module_context = mctx;
353 Ok((self, todos))
354 }
355
356 pub fn finish(self) -> Module<'c> {
358 self.module_context.finish()
359 }
360}
361
362pub fn deaggregate_call_result<'c>(
372 builder: &Builder<'c>,
373 call_result: CallSiteValue<'c>,
374 num_results: usize,
375) -> Result<Vec<BasicValueEnum<'c>>> {
376 let call_result = call_result.try_as_basic_value();
377 Ok(match num_results as u32 {
378 0 => {
379 call_result.expect_right("void");
380 vec![]
381 }
382 1 => vec![call_result.expect_left("non-void")],
383 n => {
384 let return_struct = call_result.expect_left("non-void").into_struct_value();
385 (0..n)
386 .map(|i| builder.build_extract_value(return_struct, i, ""))
387 .collect::<Result<Vec<_>, _>>()?
388 }
389 })
390}
391
392pub fn get_intrinsic<'c>(
393 module: &Module<'c>,
394 name: impl AsRef<str>,
395 args: impl AsRef<[BasicTypeEnum<'c>]>,
396) -> Result<FunctionValue<'c>> {
397 let (name, args) = (name.as_ref(), args.as_ref());
398 let intrinsic = Intrinsic::find(name).ok_or(anyhow!("Failed to find intrinsic: '{name}'"))?;
399 intrinsic
400 .get_declaration(module, args.as_ref())
401 .ok_or(anyhow!(
402 "failed to get_declaration for intrinsic '{name}' with args '{args:?}'"
403 ))
404}
405
406#[cfg(any(test, feature = "test-utils"))]
407pub mod test;