1use std::{collections::HashMap, rc::Rc};
2
3use anyhow::{anyhow, Result};
4use hugr_core::{
5 extension::prelude::{either_type, option_type},
6 ops::{constant::CustomConst, ExtensionOp, FuncDecl, FuncDefn},
7 types::Type,
8 HugrView, Node, NodeIndex, PortIndex, Wire,
9};
10use inkwell::{
11 basic_block::BasicBlock,
12 builder::Builder,
13 context::Context,
14 module::Module,
15 types::{BasicType, BasicTypeEnum, FunctionType},
16 values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31pub struct EmitFuncContext<'c, 'a, H>
48where
49 'a: 'c,
50{
51 emit_context: EmitModuleContext<'c, 'a, H>,
52 todo: EmissionSet,
53 func: FunctionValue<'c>,
54 env: HashMap<Wire, ValueMailBox<'c>>,
55 builder: Builder<'c>,
56 prologue_bb: BasicBlock<'c>,
57 launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61 delegate! {
62 to self.emit_context {
63 pub fn iw_context(&self) -> &'c Context;
65 pub fn extensions(&self) -> Rc<CodegenExtsMap<'a,H>>;
67 pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69 pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71 pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73 pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75 pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79 pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83 pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95 pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103 }
104 }
105
106 pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109 self.todo.insert(node.node());
110 }
111
112 pub fn builder(&self) -> &Builder<'c> {
116 &self.builder
117 }
118
119 pub(crate) fn new_basic_block(
123 &mut self,
124 name: impl AsRef<str>,
125 before: Option<BasicBlock<'c>>,
126 ) -> BasicBlock<'c> {
127 if let Some(before) = before {
128 self.iw_context().prepend_basic_block(before, name.as_ref())
129 } else {
130 self.iw_context()
131 .append_basic_block(self.func, name.as_ref())
132 }
133 }
134
135 fn prologue_block(&self) -> BasicBlock<'c> {
136 self.func.get_first_basic_block().unwrap()
138 }
139
140 pub fn new(
148 emit_context: EmitModuleContext<'c, 'a, H>,
149 func: FunctionValue<'c>,
150 ) -> Result<EmitFuncContext<'c, 'a, H>> {
151 if func.get_first_basic_block().is_some() {
152 Err(anyhow!(
153 "EmitContext::new: Function already has a basic block: {:?}",
154 func.get_name()
155 ))?;
156 }
157 let prologue_bb = emit_context
158 .iw_context()
159 .append_basic_block(func, "alloca_block");
160 let launch_bb = emit_context
161 .iw_context()
162 .append_basic_block(func, "entry_block");
163 let builder = emit_context.iw_context().create_builder();
164 builder.position_at_end(launch_bb);
165 Ok(Self {
166 emit_context,
167 todo: Default::default(),
168 func,
169 env: Default::default(),
170 builder,
171 prologue_bb,
172 launch_bb,
173 })
174 }
175
176 fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
177 let bte = self.llvm_type(t)?;
178 let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
179 Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
180 }
181
182 pub fn new_row_mail_box<'t>(
186 &mut self,
187 ts: impl IntoIterator<Item = &'t Type>,
188 name: impl AsRef<str>,
189 ) -> Result<RowMailBox<'c>> {
190 Ok(RowMailBox::new(
191 ts.into_iter()
192 .enumerate()
193 .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
194 .collect::<Result<Vec<_>>>()?,
195 Some(name.as_ref().into()),
196 ))
197 }
198
199 fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
200 let b = self.prologue_block();
201 self.build_positioned(b, |x| f(&x.builder))
202 }
203
204 pub fn build_positioned_new_block<T>(
209 &mut self,
210 name: impl AsRef<str>,
211 before: Option<BasicBlock<'c>>,
212 f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
213 ) -> T {
214 let bb = self.new_basic_block(name, before);
215 self.build_positioned(bb, |s| f(s, bb))
216 }
217
218 pub fn build_positioned<T>(
222 &mut self,
223 block: BasicBlock<'c>,
224 f: impl FnOnce(&mut Self) -> T,
225 ) -> T {
226 let current = self.builder.get_insert_block().unwrap();
228 self.builder.position_at_end(block);
229 let r = f(self);
230 self.builder.position_at_end(current);
231 r
232 }
233
234 pub fn node_ins_rmb<'hugr, OT: 'hugr>(
237 &mut self,
238 node: FatNode<'hugr, OT, H>,
239 ) -> Result<RowMailBox<'c>> {
240 let r = node
241 .in_value_types()
242 .map(|(p, t)| {
243 let (slo_n, slo_p) = node
244 .single_linked_output(p)
245 .ok_or(anyhow!("No single linked output"))?;
246 self.map_wire(slo_n, slo_p, &t)
247 })
248 .collect::<Result<RowMailBox>>()?;
249
250 debug_assert!(zip_eq(node.in_value_types(), r.get_types())
251 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt));
252 Ok(r)
253 }
254
255 pub fn node_outs_rmb<'hugr, OT: 'hugr>(
258 &mut self,
259 node: FatNode<'hugr, OT, H>,
260 ) -> Result<RowMailBox<'c>> {
261 let r = node
262 .out_value_types()
263 .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
264 .collect::<Result<RowMailBox>>()?;
265 debug_assert!(zip_eq(node.out_value_types(), r.get_types())
266 .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt));
267 Ok(r)
268 }
269
270 fn map_wire<'hugr, OT>(
271 &mut self,
272 node: FatNode<'hugr, OT, H>,
273 port: hugr_core::OutgoingPort,
274 hugr_type: &Type,
275 ) -> Result<ValueMailBox<'c>> {
276 let wire = Wire::new(node.node(), port);
277 if let Some(mb) = self.env.get(&wire) {
278 debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
279 return Ok(mb.clone());
280 }
281 let mb = self.new_value_mail_box(
282 hugr_type,
283 format!("{}_{}", node.node().index(), port.index()),
284 )?;
285 self.env.insert(wire, mb.clone());
286 Ok(mb)
287 }
288
289 pub fn get_current_module(&self) -> &Module<'c> {
290 self.emit_context.module()
291 }
292
293 pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
294 let exts = self.extensions();
295 exts.as_ref()
296 .load_constant_handlers
297 .emit_load_constant(self, v)
298 }
299
300 pub(crate) fn emit_extension_op(
301 &mut self,
302 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
303 ) -> Result<()> {
304 let exts = self.extensions();
305 exts.as_ref()
306 .extension_op_handlers
307 .emit_extension_op(self, args)
308 }
309
310 pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
313 self.builder.position_at_end(self.prologue_bb);
314 self.builder.build_unconditional_branch(self.launch_bb)?;
315 Ok((self.emit_context, self.todo))
316 }
317}
318
319pub fn build_option<'c, H: HugrView<Node = Node>>(
321 ctx: &mut EmitFuncContext<'c, '_, H>,
322 is_some: IntValue<'c>,
323 some_value: BasicValueEnum<'c>,
324 hugr_ty: HugrType,
325) -> Result<BasicValueEnum<'c>> {
326 let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
327 let builder = ctx.builder();
328 let some = option_ty.build_tag(builder, 1, vec![some_value])?;
329 let none = option_ty.build_tag(builder, 0, vec![])?;
330 let option = builder.build_select(is_some, some, none, "")?;
331 Ok(option)
332}
333
334pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
337 ctx: &mut EmitFuncContext<'c, '_, H>,
338 is_ok: IntValue<'c>,
339 ok_value: BasicValueEnum<'c>,
340 ok_hugr_ty: HugrType,
341 else_value: BasicValueEnum<'c>,
342 else_hugr_ty: HugrType,
343) -> Result<BasicValueEnum<'c>> {
344 let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
345 let builder = ctx.builder();
346 let left = either_ty.build_tag(builder, 0, vec![else_value])?;
347 let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
348 let either = builder.build_select(is_ok, right, left, "")?;
349 Ok(either)
350}