hugr_llvm/emit/
ops.rs

1use anyhow::{Result, anyhow, bail};
2use hugr_core::Node;
3use hugr_core::hugr::internal::PortgraphNodeMap;
4use hugr_core::ops::{
5    CFG, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
6    LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, constant::Sum,
7};
8use hugr_core::{
9    HugrView, NodeIndex,
10    types::{SumType, Type, TypeEnum},
11};
12use inkwell::types::BasicTypeEnum;
13use inkwell::values::{BasicValueEnum, CallableValue};
14use itertools::{Itertools, zip_eq};
15use petgraph::visit::Walker;
16
17use crate::{
18    sum::LLVMSumValue,
19    utils::fat::{FatExt as _, FatNode},
20};
21
22use super::{
23    EmitOpArgs, deaggregate_call_result,
24    func::{EmitFuncContext, RowPromise},
25};
26
27mod cfg;
28
29struct DataflowParentEmitter<'c, 'hugr, OT, H> {
30    node: FatNode<'hugr, OT, H>,
31    inputs: Option<Vec<BasicValueEnum<'c>>>,
32    outputs: Option<RowPromise<'c>>,
33}
34
35impl<'c, 'hugr, OT: OpTrait, H: HugrView<Node = Node>> DataflowParentEmitter<'c, 'hugr, OT, H>
36where
37    for<'a> &'a OpType: TryInto<&'a OT>,
38{
39    pub fn new(args: EmitOpArgs<'c, 'hugr, OT, H>) -> Self {
40        Self {
41            node: args.node,
42            inputs: Some(args.inputs),
43            outputs: Some(args.outputs),
44        }
45    }
46
47    /// safe because we are guaranteed only one input or output node
48    fn take_input(&mut self) -> Result<Vec<BasicValueEnum<'c>>> {
49        self.inputs
50            .take()
51            .ok_or(anyhow!("DataflowParentEmitter: Input taken twice"))
52    }
53
54    fn take_output(&mut self) -> Result<RowPromise<'c>> {
55        self.outputs
56            .take()
57            .ok_or(anyhow!("DataflowParentEmitter: Output taken twice"))
58    }
59
60    pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
61        use petgraph::visit::Topo;
62        let node = self.node;
63        if !OpTag::DataflowParent.is_superset(node.tag()) {
64            Err(anyhow!("Not a dataflow parent"))?;
65        }
66
67        let (i, o): (FatNode<Input, H>, FatNode<Output, H>) = node
68            .get_io()
69            .ok_or(anyhow!("emit_dataflow_parent: no io nodes"))?;
70        debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len());
71        debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len());
72
73        let (region_graph, node_map) = node.hugr().region_portgraph(node.node());
74        let topo = Topo::new(&region_graph);
75        for n in topo.iter(&region_graph) {
76            let node = node.hugr().fat_optype(node_map.from_portgraph(n));
77            let inputs_rmb = context.node_ins_rmb(node)?;
78            let inputs = inputs_rmb.read(context.builder(), [])?;
79            let outputs = context.node_outs_rmb(node)?.promise();
80            match node.as_ref() {
81                OpType::Input(_) => {
82                    let i = self.take_input()?;
83                    outputs.finish(context.builder(), i)?;
84                }
85                OpType::Output(_) => {
86                    let o = self.take_output()?;
87                    o.finish(context.builder(), inputs)?;
88                }
89                _ => emit_optype(
90                    context,
91                    EmitOpArgs {
92                        node,
93                        inputs,
94                        outputs,
95                    },
96                )?,
97            }
98        }
99        Ok(())
100    }
101}
102
103fn get_exactly_one_sum_type(ts: impl IntoIterator<Item = Type>) -> Result<SumType> {
104    let Some(TypeEnum::Sum(sum_type)) = ts
105        .into_iter()
106        .map(|t| t.as_type_enum().clone())
107        .exactly_one()
108        .ok()
109    else {
110        Err(anyhow!("Not exactly one SumType"))?
111    };
112    Ok(sum_type)
113}
114
115pub fn emit_value<'c, H: HugrView<Node = Node>>(
116    context: &mut EmitFuncContext<'c, '_, H>,
117    v: &Value,
118) -> Result<BasicValueEnum<'c>> {
119    match v {
120        Value::Extension { e } => context.emit_custom_const(e.value()),
121        Value::Function { .. } => bail!(
122            "Value::Function Const nodes are not supported. \
123            Ensure you eliminate these from the HUGR before lowering to LLVM. \
124            `hugr_llvm::utils::inline_constant_functions` is provided for this purpose."
125        ),
126        Value::Sum(Sum {
127            tag,
128            values,
129            sum_type,
130        }) => {
131            let llvm_st = context.llvm_sum_type(sum_type.clone())?;
132            let vs = values
133                .iter()
134                .map(|x| emit_value(context, x))
135                .collect::<Result<Vec<_>>>()?;
136            Ok(llvm_st.build_tag(context.builder(), *tag, vs)?.into())
137        }
138    }
139}
140
141pub(crate) fn emit_dataflow_parent<'c, 'hugr, OT: OpTrait, H: HugrView<Node = Node>>(
142    context: &mut EmitFuncContext<'c, '_, H>,
143    args: EmitOpArgs<'c, 'hugr, OT, H>,
144) -> Result<()>
145where
146    for<'a> &'a OpType: TryInto<&'a OT>,
147{
148    DataflowParentEmitter::new(args).emit_children(context)
149}
150
151fn emit_tag<'c, H: HugrView<Node = Node>>(
152    context: &mut EmitFuncContext<'c, '_, H>,
153    args: EmitOpArgs<'c, '_, Tag, H>,
154) -> Result<()> {
155    let st = context.llvm_sum_type(get_exactly_one_sum_type(
156        args.node.out_value_types().map(|x| x.1),
157    )?)?;
158    let builder = context.builder();
159    args.outputs.finish(
160        builder,
161        [st.build_tag(builder, args.node.tag, args.inputs)?.into()],
162    )
163}
164
165fn emit_conditional<'c, H: HugrView<Node = Node>>(
166    context: &mut EmitFuncContext<'c, '_, H>,
167    EmitOpArgs {
168        node,
169        inputs,
170        outputs,
171    }: EmitOpArgs<'c, '_, Conditional, H>,
172) -> Result<()> {
173    let exit_rmb =
174        context.new_row_mail_box(node.dataflow_signature().unwrap().output.iter(), "exit_rmb")?;
175    let exit_block = context.build_positioned_new_block(
176        format!("cond_exit_{}", node.node().index()),
177        None,
178        |context, bb| {
179            let builder = context.builder();
180            outputs.finish(builder, exit_rmb.read_vec(builder, [])?)?;
181            Ok::<_, anyhow::Error>(bb)
182        },
183    )?;
184
185    let rmbs_blocks = node
186        .children()
187        .enumerate()
188        .map(|(i, n)| {
189            let label = format!("cond_{}_case_{}", node.node().index(), i);
190            let node = n.try_into_ot::<Case>().ok_or(anyhow!("not a case node"))?;
191            let rmb = context.new_row_mail_box(node.get_io().unwrap().0.types.iter(), &label)?;
192            context.build_positioned_new_block(&label, Some(exit_block), |context, bb| {
193                let inputs = rmb.read_vec(context.builder(), [])?;
194                emit_dataflow_parent(
195                    context,
196                    EmitOpArgs {
197                        node,
198                        inputs,
199                        outputs: exit_rmb.promise(),
200                    },
201                )?;
202                context.builder().build_unconditional_branch(exit_block)?;
203                Ok((rmb, bb))
204            })
205        })
206        .collect::<Result<Vec<_>>>()?;
207
208    let sum_type = get_exactly_one_sum_type(node.in_value_types().next().map(|x| x.1))?;
209    let sum_input = LLVMSumValue::try_new(inputs[0], context.llvm_sum_type(sum_type)?)?;
210    let builder = context.builder();
211    sum_input.build_destructure(builder, |builder, tag, mut vs| {
212        let (rmb, bb) = &rmbs_blocks[tag];
213        vs.extend(&inputs[1..]);
214        rmb.write(builder, vs)?;
215        builder.build_unconditional_branch(*bb)?;
216        Ok(())
217    })?;
218    builder.position_at_end(exit_block);
219    Ok(())
220}
221
222fn emit_load_constant<'c, H: HugrView<Node = Node>>(
223    context: &mut EmitFuncContext<'c, '_, H>,
224    args: EmitOpArgs<'c, '_, LoadConstant, H>,
225) -> Result<()> {
226    let konst_node = args
227        .node
228        .single_linked_output(0.into())
229        .unwrap()
230        .0
231        .try_into_ot::<Const>()
232        .unwrap();
233    let r = emit_value(context, konst_node.value())?;
234    args.outputs.finish(context.builder(), [r])
235}
236
237fn emit_call<'c, H: HugrView<Node = Node>>(
238    context: &mut EmitFuncContext<'c, '_, H>,
239    args: EmitOpArgs<'c, '_, Call, H>,
240) -> Result<()> {
241    if !args.node.called_function_type().params().is_empty() {
242        return Err(anyhow!("Call of generic function"));
243    }
244    let (func_node, _) = args
245        .node
246        .single_linked_output(args.node.called_function_port())
247        .unwrap();
248    let func = match func_node.as_ref() {
249        OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
250        OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
251        _ => Err(anyhow!("emit_call: Not a Decl or Defn")),
252    };
253    let inputs = args.inputs.into_iter().map_into().collect_vec();
254    let builder = context.builder();
255    let call = builder.build_call(func?, inputs.as_slice(), "")?;
256    let call_results = deaggregate_call_result(builder, call, args.outputs.len())?;
257    args.outputs.finish(builder, call_results)
258}
259
260fn emit_call_indirect<'c, H: HugrView<Node = Node>>(
261    context: &mut EmitFuncContext<'c, '_, H>,
262    args: EmitOpArgs<'c, '_, CallIndirect, H>,
263) -> Result<()> {
264    let func_ptr = match args.inputs[0] {
265        BasicValueEnum::PointerValue(v) => Ok(v),
266        _ => Err(anyhow!("emit_call_indirect: Not a pointer")),
267    }?;
268    let func =
269        CallableValue::try_from(func_ptr).expect("emit_call_indirect: Not a function pointer");
270    let inputs = args.inputs.into_iter().skip(1).map_into().collect_vec();
271    let builder = context.builder();
272    let call = builder.build_call(func, inputs.as_slice(), "")?;
273    let call_results = deaggregate_call_result(builder, call, args.outputs.len())?;
274    args.outputs.finish(builder, call_results)
275}
276
277fn emit_load_function<'c, H: HugrView<Node = Node>>(
278    context: &mut EmitFuncContext<'c, '_, H>,
279    args: EmitOpArgs<'c, '_, LoadFunction, H>,
280) -> Result<()> {
281    if !args.node.func_sig.params().is_empty() {
282        return Err(anyhow!("Load of generic function"));
283    }
284    let (func_node, _) = args
285        .node
286        .single_linked_output(args.node.function_port())
287        .unwrap();
288
289    let func = match func_node.as_ref() {
290        OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
291        OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
292        _ => Err(anyhow!("emit_call: Not a Decl or Defn")),
293    }?;
294    args.outputs.finish(
295        context.builder(),
296        [func.as_global_value().as_pointer_value().into()],
297    )
298}
299
300fn emit_cfg<'c, H: HugrView<Node = Node>>(
301    context: &mut EmitFuncContext<'c, '_, H>,
302    args: EmitOpArgs<'c, '_, CFG, H>,
303) -> Result<()> {
304    cfg::CfgEmitter::new(context, args)?.emit_children(context)
305}
306
307fn emit_tail_loop<'c, H: HugrView<Node = Node>>(
308    context: &mut EmitFuncContext<'c, '_, H>,
309    args: EmitOpArgs<'c, '_, TailLoop, H>,
310) -> Result<()> {
311    let node = args.node();
312
313    // Make a block to jump to when we `Break`
314    let out_bb = context.new_basic_block("loop_out", None);
315    // A block for the body of the loop
316    let body_bb = context.new_basic_block("loop_body", Some(out_bb));
317
318    let (body_i_node, body_o_node) = node.get_io().unwrap();
319    let body_i_rmb = context.node_outs_rmb(body_i_node)?;
320    let body_o_rmb = context.node_ins_rmb(body_o_node)?;
321
322    body_i_rmb.write(context.builder(), args.inputs)?;
323    context.builder().build_unconditional_branch(body_bb)?;
324
325    let control_llvm_sum_type = {
326        let sum_ty = SumType::new([node.just_inputs.clone(), node.just_outputs.clone()]);
327        context.llvm_sum_type(sum_ty)?
328    };
329
330    context.build_positioned(body_bb, move |context| {
331        let inputs = body_i_rmb.read_vec(context.builder(), [])?;
332        emit_dataflow_parent(
333            context,
334            EmitOpArgs {
335                node,
336                inputs,
337                outputs: body_o_rmb.promise(),
338            },
339        )?;
340        let dataflow_outputs = body_o_rmb.read_vec(context.builder(), [])?;
341        let control_val = LLVMSumValue::try_new(dataflow_outputs[0], control_llvm_sum_type)?;
342        let mut outputs = Some(args.outputs);
343
344        control_val.build_destructure(context.builder(), |builder, tag, mut values| {
345            values.extend(dataflow_outputs[1..].iter().copied());
346            if tag == 0 {
347                body_i_rmb.write(builder, values)?;
348                builder.build_unconditional_branch(body_bb)?;
349            } else {
350                outputs.take().unwrap().finish(builder, values)?;
351                builder.build_unconditional_branch(out_bb)?;
352            }
353            Ok(())
354        })
355    })?;
356    context.builder().position_at_end(out_bb);
357    Ok(())
358}
359
360fn emit_optype<'c, H: HugrView<Node = Node>>(
361    context: &mut EmitFuncContext<'c, '_, H>,
362    args: EmitOpArgs<'c, '_, OpType, H>,
363) -> Result<()> {
364    let node = args.node();
365    match node.as_ref() {
366        OpType::Tag(tag) => emit_tag(context, args.into_ot(tag)),
367        OpType::DFG(_) => emit_dataflow_parent(context, args),
368
369        OpType::ExtensionOp(co) => context.emit_extension_op(args.into_ot(co)),
370        OpType::LoadConstant(lc) => emit_load_constant(context, args.into_ot(lc)),
371        OpType::Call(cl) => emit_call(context, args.into_ot(cl)),
372        OpType::CallIndirect(cl) => emit_call_indirect(context, args.into_ot(cl)),
373        OpType::LoadFunction(lf) => emit_load_function(context, args.into_ot(lf)),
374        OpType::Conditional(co) => emit_conditional(context, args.into_ot(co)),
375        OpType::CFG(cfg) => emit_cfg(context, args.into_ot(cfg)),
376        // Const is allowed, but requires no work here. FuncDecl is technically
377        // not allowed, but there is no harm in allowing it.
378        OpType::Const(_) => Ok(()),
379        OpType::FuncDecl(_) => Ok(()),
380        OpType::FuncDefn(fd) => {
381            context.push_todo_func(node.into_ot(fd));
382            Ok(())
383        }
384        OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)),
385        _ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")),
386    }
387}
388
389/// Emit a custom operation with a single input.
390///
391/// # Arguments
392///
393/// * `context` - The context in which to emit the operation.
394/// * `args` - The arguments to the operation.
395/// * `go` - The operation to build the result given a [`Builder`], the input,
396///   and an iterator over the expected output types.
397pub(crate) fn emit_custom_unary_op<'c, 'hugr, H, F>(
398    context: &mut EmitFuncContext<'c, '_, H>,
399    args: EmitOpArgs<'c, 'hugr, ExtensionOp, H>,
400    go: F,
401) -> Result<()>
402where
403    H: HugrView<Node = Node>,
404    F: FnOnce(
405        &mut EmitFuncContext<'c, '_, H>,
406        BasicValueEnum<'c>,
407        &[BasicTypeEnum<'c>],
408    ) -> Result<Vec<BasicValueEnum<'c>>>,
409{
410    let [inp] = TryInto::<[_; 1]>::try_into(args.inputs).map_err(|v| {
411        anyhow!(
412            "emit_custom_unary_op: expected exactly one input, got {}",
413            v.len()
414        )
415    })?;
416    let out_types = args.outputs.get_types().collect_vec();
417    let res = go(context, inp, &out_types)?;
418    if res.len() != args.outputs.len()
419        || zip_eq(res.iter(), out_types).any(|(a, b)| a.get_type() != b)
420    {
421        return Err(anyhow!(
422            "emit_custom_unary_op: expected outputs of types {:?}, got {:?}",
423            args.outputs.get_types().collect_vec(),
424            res.iter().map(BasicValueEnum::get_type).collect_vec()
425        ));
426    }
427    args.outputs.finish(context.builder(), res)
428}
429
430/// Emit a custom operation with two inputs of the same type.
431///
432/// # Arguments
433///
434/// * `context` - The context in which to emit the operation.
435/// * `args` - The arguments to the operation.
436/// * `go` - The operation to build the result given a [`Builder`], the two
437///   inputs, and an iterator over the expected output types.
438pub(crate) fn emit_custom_binary_op<'c, 'hugr, H, F>(
439    context: &mut EmitFuncContext<'c, '_, H>,
440    args: EmitOpArgs<'c, 'hugr, ExtensionOp, H>,
441    go: F,
442) -> Result<()>
443where
444    H: HugrView<Node = Node>,
445    F: FnOnce(
446        &mut EmitFuncContext<'c, '_, H>,
447        (BasicValueEnum<'c>, BasicValueEnum<'c>),
448        &[BasicTypeEnum<'c>],
449    ) -> Result<Vec<BasicValueEnum<'c>>>,
450{
451    let [lhs, rhs] = TryInto::<[_; 2]>::try_into(args.inputs).map_err(|v| {
452        anyhow!(
453            "emit_custom_binary_op: expected exactly 2 inputs, got {}",
454            v.len()
455        )
456    })?;
457    if lhs.get_type() != rhs.get_type() {
458        return Err(anyhow!(
459            "emit_custom_binary_op: expected inputs of the same type, got {} and {}",
460            lhs.get_type(),
461            rhs.get_type()
462        ));
463    }
464    let out_types = args.outputs.get_types().collect_vec();
465    let res = go(context, (lhs, rhs), &out_types)?;
466    if res.len() != out_types.len() || zip_eq(res.iter(), out_types).any(|(a, b)| a.get_type() != b)
467    {
468        return Err(anyhow!(
469            "emit_custom_binary_op: expected outputs of types {:?}, got {:?}",
470            args.outputs.get_types().collect_vec(),
471            res.iter().map(BasicValueEnum::get_type).collect_vec()
472        ));
473    }
474    args.outputs.finish(context.builder(), res)
475}