Skip to main content

haloumi_backend/
codegen.rs

1//! Traits for defining code generators.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6use haloumi_ir::Prime;
7use haloumi_ir::Slot;
8use haloumi_ir_gen::{circuit::resolved::ResolvedIRCircuit, ctx::IRCtx};
9use haloumi_lowering::{
10    ExprLowering as _, Lowering, error::Error as LoweringError, lowerable::LowerableStmt,
11};
12use haloumi_synthesis::io::{AdviceIO, InstanceIO};
13
14pub(crate) mod strats;
15
16/// Implemented by code generators that emit some code representation of the circuit.
17pub trait Codegen<'c: 's, 's>: Sized + 's {
18    /// Implementation of the lowering behavior.
19    type FuncOutput: Lowering;
20    /// Final output of the code generator.
21    type Output;
22    /// Internal state of the code generator.
23    type State: 'c;
24    /// Error type of the code generator.
25    type Error: From<LoweringError> + From<haloumi_ir_gen::error::Error>;
26
27    /// Initializes the code generator.
28    fn initialize(state: &'s Self::State) -> Self;
29
30    /// Sets the prime field used by the circuit.
31    ///
32    /// By default does nothing.
33    #[allow(unused_variables)]
34    fn set_prime_field(&self, prime: Prime) -> Result<(), Self::Error> {
35        Ok(())
36    }
37
38    /// Defines an empty function.
39    fn define_function(
40        &self,
41        name: &str,
42        inputs: usize,
43        outputs: usize,
44    ) -> Result<Self::FuncOutput, Self::Error>;
45
46    /// Defines a function filled with the given body.
47    fn define_function_with_body<FN, L, I>(
48        &self,
49        name: &str,
50        inputs: usize,
51        outputs: usize,
52        f: FN,
53    ) -> Result<(), Self::Error>
54    where
55        FN: FnOnce(&Self::FuncOutput, &[Slot], &[Slot]) -> Result<I, Self::Error>,
56        I: IntoIterator<Item = L>,
57        L: LowerableStmt,
58    {
59        let func = self.define_function(name, inputs, outputs)?;
60        let inputs = func.lower_function_inputs(0..inputs);
61        let outputs = func.lower_function_outputs(0..outputs);
62        let stmts = f(&func, &inputs, &outputs)?;
63        for stmt in stmts {
64            stmt.lower(&func)?;
65        }
66        self.on_scope_end(func)
67    }
68
69    /// Defines the entrypoint function of the circuit.
70    fn define_main_function(
71        &self,
72        advice_io: &AdviceIO,
73        instance_io: &InstanceIO,
74    ) -> Result<Self::FuncOutput, Self::Error>;
75
76    /// Defines the entrypoint function of the circuit and fills it with the given body.
77    fn define_main_function_with_body<L>(
78        &self,
79        advice_io: &AdviceIO,
80        instance_io: &InstanceIO,
81        stmts: impl IntoIterator<Item = L>,
82    ) -> Result<(), Self::Error>
83    where
84        L: LowerableStmt + std::fmt::Debug,
85    {
86        let main = self.define_main_function(advice_io, instance_io)?;
87        log::debug!("Defined main function");
88        for stmt in stmts {
89            log::debug!("Lowering statement {stmt:?}");
90            stmt.lower(&main)?;
91        }
92        log::debug!("Lowered function body");
93        self.on_scope_end(main)
94    }
95
96    /// Callback called when a scope ends.
97    fn on_scope_end(&self, _: Self::FuncOutput) -> Result<(), Self::Error> {
98        Ok(())
99    }
100
101    /// Consumes the code generator and returns the final output.
102    fn generate_output(self) -> Result<Self::Output, Self::Error>
103    where
104        Self::Output: 'c;
105}
106
107pub(crate) trait CodegenStrategy {
108    fn codegen<'c: 'st, 's, 'st, C>(
109        &self,
110        codegen: &C,
111        ctx: &IRCtx,
112        ir: &ResolvedIRCircuit,
113    ) -> Result<(), C::Error>
114    where
115        C: Codegen<'c, 'st>;
116}
117
118/// Trait defining generic code generation configuration parameters.
119pub trait CodegenParams {
120    /// Returns true if inlining is enabled.
121    fn inlining_enabled(&self) -> bool;
122}
123
124impl<T: CodegenParams> CodegenParams for Rc<RefCell<T>> {
125    fn inlining_enabled(&self) -> bool {
126        self.borrow().inlining_enabled()
127    }
128}