1use std::{collections::HashMap, rc::Rc};
2
3use anyhow::{anyhow, Result};
4use hugr_core::{
5    extension::prelude::{either_type, option_type},
6    ops::{constant::CustomConst, ExtensionOp, FuncDecl, FuncDefn},
7    types::Type,
8    HugrView, Node, NodeIndex, PortIndex, Wire,
9};
10use inkwell::{
11    basic_block::BasicBlock,
12    builder::Builder,
13    context::Context,
14    module::Module,
15    types::{BasicType, BasicTypeEnum, FunctionType},
16    values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
17};
18use itertools::zip_eq;
19
20use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
21use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
22use delegate::delegate;
23
24use self::mailbox::ValueMailBox;
25
26use super::{EmissionSet, EmitModuleContext, EmitOpArgs};
27
28mod mailbox;
29pub use mailbox::{RowMailBox, RowPromise};
30
31pub struct EmitFuncContext<'c, 'a, H>
48where
49    'a: 'c,
50{
51    emit_context: EmitModuleContext<'c, 'a, H>,
52    todo: EmissionSet,
53    func: FunctionValue<'c>,
54    env: HashMap<Wire, ValueMailBox<'c>>,
55    builder: Builder<'c>,
56    prologue_bb: BasicBlock<'c>,
57    launch_bb: BasicBlock<'c>,
58}
59
60impl<'c, 'a, H: HugrView<Node = Node>> EmitFuncContext<'c, 'a, H> {
61    delegate! {
62        to self.emit_context {
63            pub fn iw_context(&self) ->  &'c Context;
65            pub fn extensions(&self) ->  Rc<CodegenExtsMap<'a,H>>;
67            pub fn typing_session(&self) -> TypingSession<'c, 'a>;
69            pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
71            pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c> >;
73            pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
75            pub fn get_func_defn(&self, node: FatNode<FuncDefn, H>) -> Result<FunctionValue<'c>>;
79            pub fn get_func_decl(&self, node: FatNode<FuncDecl, H>) -> Result<FunctionValue<'c>>;
83            pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
95            pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
103        }
104    }
105
106    pub fn push_todo_func(&mut self, node: FatNode<'_, FuncDefn, H>) {
109        self.todo.insert(node.node());
110    }
111
112    pub fn builder(&self) -> &Builder<'c> {
116        &self.builder
117    }
118
119    pub(crate) fn new_basic_block(
123        &mut self,
124        name: impl AsRef<str>,
125        before: Option<BasicBlock<'c>>,
126    ) -> BasicBlock<'c> {
127        if let Some(before) = before {
128            self.iw_context().prepend_basic_block(before, name.as_ref())
129        } else {
130            self.iw_context()
131                .append_basic_block(self.func, name.as_ref())
132        }
133    }
134
135    fn prologue_block(&self) -> BasicBlock<'c> {
136        self.func.get_first_basic_block().unwrap()
138    }
139
140    pub fn new(
148        emit_context: EmitModuleContext<'c, 'a, H>,
149        func: FunctionValue<'c>,
150    ) -> Result<EmitFuncContext<'c, 'a, H>> {
151        if func.get_first_basic_block().is_some() {
152            Err(anyhow!(
153                "EmitContext::new: Function already has a basic block: {:?}",
154                func.get_name()
155            ))?;
156        }
157        let prologue_bb = emit_context
158            .iw_context()
159            .append_basic_block(func, "alloca_block");
160        let launch_bb = emit_context
161            .iw_context()
162            .append_basic_block(func, "entry_block");
163        let builder = emit_context.iw_context().create_builder();
164        builder.position_at_end(launch_bb);
165        Ok(Self {
166            emit_context,
167            todo: Default::default(),
168            func,
169            env: Default::default(),
170            builder,
171            prologue_bb,
172            launch_bb,
173        })
174    }
175
176    fn new_value_mail_box(&mut self, t: &Type, name: impl AsRef<str>) -> Result<ValueMailBox<'c>> {
177        let bte = self.llvm_type(t)?;
178        let ptr = self.build_prologue(|builder| builder.build_alloca(bte, name.as_ref()))?;
179        Ok(ValueMailBox::new(bte, ptr, Some(name.as_ref().into())))
180    }
181
182    pub fn new_row_mail_box<'t>(
186        &mut self,
187        ts: impl IntoIterator<Item = &'t Type>,
188        name: impl AsRef<str>,
189    ) -> Result<RowMailBox<'c>> {
190        Ok(RowMailBox::new(
191            ts.into_iter()
192                .enumerate()
193                .map(|(i, t)| self.new_value_mail_box(t, format!("{i}")))
194                .collect::<Result<Vec<_>>>()?,
195            Some(name.as_ref().into()),
196        ))
197    }
198
199    fn build_prologue<T>(&mut self, f: impl FnOnce(&Builder<'c>) -> T) -> T {
200        let b = self.prologue_block();
201        self.build_positioned(b, |x| f(&x.builder))
202    }
203
204    pub fn build_positioned_new_block<T>(
209        &mut self,
210        name: impl AsRef<str>,
211        before: Option<BasicBlock<'c>>,
212        f: impl FnOnce(&mut Self, BasicBlock<'c>) -> T,
213    ) -> T {
214        let bb = self.new_basic_block(name, before);
215        self.build_positioned(bb, |s| f(s, bb))
216    }
217
218    pub fn build_positioned<T>(
222        &mut self,
223        block: BasicBlock<'c>,
224        f: impl FnOnce(&mut Self) -> T,
225    ) -> T {
226        let current = self.builder.get_insert_block().unwrap();
228        self.builder.position_at_end(block);
229        let r = f(self);
230        self.builder.position_at_end(current);
231        r
232    }
233
234    pub fn node_ins_rmb<'hugr, OT: 'hugr>(
237        &mut self,
238        node: FatNode<'hugr, OT, H>,
239    ) -> Result<RowMailBox<'c>> {
240        let r = node
241            .in_value_types()
242            .map(|(p, t)| {
243                let (slo_n, slo_p) = node
244                    .single_linked_output(p)
245                    .ok_or(anyhow!("No single linked output"))?;
246                self.map_wire(slo_n, slo_p, &t)
247            })
248            .collect::<Result<RowMailBox>>()?;
249
250        debug_assert!(zip_eq(node.in_value_types(), r.get_types())
251            .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt));
252        Ok(r)
253    }
254
255    pub fn node_outs_rmb<'hugr, OT: 'hugr>(
258        &mut self,
259        node: FatNode<'hugr, OT, H>,
260    ) -> Result<RowMailBox<'c>> {
261        let r = node
262            .out_value_types()
263            .map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
264            .collect::<Result<RowMailBox>>()?;
265        debug_assert!(zip_eq(node.out_value_types(), r.get_types())
266            .all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt));
267        Ok(r)
268    }
269
270    fn map_wire<'hugr, OT>(
271        &mut self,
272        node: FatNode<'hugr, OT, H>,
273        port: hugr_core::OutgoingPort,
274        hugr_type: &Type,
275    ) -> Result<ValueMailBox<'c>> {
276        let wire = Wire::new(node.node(), port);
277        if let Some(mb) = self.env.get(&wire) {
278            debug_assert_eq!(self.llvm_type(hugr_type).unwrap(), mb.get_type());
279            return Ok(mb.clone());
280        }
281        let mb = self.new_value_mail_box(
282            hugr_type,
283            format!("{}_{}", node.node().index(), port.index()),
284        )?;
285        self.env.insert(wire, mb.clone());
286        Ok(mb)
287    }
288
289    pub fn get_current_module(&self) -> &Module<'c> {
290        self.emit_context.module()
291    }
292
293    pub(crate) fn emit_custom_const(&mut self, v: &dyn CustomConst) -> Result<BasicValueEnum<'c>> {
294        let exts = self.extensions();
295        exts.as_ref()
296            .load_constant_handlers
297            .emit_load_constant(self, v)
298    }
299
300    pub(crate) fn emit_extension_op(
301        &mut self,
302        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
303    ) -> Result<()> {
304        let exts = self.extensions();
305        exts.as_ref()
306            .extension_op_handlers
307            .emit_extension_op(self, args)
308    }
309
310    pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
313        self.builder.position_at_end(self.prologue_bb);
314        self.builder.build_unconditional_branch(self.launch_bb)?;
315        Ok((self.emit_context, self.todo))
316    }
317}
318
319pub fn build_option<'c, H: HugrView<Node = Node>>(
321    ctx: &mut EmitFuncContext<'c, '_, H>,
322    is_some: IntValue<'c>,
323    some_value: BasicValueEnum<'c>,
324    hugr_ty: HugrType,
325) -> Result<BasicValueEnum<'c>> {
326    let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?;
327    let builder = ctx.builder();
328    let some = option_ty.build_tag(builder, 1, vec![some_value])?;
329    let none = option_ty.build_tag(builder, 0, vec![])?;
330    let option = builder.build_select(is_some, some, none, "")?;
331    Ok(option)
332}
333
334pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
337    ctx: &mut EmitFuncContext<'c, '_, H>,
338    is_ok: IntValue<'c>,
339    ok_value: BasicValueEnum<'c>,
340    ok_hugr_ty: HugrType,
341    else_value: BasicValueEnum<'c>,
342    else_hugr_ty: HugrType,
343) -> Result<BasicValueEnum<'c>> {
344    let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?;
345    let builder = ctx.builder();
346    let left = either_ty.build_tag(builder, 0, vec![else_value])?;
347    let right = either_ty.build_tag(builder, 1, vec![ok_value])?;
348    let either = builder.build_select(is_ok, right, left, "")?;
349    Ok(either)
350}