Skip to main content

haloumi_backend/
backend.rs

1use std::marker::PhantomData;
2
3use haloumi_ir_gen::{circuit::resolved::ResolvedIRCircuit, ctx::IRCtx};
4
5use crate::codegen::{
6    Codegen, CodegenParams, CodegenStrategy,
7    strats::{groups::GroupConstraintsStrat, inline::InlineConstraintsStrat},
8};
9
10/// Entrypoint for the backend.
11pub struct Backend<C, S> {
12    state: S,
13    _codegen: PhantomData<C>,
14}
15
16impl<C, S: std::fmt::Debug> std::fmt::Debug for Backend<C, S> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("Backend")
19            .field("state", &self.state)
20            .finish()
21    }
22}
23
24impl<C, S> Backend<C, S> {
25    /// Initializes the backend.
26    pub fn initialize<P: Clone + Into<S>>(params: P) -> Self {
27        Self {
28            state: params.into(),
29            _codegen: PhantomData,
30        }
31    }
32}
33
34impl<'b, 's: 'b, C> Backend<C, C::State>
35where
36    C: Codegen<'s, 'b>,
37    C::State: 's,
38    C::Output: 's,
39    C::State: CodegenParams,
40{
41    fn create_codegen(&'b self) -> C {
42        C::initialize(&self.state)
43    }
44
45    /// Generate code using the default strategy.
46    pub fn codegen(&'b self, ir: &ResolvedIRCircuit, ctx: &IRCtx) -> Result<C::Output, C::Error> {
47        if self.state.inlining_enabled() {
48            self.codegen_with_strat(ir, ctx, InlineConstraintsStrat::default())
49        } else {
50            self.codegen_with_strat(ir, ctx, GroupConstraintsStrat::default())
51        }
52    }
53
54    /// Generate code using the given strategy.
55    fn codegen_with_strat(
56        &'b self,
57        ir: &ResolvedIRCircuit,
58        ctx: &IRCtx,
59        strat: impl CodegenStrategy,
60    ) -> Result<C::Output, C::Error> {
61        log::debug!("Initializing code generator");
62        let codegen = self.create_codegen();
63        codegen.set_prime_field(ir.prime())?;
64        log::debug!(
65            "Starting code generation with {} strategy...",
66            std::any::type_name_of_val(&strat)
67        );
68
69        strat.codegen(&codegen, ctx, ir)?;
70
71        log::debug!("Code generation completed");
72        codegen.generate_output()
73    }
74}