1use anyhow::{Result, anyhow};
2use delegate::delegate;
3use hugr_core::{
4 HugrView, Node, Visibility,
5 ops::{FuncDecl, FuncDefn, OpType},
6 types::PolyFuncType,
7};
8use inkwell::{
9 builder::Builder,
10 context::Context,
11 intrinsics::Intrinsic,
12 module::{Linkage, Module},
13 types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
14 values::{BasicValueEnum, CallSiteValue, FunctionValue, GlobalValue},
15};
16use std::{collections::HashSet, rc::Rc};
17
18use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
19
20use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
21
22pub mod args;
23pub mod func;
24pub mod libc;
25pub mod namer;
26pub mod ops;
27
28pub use args::EmitOpArgs;
29pub use func::{EmitFuncContext, RowPromise};
30pub use namer::Namer;
31pub use ops::emit_value;
32
33pub struct EmitModuleContext<'c, 'a, H>
40where
41 'a: 'c,
42{
43 iw_context: &'c Context,
44 module: Module<'c>,
45 extensions: Rc<CodegenExtsMap<'a, H>>,
46 namer: Rc<Namer>,
47}
48
49impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
50 delegate! {
51 to self.typing_session() {
52 pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
54 pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
56 pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
58 }
59
60 to self.namer {
61 pub fn name_func(&self, name: impl AsRef<str>, node: Node) -> String;
63 }
64 }
65
66 pub fn iw_context(&self) -> &'c Context {
67 self.iw_context
68 }
69
70 pub fn new(
73 iw_context: &'c Context,
74 module: Module<'c>,
75 namer: Rc<Namer>,
76 extensions: Rc<CodegenExtsMap<'a, H>>,
77 ) -> Self {
78 Self {
79 iw_context,
80 module,
81 extensions,
82 namer,
83 }
84 }
85
86 pub fn module(&self) -> &Module<'c> {
90 &self.module
91 }
92
93 pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
95 self.extensions.clone()
96 }
97
98 pub fn typing_session(&self) -> TypingSession<'c, 'a> {
100 self.extensions
101 .type_converter
102 .clone()
103 .session(self.iw_context)
104 }
105
106 fn get_func_impl(
107 &self,
108 name: impl AsRef<str>,
109 func_ty: FunctionType<'c>,
110 linkage: Option<Linkage>,
111 ) -> Result<FunctionValue<'c>> {
112 let func = self
113 .module()
114 .get_function(name.as_ref())
115 .unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
116 if func.get_type() != func_ty {
117 Err(anyhow!(
118 "Function '{}' has wrong type: expected: {func_ty} actual: {}",
119 name.as_ref(),
120 func.get_type()
121 ))?;
122 }
123 Ok(func)
124 }
125
126 fn get_hugr_func_impl(
127 &self,
128 name: impl AsRef<str>,
129 node: Node,
130 func_ty: &PolyFuncType,
131 visibility: &Visibility,
132 ) -> Result<FunctionValue<'c>> {
133 let func_ty = (func_ty.params().is_empty())
134 .then_some(func_ty.body())
135 .ok_or(anyhow!("function has type params"))?;
136 let llvm_func_ty = self.llvm_func_type(func_ty)?;
137 let name = self.name_func(name, node);
138 match visibility {
139 Visibility::Public => self.get_func_impl(name, llvm_func_ty, Some(Linkage::External)),
140 Visibility::Private => self.get_func_impl(name, llvm_func_ty, Some(Linkage::Private)),
141 _ => self.get_func_impl(name, llvm_func_ty, None),
142 }
143 }
144
145 pub fn get_func_defn<'hugr>(
149 &self,
150 node: FatNode<'hugr, FuncDefn, H>,
151 ) -> Result<FunctionValue<'c>>
152 where
153 H: HugrView<Node = Node>,
154 {
155 self.get_hugr_func_impl(
156 node.func_name(),
157 node.node(),
158 node.signature(),
159 node.visibility(),
160 )
161 }
162
163 pub fn get_func_decl<'hugr>(
167 &self,
168 node: FatNode<'hugr, FuncDecl, H>,
169 ) -> Result<FunctionValue<'c>>
170 where
171 H: HugrView<Node = Node>,
172 {
173 self.get_hugr_func_impl(
174 node.func_name(),
175 node.node(),
176 node.signature(),
177 node.visibility(),
178 )
179 }
180
181 pub fn get_extern_func(
193 &self,
194 symbol: impl AsRef<str>,
195 typ: FunctionType<'c>,
196 ) -> Result<FunctionValue<'c>> {
197 self.get_func_impl(symbol, typ, Some(Linkage::External))
198 }
199
200 pub fn get_global(
208 &self,
209 symbol: impl AsRef<str>,
210 typ: impl BasicType<'c>,
211 constant: bool,
212 ) -> Result<GlobalValue<'c>> {
213 let symbol = symbol.as_ref();
214 let typ = typ.as_basic_type_enum();
215 if let Some(global) = self.module().get_global(symbol) {
216 let global_type = {
217 use inkwell::types::AnyTypeEnum;
221 use inkwell::values::AsValueRef;
222 unsafe {
223 AnyTypeEnum::new(inkwell::llvm_sys::core::LLVMGlobalGetValueType(
224 global.as_value_ref(),
225 ))
226 }
227 };
228 if global_type != typ.as_any_type_enum() {
229 Err(anyhow!(
230 "Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
231 ))?;
232 }
233 if global.is_constant() != constant {
234 Err(anyhow!(
235 "Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
236 global.is_constant()
237 ))?;
238 }
239 Ok(global)
240 } else {
241 let global = self.module().add_global(typ, None, symbol.as_ref());
242 global.set_constant(constant);
243 Ok(global)
244 }
245 }
246
247 pub fn finish(self) -> Module<'c> {
249 self.module
250 }
251}
252
253type EmissionSet = HashSet<Node>;
254
255pub struct EmitHugr<'c, 'a, H>
257where
258 'a: 'c,
259{
260 emitted: EmissionSet,
261 module_context: EmitModuleContext<'c, 'a, H>,
262}
263
264impl<'c, 'a, H: HugrView<Node = Node>> EmitHugr<'c, 'a, H> {
265 delegate! {
266 to self.module_context {
267 pub fn iw_context(&self) -> &'c Context;
269 pub fn module(&self) -> &Module<'c>;
273 }
274 }
275
276 pub fn new(
278 iw_context: &'c Context,
279 module: Module<'c>,
280 namer: Rc<Namer>,
281 extensions: Rc<CodegenExtsMap<'a, H>>,
282 ) -> Self {
283 assert_eq!(iw_context, &module.get_context());
284 Self {
285 emitted: Default::default(),
286 module_context: EmitModuleContext::new(iw_context, module, namer, extensions),
287 }
288 }
289
290 pub fn emit_func(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<Self> {
304 let mut worklist: EmissionSet = [node.node()].into_iter().collect();
305 let pop = |wl: &mut EmissionSet| wl.iter().next().copied().map(|x| wl.take(&x).unwrap());
306
307 while let Some(next_node) = pop(&mut worklist) {
308 use crate::utils::fat::FatExt as _;
309 let Some(func) = node.hugr().try_fat(next_node) else {
310 panic!(
311 "emit_func: node in worklist was not a FuncDefn: {:?}",
312 node.hugr().get_optype(next_node)
313 )
314 };
315 let (new_self, new_tasks) = self.emit_func_impl(func)?;
316 self = new_self;
317 worklist.extend(new_tasks.into_iter());
318 }
319 Ok(self)
320 }
321
322 pub fn emit_module(mut self, node: FatNode<'_, hugr_core::ops::Module, H>) -> Result<Self> {
329 for c in node.children() {
330 match c.as_ref() {
331 OpType::FuncDefn(fd) => {
332 let fat_ot = c.into_ot(fd);
333 self = self.emit_func(fat_ot)?;
334 }
335 OpType::FuncDecl(_) => (),
337 OpType::Const(_) => (),
339 _ => Err(anyhow!("Module has invalid child: {c}"))?,
340 }
341 }
342 Ok(self)
343 }
344
345 fn emit_func_impl(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<(Self, EmissionSet)> {
346 if !self.emitted.insert(node.node()) {
347 return Ok((self, EmissionSet::default()));
348 }
349 let func = self.module_context.get_func_defn(node)?;
350 let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
351 let ret_rmb = func_ctx.new_row_mail_box(node.signature().body().output.iter(), "ret")?;
352 ops::emit_dataflow_parent(
353 &mut func_ctx,
354 EmitOpArgs {
355 node,
356 inputs: func.get_params(),
357 outputs: ret_rmb.promise(),
358 },
359 )?;
360 let builder = func_ctx.builder();
361 match &ret_rmb.read::<Vec<_>>(builder, [])?[..] {
362 [] => builder.build_return(None)?,
363 [x] => builder.build_return(Some(x))?,
364 xs => builder.build_aggregate_return(xs)?,
365 };
366 let (mctx, todos) = func_ctx.finish()?;
367 self.module_context = mctx;
368 Ok((self, todos))
369 }
370
371 pub fn finish(self) -> Module<'c> {
373 self.module_context.finish()
374 }
375}
376
377pub fn deaggregate_call_result<'c>(
387 builder: &Builder<'c>,
388 call_result: CallSiteValue<'c>,
389 num_results: usize,
390) -> Result<Vec<BasicValueEnum<'c>>> {
391 let call_result = call_result.try_as_basic_value();
392 Ok(match num_results as u32 {
393 0 => {
394 call_result.expect_right("void");
395 vec![]
396 }
397 1 => vec![call_result.expect_left("non-void")],
398 n => {
399 let return_struct = call_result.expect_left("non-void").into_struct_value();
400 (0..n)
401 .map(|i| builder.build_extract_value(return_struct, i, ""))
402 .collect::<Result<Vec<_>, _>>()?
403 }
404 })
405}
406
407pub fn get_intrinsic<'c>(
408 module: &Module<'c>,
409 name: impl AsRef<str>,
410 args: impl AsRef<[BasicTypeEnum<'c>]>,
411) -> Result<FunctionValue<'c>> {
412 let (name, args) = (name.as_ref(), args.as_ref());
413 let intrinsic = Intrinsic::find(name).ok_or(anyhow!("Failed to find intrinsic: '{name}'"))?;
414 intrinsic
415 .get_declaration(module, args.as_ref())
416 .ok_or(anyhow!(
417 "failed to get_declaration for intrinsic '{name}' with args '{args:?}'"
418 ))
419}
420
421#[cfg(any(test, feature = "test-utils"))]
422pub mod test;