haloumi_backend/
codegen.rs1use std::cell::RefCell;
4use std::rc::Rc;
5
6use haloumi_ir::Prime;
7use haloumi_ir::Slot;
8use haloumi_ir_gen::{circuit::resolved::ResolvedIRCircuit, ctx::IRCtx};
9use haloumi_lowering::{
10 ExprLowering as _, Lowering, error::Error as LoweringError, lowerable::LowerableStmt,
11};
12use haloumi_synthesis::io::{AdviceIO, InstanceIO};
13
14pub(crate) mod strats;
15
16pub trait Codegen<'c: 's, 's>: Sized + 's {
18 type FuncOutput: Lowering;
20 type Output;
22 type State: 'c;
24 type Error: From<LoweringError> + From<haloumi_ir_gen::error::Error>;
26
27 fn initialize(state: &'s Self::State) -> Self;
29
30 #[allow(unused_variables)]
34 fn set_prime_field(&self, prime: Prime) -> Result<(), Self::Error> {
35 Ok(())
36 }
37
38 fn define_function(
40 &self,
41 name: &str,
42 inputs: usize,
43 outputs: usize,
44 ) -> Result<Self::FuncOutput, Self::Error>;
45
46 fn define_function_with_body<FN, L, I>(
48 &self,
49 name: &str,
50 inputs: usize,
51 outputs: usize,
52 f: FN,
53 ) -> Result<(), Self::Error>
54 where
55 FN: FnOnce(&Self::FuncOutput, &[Slot], &[Slot]) -> Result<I, Self::Error>,
56 I: IntoIterator<Item = L>,
57 L: LowerableStmt,
58 {
59 let func = self.define_function(name, inputs, outputs)?;
60 let inputs = func.lower_function_inputs(0..inputs);
61 let outputs = func.lower_function_outputs(0..outputs);
62 let stmts = f(&func, &inputs, &outputs)?;
63 for stmt in stmts {
64 stmt.lower(&func)?;
65 }
66 self.on_scope_end(func)
67 }
68
69 fn define_main_function(
71 &self,
72 advice_io: &AdviceIO,
73 instance_io: &InstanceIO,
74 ) -> Result<Self::FuncOutput, Self::Error>;
75
76 fn define_main_function_with_body<L>(
78 &self,
79 advice_io: &AdviceIO,
80 instance_io: &InstanceIO,
81 stmts: impl IntoIterator<Item = L>,
82 ) -> Result<(), Self::Error>
83 where
84 L: LowerableStmt + std::fmt::Debug,
85 {
86 let main = self.define_main_function(advice_io, instance_io)?;
87 log::debug!("Defined main function");
88 for stmt in stmts {
89 log::debug!("Lowering statement {stmt:?}");
90 stmt.lower(&main)?;
91 }
92 log::debug!("Lowered function body");
93 self.on_scope_end(main)
94 }
95
96 fn on_scope_end(&self, _: Self::FuncOutput) -> Result<(), Self::Error> {
98 Ok(())
99 }
100
101 fn generate_output(self) -> Result<Self::Output, Self::Error>
103 where
104 Self::Output: 'c;
105}
106
107pub(crate) trait CodegenStrategy {
108 fn codegen<'c: 'st, 's, 'st, C>(
109 &self,
110 codegen: &C,
111 ctx: &IRCtx,
112 ir: &ResolvedIRCircuit,
113 ) -> Result<(), C::Error>
114 where
115 C: Codegen<'c, 'st>;
116}
117
118pub trait CodegenParams {
120 fn inlining_enabled(&self) -> bool;
122}
123
124impl<T: CodegenParams> CodegenParams for Rc<RefCell<T>> {
125 fn inlining_enabled(&self) -> bool {
126 self.borrow().inlining_enabled()
127 }
128}