1use std::{collections::BTreeMap, rc::Rc};
2
3use anyhow::{Result, anyhow};
4use hugr_core::{
5 HugrView, Node, NodeIndex, PortIndex, Wire,
6 extension::prelude::{either_type, option_type},
7 ops::{ExtensionOp, FuncDecl, FuncDefn, constant::CustomConst},
8 types::Type,
9};
10use inkwell::{
11 basic_block::BasicBlock,
12 builder::Builder,
13 context::Context,
14 module::Module,
15 types::{BasicType, BasicTypeEnum, FunctionType},
16 values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31pub struct EmitFuncContext<'c, 'a, H>
48where
49 'a: 'c,
50{
51 emit_context: EmitModuleContext<'c, 'a, H>,
52 todo: EmissionSet,
53 func: FunctionValue<'c>,
54 env: BTreeMap<Wire, ValueMailBox<'c>>,
55 builder: Builder<'c>,
56 prologue_bb: BasicBlock<'c>,
57 launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61 delegate! {
62 to self.emit_context {
63 pub fn iw_context(&self) -> &'c Context;
65 pub fn extensions(&self) -> Rc<CodegenExtsMap<'a,H>>;
67 pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69 pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71 pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73 pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75 pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79 pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83 pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95 pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103 }
104 }
105
106 pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109 self.todo.insert(node.node());
110 }
111
112 pub fn func(&self) -> FunctionValue<'c> {
114 self.func
115 }
116
117 pub fn builder(&self) -> &Builder<'c> {
121 &self.builder
122 }
123
124 pub(crate) fn new_basic_block(
128 &mut self,
129 name: impl AsRef<str>,
130 before: Option<BasicBlock<'c>>,
131 ) -> BasicBlock<'c> {
132 if let Some(before) = before {
133 self.iw_context().prepend_basic_block(before, name.as_ref())
134 } else {
135 self.iw_context()
136 .append_basic_block(self.func, name.as_ref())
137 }
138 }
139
140 fn prologue_block(&self) -> BasicBlock<'c> {
141 self.func.get_first_basic_block().unwrap()
143 }
144
145 pub fn new(
153 emit_context: EmitModuleContext<'c, 'a, H>,
154 func: FunctionValue<'c>,
155 ) -> Result<EmitFuncContext<'c, 'a, H>> {
156 if func.get_first_basic_block().is_some() {
157 Err(anyhow!(
158 "EmitContext::new: Function already has a basic block: {:?}",
159 func.get_name()
160 ))?;
161 }
162 let prologue_bb = emit_context
163 .iw_context()
164 .append_basic_block(func, "alloca_block");
165 let launch_bb = emit_context
166 .iw_context()
167 .append_basic_block(func, "entry_block");
168 let builder = emit_context.iw_context().create_builder();
169 builder.position_at_end(launch_bb);
170 Ok(Self {
171 emit_context,
172 todo: Default::default(),
173 func,
174 env: Default::default(),
175 builder,
176 prologue_bb,
177 launch_bb,
178 })
179 }
180
181 fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
182 let bte = self.llvm_type(t)?;
183 let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
184 Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
185 }
186
187 pub fn new_row_mail_box<'t>(
191 &mut self,
192 ts: impl IntoIterator<Item = &'t Type>,
193 name: impl AsRef<str>,
194 ) -> Result<RowMailBox<'c>> {
195 Ok(RowMailBox::new(
196 ts.into_iter()
197 .enumerate()
198 .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
199 .collect::<Result<Vec<_>>>()?,
200 Some(name.as_ref().into()),
201 ))
202 }
203
204 fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
205 let b = self.prologue_block();
206 self.build_positioned(b, |x| f(&x.builder))
207 }
208
209 pub fn build_positioned_new_block<T>(
214 &mut self,
215 name: impl AsRef<str>,
216 before: Option<BasicBlock<'c>>,
217 f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
218 ) -> T {
219 let bb = self.new_basic_block(name, before);
220 self.build_positioned(bb, |s| f(s, bb))
221 }
222
223 pub fn build_positioned<T>(
227 &mut self,
228 block: BasicBlock<'c>,
229 f: impl FnOnce(&mut Self) -> T,
230 ) -> T {
231 let current = self.builder.get_insert_block().unwrap();
233 self.builder.position_at_end(block);
234 let r = f(self);
235 self.builder.position_at_end(current);
236 r
237 }
238
239 pub fn node_ins_rmb<'hugr, OT: 'hugr>(
242 &mut self,
243 node: FatNode<'hugr, OT, H>,
244 ) -> Result<RowMailBox<'c>> {
245 let r = node
246 .in_value_types()
247 .map(|(p, t)| {
248 let (slo_n, slo_p) = node
249 .single_linked_output(p)
250 .ok_or(anyhow!("No single linked output"))?;
251 self.map_wire(slo_n, slo_p, &t)
252 })
253 .collect::<Result<RowMailBox>>()?;
254
255 debug_assert!(
256 zip_eq(node.in_value_types(), r.get_types())
257 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
258 );
259 Ok(r)
260 }
261
262 pub fn node_outs_rmb<'hugr, OT: 'hugr>(
265 &mut self,
266 node: FatNode<'hugr, OT, H>,
267 ) -> Result<RowMailBox<'c>> {
268 let r = node
269 .out_value_types()
270 .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
271 .collect::<Result<RowMailBox>>()?;
272 debug_assert!(
273 zip_eq(node.out_value_types(), r.get_types())
274 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
275 );
276 Ok(r)
277 }
278
279 fn map_wire<'hugr, OT>(
280 &mut self,
281 node: FatNode<'hugr, OT, H>,
282 port: hugr_core::OutgoingPort,
283 hugr_type: &Type,
284 ) -> Result<ValueMailBox<'c>> {
285 let wire = Wire::new(node.node(), port);
286 if let Some(mb) = self.env.get(&wire) {
287 debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
288 return Ok(mb.clone());
289 }
290 let mb = self.new_value_mail_box(
291 hugr_type,
292 format!("{}_{}", node.node().index(), port.index()),
293 )?;
294 self.env.insert(wire, mb.clone());
295 Ok(mb)
296 }
297
298 pub fn get_current_module(&self) -> &Module<'c> {
299 self.emit_context.module()
300 }
301
302 pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
303 let exts = self.extensions();
304 exts.as_ref()
305 .load_constant_handlers
306 .emit_load_constant(self, v)
307 }
308
309 pub(crate) fn emit_extension_op(
310 &mut self,
311 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
312 ) -> Result<()> {
313 let exts = self.extensions();
314 exts.as_ref()
315 .extension_op_handlers
316 .emit_extension_op(self, args)
317 }
318
319 pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
322 self.builder.position_at_end(self.prologue_bb);
323 self.builder.build_unconditional_branch(self.launch_bb)?;
324 Ok((self.emit_context, self.todo))
325 }
326}
327
328pub fn build_option<'c, H: HugrView<Node = Node>>(
330 ctx: &mut EmitFuncContext<'c, '_, H>,
331 is_some: IntValue<'c>,
332 some_value: BasicValueEnum<'c>,
333 hugr_ty: HugrType,
334) -> Result<BasicValueEnum<'c>> {
335 let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
336 let builder = ctx.builder();
337 let some = option_ty.build_tag(builder, 1, vec![some_value])?;
338 let none = option_ty.build_tag(builder, 0, vec![])?;
339 let option = builder.build_select(is_some, some, none, "")?;
340 Ok(option)
341}
342
343pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
346 ctx: &mut EmitFuncContext<'c, '_, H>,
347 is_ok: IntValue<'c>,
348 ok_value: BasicValueEnum<'c>,
349 ok_hugr_ty: HugrType,
350 else_value: BasicValueEnum<'c>,
351 else_hugr_ty: HugrType,
352) -> Result<BasicValueEnum<'c>> {
353 let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
354 let builder = ctx.builder();
355 let left = either_ty.build_tag(builder, 0, vec![else_value])?;
356 let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
357 let either = builder.build_select(is_ok, right, left, "")?;
358 Ok(either)
359}
360
361#[cfg(test)]
362mod tests {
363 #[test]
364 fn test_func_getter() {
365 let test_ctx = crate::test::test_ctx(-1);
367 let emit_context = test_ctx.get_emit_module_context();
368 let func_type = emit_context.iw_context().void_type().fn_type(&[], false);
369 let function = emit_context
370 .module()
371 .add_function("test_func", func_type, None);
372 let func_context = super::EmitFuncContext::new(emit_context, function).unwrap();
373
374 assert_eq!(func_context.func(), function);
376 }
377}