1use std::{collections::BTreeMap, rc::Rc};
2
3use anyhow::{Result, anyhow};
4use hugr_core::{
5 HugrView, Node, NodeIndex, PortIndex, Wire,
6 extension::prelude::{either_type, option_type},
7 ops::{ExtensionOp, FuncDecl, FuncDefn, constant::CustomConst},
8 types::Type,
9};
10use inkwell::{
11 basic_block::BasicBlock,
12 builder::Builder,
13 context::Context,
14 module::Module,
15 types::{BasicType, BasicTypeEnum, FunctionType},
16 values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31pub struct EmitFuncContext<'c, 'a, H>
48where
49 'a: 'c,
50{
51 emit_context: EmitModuleContext<'c, 'a, H>,
52 todo: EmissionSet,
53 func: FunctionValue<'c>,
54 env: BTreeMap<Wire, ValueMailBox<'c>>,
55 builder: Builder<'c>,
56 prologue_bb: BasicBlock<'c>,
57 launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61 delegate! {
62 to self.emit_context {
63 pub fn iw_context(&self) -> &'c Context;
65 pub fn extensions(&self) -> Rc<CodegenExtsMap<'a,H>>;
67 pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69 pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71 pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73 pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75 pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79 pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83 pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95 pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103 }
104 }
105
106 pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109 self.todo.insert(node.node());
110 }
111
112 pub fn builder(&self) -> &Builder<'c> {
116 &self.builder
117 }
118
119 pub(crate) fn new_basic_block(
123 &mut self,
124 name: impl AsRef<str>,
125 before: Option<BasicBlock<'c>>,
126 ) -> BasicBlock<'c> {
127 if let Some(before) = before {
128 self.iw_context().prepend_basic_block(before, name.as_ref())
129 } else {
130 self.iw_context()
131 .append_basic_block(self.func, name.as_ref())
132 }
133 }
134
135 fn prologue_block(&self) -> BasicBlock<'c> {
136 self.func.get_first_basic_block().unwrap()
138 }
139
140 pub fn new(
148 emit_context: EmitModuleContext<'c, 'a, H>,
149 func: FunctionValue<'c>,
150 ) -> Result<EmitFuncContext<'c, 'a, H>> {
151 if func.get_first_basic_block().is_some() {
152 Err(anyhow!(
153 "EmitContext::new: Function already has a basic block: {:?}",
154 func.get_name()
155 ))?;
156 }
157 let prologue_bb = emit_context
158 .iw_context()
159 .append_basic_block(func, "alloca_block");
160 let launch_bb = emit_context
161 .iw_context()
162 .append_basic_block(func, "entry_block");
163 let builder = emit_context.iw_context().create_builder();
164 builder.position_at_end(launch_bb);
165 Ok(Self {
166 emit_context,
167 todo: Default::default(),
168 func,
169 env: Default::default(),
170 builder,
171 prologue_bb,
172 launch_bb,
173 })
174 }
175
176 fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
177 let bte = self.llvm_type(t)?;
178 let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
179 Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
180 }
181
182 pub fn new_row_mail_box<'t>(
186 &mut self,
187 ts: impl IntoIterator<Item = &'t Type>,
188 name: impl AsRef<str>,
189 ) -> Result<RowMailBox<'c>> {
190 Ok(RowMailBox::new(
191 ts.into_iter()
192 .enumerate()
193 .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
194 .collect::<Result<Vec<_>>>()?,
195 Some(name.as_ref().into()),
196 ))
197 }
198
199 fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
200 let b = self.prologue_block();
201 self.build_positioned(b, |x| f(&x.builder))
202 }
203
204 pub fn build_positioned_new_block<T>(
209 &mut self,
210 name: impl AsRef<str>,
211 before: Option<BasicBlock<'c>>,
212 f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
213 ) -> T {
214 let bb = self.new_basic_block(name, before);
215 self.build_positioned(bb, |s| f(s, bb))
216 }
217
218 pub fn build_positioned<T>(
222 &mut self,
223 block: BasicBlock<'c>,
224 f: impl FnOnce(&mut Self) -> T,
225 ) -> T {
226 let current = self.builder.get_insert_block().unwrap();
228 self.builder.position_at_end(block);
229 let r = f(self);
230 self.builder.position_at_end(current);
231 r
232 }
233
234 pub fn node_ins_rmb<'hugr, OT: 'hugr>(
237 &mut self,
238 node: FatNode<'hugr, OT, H>,
239 ) -> Result<RowMailBox<'c>> {
240 let r = node
241 .in_value_types()
242 .map(|(p, t)| {
243 let (slo_n, slo_p) = node
244 .single_linked_output(p)
245 .ok_or(anyhow!("No single linked output"))?;
246 self.map_wire(slo_n, slo_p, &t)
247 })
248 .collect::<Result<RowMailBox>>()?;
249
250 debug_assert!(
251 zip_eq(node.in_value_types(), r.get_types())
252 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
253 );
254 Ok(r)
255 }
256
257 pub fn node_outs_rmb<'hugr, OT: 'hugr>(
260 &mut self,
261 node: FatNode<'hugr, OT, H>,
262 ) -> Result<RowMailBox<'c>> {
263 let r = node
264 .out_value_types()
265 .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
266 .collect::<Result<RowMailBox>>()?;
267 debug_assert!(
268 zip_eq(node.out_value_types(), r.get_types())
269 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt)
270 );
271 Ok(r)
272 }
273
274 fn map_wire<'hugr, OT>(
275 &mut self,
276 node: FatNode<'hugr, OT, H>,
277 port: hugr_core::OutgoingPort,
278 hugr_type: &Type,
279 ) -> Result<ValueMailBox<'c>> {
280 let wire = Wire::new(node.node(), port);
281 if let Some(mb) = self.env.get(&wire) {
282 debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
283 return Ok(mb.clone());
284 }
285 let mb = self.new_value_mail_box(
286 hugr_type,
287 format!("{}_{}", node.node().index(), port.index()),
288 )?;
289 self.env.insert(wire, mb.clone());
290 Ok(mb)
291 }
292
293 pub fn get_current_module(&self) -> &Module<'c> {
294 self.emit_context.module()
295 }
296
297 pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
298 let exts = self.extensions();
299 exts.as_ref()
300 .load_constant_handlers
301 .emit_load_constant(self, v)
302 }
303
304 pub(crate) fn emit_extension_op(
305 &mut self,
306 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
307 ) -> Result<()> {
308 let exts = self.extensions();
309 exts.as_ref()
310 .extension_op_handlers
311 .emit_extension_op(self, args)
312 }
313
314 pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
317 self.builder.position_at_end(self.prologue_bb);
318 self.builder.build_unconditional_branch(self.launch_bb)?;
319 Ok((self.emit_context, self.todo))
320 }
321}
322
323pub fn build_option<'c, H: HugrView<Node = Node>>(
325 ctx: &mut EmitFuncContext<'c, '_, H>,
326 is_some: IntValue<'c>,
327 some_value: BasicValueEnum<'c>,
328 hugr_ty: HugrType,
329) -> Result<BasicValueEnum<'c>> {
330 let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
331 let builder = ctx.builder();
332 let some = option_ty.build_tag(builder, 1, vec![some_value])?;
333 let none = option_ty.build_tag(builder, 0, vec![])?;
334 let option = builder.build_select(is_some, some, none, "")?;
335 Ok(option)
336}
337
338pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
341 ctx: &mut EmitFuncContext<'c, '_, H>,
342 is_ok: IntValue<'c>,
343 ok_value: BasicValueEnum<'c>,
344 ok_hugr_ty: HugrType,
345 else_value: BasicValueEnum<'c>,
346 else_hugr_ty: HugrType,
347) -> Result<BasicValueEnum<'c>> {
348 let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
349 let builder = ctx.builder();
350 let left = either_ty.build_tag(builder, 0, vec![else_value])?;
351 let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
352 let either = builder.build_select(is_ok, right, left, "")?;
353 Ok(either)
354}