1use anyhow::{Result, anyhow, bail};
2use hugr_core::Node;
3use hugr_core::hugr::internal::PortgraphNodeMap;
4use hugr_core::ops::{
5 CFG, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
6 LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, constant::Sum,
7};
8use hugr_core::{
9 HugrView, NodeIndex,
10 types::{SumType, Type, TypeEnum},
11};
12use inkwell::types::BasicTypeEnum;
13use inkwell::values::{BasicValueEnum, CallableValue};
14use itertools::{Itertools, zip_eq};
15use petgraph::visit::Walker;
16
17use crate::{
18 sum::LLVMSumValue,
19 utils::fat::{FatExt as _, FatNode},
20};
21
22use super::{
23 EmitOpArgs, deaggregate_call_result,
24 func::{EmitFuncContext, RowPromise},
25};
26
27mod cfg;
28
29struct DataflowParentEmitter<'c, 'hugr, OT, H> {
30 node: FatNode<'hugr, OT, H>,
31 inputs: Option<Vec<BasicValueEnum<'c>>>,
32 outputs: Option<RowPromise<'c>>,
33}
34
35impl<'c, 'hugr, OT: OpTrait, H: HugrView<Node = Node>> DataflowParentEmitter<'c, 'hugr, OT, H>
36where
37 for<'a> &'a OpType: TryInto<&'a OT>,
38{
39 pub fn new(args: EmitOpArgs<'c, 'hugr, OT, H>) -> Self {
40 Self {
41 node: args.node,
42 inputs: Some(args.inputs),
43 outputs: Some(args.outputs),
44 }
45 }
46
47 fn take_input(&mut self) -> Result<Vec<BasicValueEnum<'c>>> {
49 self.inputs
50 .take()
51 .ok_or(anyhow!("DataflowParentEmitter: Input taken twice"))
52 }
53
54 fn take_output(&mut self) -> Result<RowPromise<'c>> {
55 self.outputs
56 .take()
57 .ok_or(anyhow!("DataflowParentEmitter: Output taken twice"))
58 }
59
60 pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
61 use petgraph::visit::Topo;
62 let node = self.node;
63 if !OpTag::DataflowParent.is_superset(node.tag()) {
64 Err(anyhow!("Not a dataflow parent"))?;
65 }
66
67 let (i, o): (FatNode<Input, H>, FatNode<Output, H>) = node
68 .get_io()
69 .ok_or(anyhow!("emit_dataflow_parent: no io nodes"))?;
70 debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len());
71 debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len());
72
73 let (region_graph, node_map) = node.hugr().region_portgraph(node.node());
74 let topo = Topo::new(®ion_graph);
75 for n in topo.iter(®ion_graph) {
76 let node = node.hugr().fat_optype(node_map.from_portgraph(n));
77 let inputs_rmb = context.node_ins_rmb(node)?;
78 let inputs = inputs_rmb.read(context.builder(), [])?;
79 let outputs = context.node_outs_rmb(node)?.promise();
80 match node.as_ref() {
81 OpType::Input(_) => {
82 let i = self.take_input()?;
83 outputs.finish(context.builder(), i)?;
84 }
85 OpType::Output(_) => {
86 let o = self.take_output()?;
87 o.finish(context.builder(), inputs)?;
88 }
89 _ => emit_optype(
90 context,
91 EmitOpArgs {
92 node,
93 inputs,
94 outputs,
95 },
96 )?,
97 }
98 }
99 Ok(())
100 }
101}
102
103fn get_exactly_one_sum_type(ts: impl IntoIterator<Item = Type>) -> Result<SumType> {
104 let Some(TypeEnum::Sum(sum_type)) = ts
105 .into_iter()
106 .map(|t| t.as_type_enum().clone())
107 .exactly_one()
108 .ok()
109 else {
110 Err(anyhow!("Not exactly one SumType"))?
111 };
112 Ok(sum_type)
113}
114
115pub fn emit_value<'c, H: HugrView<Node = Node>>(
116 context: &mut EmitFuncContext<'c, '_, H>,
117 v: &Value,
118) -> Result<BasicValueEnum<'c>> {
119 match v {
120 Value::Extension { e } => context.emit_custom_const(e.value()),
121 Value::Function { .. } => bail!(
122 "Value::Function Const nodes are not supported. \
123 Ensure you eliminate these from the HUGR before lowering to LLVM. \
124 `hugr_llvm::utils::inline_constant_functions` is provided for this purpose."
125 ),
126 Value::Sum(Sum {
127 tag,
128 values,
129 sum_type,
130 }) => {
131 let llvm_st = context.llvm_sum_type(sum_type.clone())?;
132 let vs = values
133 .iter()
134 .map(|x| emit_value(context, x))
135 .collect::<Result<Vec<_>>>()?;
136 Ok(llvm_st.build_tag(context.builder(), *tag, vs)?.into())
137 }
138 }
139}
140
141pub(crate) fn emit_dataflow_parent<'c, 'hugr, OT: OpTrait, H: HugrView<Node = Node>>(
142 context: &mut EmitFuncContext<'c, '_, H>,
143 args: EmitOpArgs<'c, 'hugr, OT, H>,
144) -> Result<()>
145where
146 for<'a> &'a OpType: TryInto<&'a OT>,
147{
148 DataflowParentEmitter::new(args).emit_children(context)
149}
150
151fn emit_tag<'c, H: HugrView<Node = Node>>(
152 context: &mut EmitFuncContext<'c, '_, H>,
153 args: EmitOpArgs<'c, '_, Tag, H>,
154) -> Result<()> {
155 let st = context.llvm_sum_type(get_exactly_one_sum_type(
156 args.node.out_value_types().map(|x| x.1),
157 )?)?;
158 let builder = context.builder();
159 args.outputs.finish(
160 builder,
161 [st.build_tag(builder, args.node.tag, args.inputs)?.into()],
162 )
163}
164
165fn emit_conditional<'c, H: HugrView<Node = Node>>(
166 context: &mut EmitFuncContext<'c, '_, H>,
167 EmitOpArgs {
168 node,
169 inputs,
170 outputs,
171 }: EmitOpArgs<'c, '_, Conditional, H>,
172) -> Result<()> {
173 let exit_rmb =
174 context.new_row_mail_box(node.dataflow_signature().unwrap().output.iter(), "exit_rmb")?;
175 let exit_block = context.build_positioned_new_block(
176 format!("cond_exit_{}", node.node().index()),
177 None,
178 |context, bb| {
179 let builder = context.builder();
180 outputs.finish(builder, exit_rmb.read_vec(builder, [])?)?;
181 Ok::<_, anyhow::Error>(bb)
182 },
183 )?;
184
185 let rmbs_blocks = node
186 .children()
187 .enumerate()
188 .map(|(i, n)| {
189 let label = format!("cond_{}_case_{}", node.node().index(), i);
190 let node = n.try_into_ot::<Case>().ok_or(anyhow!("not a case node"))?;
191 let rmb = context.new_row_mail_box(node.get_io().unwrap().0.types.iter(), &label)?;
192 context.build_positioned_new_block(&label, Some(exit_block), |context, bb| {
193 let inputs = rmb.read_vec(context.builder(), [])?;
194 emit_dataflow_parent(
195 context,
196 EmitOpArgs {
197 node,
198 inputs,
199 outputs: exit_rmb.promise(),
200 },
201 )?;
202 context.builder().build_unconditional_branch(exit_block)?;
203 Ok((rmb, bb))
204 })
205 })
206 .collect::<Result<Vec<_>>>()?;
207
208 let sum_type = get_exactly_one_sum_type(node.in_value_types().next().map(|x| x.1))?;
209 let sum_input = LLVMSumValue::try_new(inputs[0], context.llvm_sum_type(sum_type)?)?;
210 let builder = context.builder();
211 sum_input.build_destructure(builder, |builder, tag, mut vs| {
212 let (rmb, bb) = &rmbs_blocks[tag];
213 vs.extend(&inputs[1..]);
214 rmb.write(builder, vs)?;
215 builder.build_unconditional_branch(*bb)?;
216 Ok(())
217 })?;
218 builder.position_at_end(exit_block);
219 Ok(())
220}
221
222fn emit_load_constant<'c, H: HugrView<Node = Node>>(
223 context: &mut EmitFuncContext<'c, '_, H>,
224 args: EmitOpArgs<'c, '_, LoadConstant, H>,
225) -> Result<()> {
226 let konst_node = args
227 .node
228 .single_linked_output(0.into())
229 .unwrap()
230 .0
231 .try_into_ot::<Const>()
232 .unwrap();
233 let r = emit_value(context, konst_node.value())?;
234 args.outputs.finish(context.builder(), [r])
235}
236
237fn emit_call<'c, H: HugrView<Node = Node>>(
238 context: &mut EmitFuncContext<'c, '_, H>,
239 args: EmitOpArgs<'c, '_, Call, H>,
240) -> Result<()> {
241 if !args.node.called_function_type().params().is_empty() {
242 return Err(anyhow!("Call of generic function"));
243 }
244 let (func_node, _) = args
245 .node
246 .single_linked_output(args.node.called_function_port())
247 .unwrap();
248 let func = match func_node.as_ref() {
249 OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
250 OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
251 _ => Err(anyhow!("emit_call: Not a Decl or Defn")),
252 };
253 let inputs = args.inputs.into_iter().map_into().collect_vec();
254 let builder = context.builder();
255 let call = builder.build_call(func?, inputs.as_slice(), "")?;
256 let call_results = deaggregate_call_result(builder, call, args.outputs.len())?;
257 args.outputs.finish(builder, call_results)
258}
259
260fn emit_call_indirect<'c, H: HugrView<Node = Node>>(
261 context: &mut EmitFuncContext<'c, '_, H>,
262 args: EmitOpArgs<'c, '_, CallIndirect, H>,
263) -> Result<()> {
264 let func_ptr = match args.inputs[0] {
265 BasicValueEnum::PointerValue(v) => Ok(v),
266 _ => Err(anyhow!("emit_call_indirect: Not a pointer")),
267 }?;
268 let func =
269 CallableValue::try_from(func_ptr).expect("emit_call_indirect: Not a function pointer");
270 let inputs = args.inputs.into_iter().skip(1).map_into().collect_vec();
271 let builder = context.builder();
272 let call = builder.build_call(func, inputs.as_slice(), "")?;
273 let call_results = deaggregate_call_result(builder, call, args.outputs.len())?;
274 args.outputs.finish(builder, call_results)
275}
276
277fn emit_load_function<'c, H: HugrView<Node = Node>>(
278 context: &mut EmitFuncContext<'c, '_, H>,
279 args: EmitOpArgs<'c, '_, LoadFunction, H>,
280) -> Result<()> {
281 if !args.node.func_sig.params().is_empty() {
282 return Err(anyhow!("Load of generic function"));
283 }
284 let (func_node, _) = args
285 .node
286 .single_linked_output(args.node.function_port())
287 .unwrap();
288
289 let func = match func_node.as_ref() {
290 OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
291 OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
292 _ => Err(anyhow!("emit_call: Not a Decl or Defn")),
293 }?;
294 args.outputs.finish(
295 context.builder(),
296 [func.as_global_value().as_pointer_value().into()],
297 )
298}
299
300fn emit_cfg<'c, H: HugrView<Node = Node>>(
301 context: &mut EmitFuncContext<'c, '_, H>,
302 args: EmitOpArgs<'c, '_, CFG, H>,
303) -> Result<()> {
304 cfg::CfgEmitter::new(context, args)?.emit_children(context)
305}
306
307fn emit_tail_loop<'c, H: HugrView<Node = Node>>(
308 context: &mut EmitFuncContext<'c, '_, H>,
309 args: EmitOpArgs<'c, '_, TailLoop, H>,
310) -> Result<()> {
311 let node = args.node();
312
313 let out_bb = context.new_basic_block("loop_out", None);
315 let body_bb = context.new_basic_block("loop_body", Some(out_bb));
317
318 let (body_i_node, body_o_node) = node.get_io().unwrap();
319 let body_i_rmb = context.node_outs_rmb(body_i_node)?;
320 let body_o_rmb = context.node_ins_rmb(body_o_node)?;
321
322 body_i_rmb.write(context.builder(), args.inputs)?;
323 context.builder().build_unconditional_branch(body_bb)?;
324
325 let control_llvm_sum_type = {
326 let sum_ty = SumType::new([node.just_inputs.clone(), node.just_outputs.clone()]);
327 context.llvm_sum_type(sum_ty)?
328 };
329
330 context.build_positioned(body_bb, move |context| {
331 let inputs = body_i_rmb.read_vec(context.builder(), [])?;
332 emit_dataflow_parent(
333 context,
334 EmitOpArgs {
335 node,
336 inputs,
337 outputs: body_o_rmb.promise(),
338 },
339 )?;
340 let dataflow_outputs = body_o_rmb.read_vec(context.builder(), [])?;
341 let control_val = LLVMSumValue::try_new(dataflow_outputs[0], control_llvm_sum_type)?;
342 let mut outputs = Some(args.outputs);
343
344 control_val.build_destructure(context.builder(), |builder, tag, mut values| {
345 values.extend(dataflow_outputs[1..].iter().copied());
346 if tag == 0 {
347 body_i_rmb.write(builder, values)?;
348 builder.build_unconditional_branch(body_bb)?;
349 } else {
350 outputs.take().unwrap().finish(builder, values)?;
351 builder.build_unconditional_branch(out_bb)?;
352 }
353 Ok(())
354 })
355 })?;
356 context.builder().position_at_end(out_bb);
357 Ok(())
358}
359
360fn emit_optype<'c, H: HugrView<Node = Node>>(
361 context: &mut EmitFuncContext<'c, '_, H>,
362 args: EmitOpArgs<'c, '_, OpType, H>,
363) -> Result<()> {
364 let node = args.node();
365 match node.as_ref() {
366 OpType::Tag(tag) => emit_tag(context, args.into_ot(tag)),
367 OpType::DFG(_) => emit_dataflow_parent(context, args),
368
369 OpType::ExtensionOp(co) => context.emit_extension_op(args.into_ot(co)),
370 OpType::LoadConstant(lc) => emit_load_constant(context, args.into_ot(lc)),
371 OpType::Call(cl) => emit_call(context, args.into_ot(cl)),
372 OpType::CallIndirect(cl) => emit_call_indirect(context, args.into_ot(cl)),
373 OpType::LoadFunction(lf) => emit_load_function(context, args.into_ot(lf)),
374 OpType::Conditional(co) => emit_conditional(context, args.into_ot(co)),
375 OpType::CFG(cfg) => emit_cfg(context, args.into_ot(cfg)),
376 OpType::Const(_) => Ok(()),
379 OpType::FuncDecl(_) => Ok(()),
380 OpType::FuncDefn(fd) => {
381 context.push_todo_func(node.into_ot(fd));
382 Ok(())
383 }
384 OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)),
385 _ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")),
386 }
387}
388
389pub(crate) fn emit_custom_unary_op<'c, 'hugr, H, F>(
398 context: &mut EmitFuncContext<'c, '_, H>,
399 args: EmitOpArgs<'c, 'hugr, ExtensionOp, H>,
400 go: F,
401) -> Result<()>
402where
403 H: HugrView<Node = Node>,
404 F: FnOnce(
405 &mut EmitFuncContext<'c, '_, H>,
406 BasicValueEnum<'c>,
407 &[BasicTypeEnum<'c>],
408 ) -> Result<Vec<BasicValueEnum<'c>>>,
409{
410 let [inp] = TryInto::<[_; 1]>::try_into(args.inputs).map_err(|v| {
411 anyhow!(
412 "emit_custom_unary_op: expected exactly one input, got {}",
413 v.len()
414 )
415 })?;
416 let out_types = args.outputs.get_types().collect_vec();
417 let res = go(context, inp, &out_types)?;
418 if res.len() != args.outputs.len()
419 || zip_eq(res.iter(), out_types).any(|(a, b)| a.get_type() != b)
420 {
421 return Err(anyhow!(
422 "emit_custom_unary_op: expected outputs of types {:?}, got {:?}",
423 args.outputs.get_types().collect_vec(),
424 res.iter().map(BasicValueEnum::get_type).collect_vec()
425 ));
426 }
427 args.outputs.finish(context.builder(), res)
428}
429
430pub(crate) fn emit_custom_binary_op<'c, 'hugr, H, F>(
439 context: &mut EmitFuncContext<'c, '_, H>,
440 args: EmitOpArgs<'c, 'hugr, ExtensionOp, H>,
441 go: F,
442) -> Result<()>
443where
444 H: HugrView<Node = Node>,
445 F: FnOnce(
446 &mut EmitFuncContext<'c, '_, H>,
447 (BasicValueEnum<'c>, BasicValueEnum<'c>),
448 &[BasicTypeEnum<'c>],
449 ) -> Result<Vec<BasicValueEnum<'c>>>,
450{
451 let [lhs, rhs] = TryInto::<[_; 2]>::try_into(args.inputs).map_err(|v| {
452 anyhow!(
453 "emit_custom_binary_op: expected exactly 2 inputs, got {}",
454 v.len()
455 )
456 })?;
457 if lhs.get_type() != rhs.get_type() {
458 return Err(anyhow!(
459 "emit_custom_binary_op: expected inputs of the same type, got {} and {}",
460 lhs.get_type(),
461 rhs.get_type()
462 ));
463 }
464 let out_types = args.outputs.get_types().collect_vec();
465 let res = go(context, (lhs, rhs), &out_types)?;
466 if res.len() != out_types.len() || zip_eq(res.iter(), out_types).any(|(a, b)| a.get_type() != b)
467 {
468 return Err(anyhow!(
469 "emit_custom_binary_op: expected outputs of types {:?}, got {:?}",
470 args.outputs.get_types().collect_vec(),
471 res.iter().map(BasicValueEnum::get_type).collect_vec()
472 ));
473 }
474 args.outputs.finish(context.builder(), res)
475}