Skip to main content

haloumi_picus/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(rustdoc::broken_intra_doc_links)]
3#![deny(missing_debug_implementations)]
4#![deny(missing_docs)]
5
6use std::{
7    cell::RefCell,
8    collections::{HashMap, HashSet},
9    rc::Rc,
10};
11
12use haloumi_backend::{Backend, codegen::Codegen};
13use haloumi_core::{felt::Prime, slot::Slot as FuncIO};
14use haloumi_synthesis::io::{AdviceIO, InstanceIO};
15
16//use anyhow::Result;
17
18use inner::PicusCodegenInner;
19use lowering::PicusModuleLowering;
20pub use params::{PicusParams, PicusParamsBuilder};
21use pcl::{opt::MutOptimizer as _, vars::VarStr};
22use utils::mk_io;
23use vars::{NamingConvention, VarKey, VarKeySeed};
24
25mod inner;
26mod lowering;
27mod params;
28mod pcl;
29mod utils;
30mod vars;
31
32/// Instance of a [`Backend`] prepared for lowering to PCL.
33pub type PicusBackend = Backend<PicusCodegen, InnerState>;
34type InnerState = Rc<RefCell<PicusCodegenInner>>;
35type PicusModule = pcl::Module<VarKey>;
36/// Output produced by the picus backend.
37pub type PicusOutput = pcl::Program<VarKey>;
38type PipelineBuilder = pcl::opt::OptimizerPipelineBuilder<VarKey>;
39type Pipeline = pcl::opt::OptimizerPipeline<VarKey>;
40
41impl From<PicusParams> for InnerState {
42    fn from(value: PicusParams) -> Self {
43        Rc::new(RefCell::new(PicusCodegenInner::new(value)))
44    }
45}
46
47/// Code generator for PCL.
48#[derive(Debug, Clone)]
49pub struct PicusCodegen {
50    inner: InnerState,
51}
52
53impl PicusCodegen {
54    fn naming_convention(&self) -> NamingConvention {
55        self.inner.borrow().naming_convention()
56    }
57
58    fn var_consistency_check(&self, output: &PicusOutput) -> Result<(), PicusCodegenError> {
59        // Var consistency check
60        for module in output.modules() {
61            let vars = module.vars();
62            // Get the set of io variables, without the fqn.
63            // This set will have all the circuit cells that have been queried and resolved
64            // during lowering.
65            let io_vars = vars
66                .keys()
67                .filter_map(|k| match k {
68                    VarKey::Slot(slot) => Some(*slot),
69                    _ => None,
70                })
71                .collect::<HashSet<_>>();
72
73            // The set of io variables, with names, should be the same length.
74            let io_var_count = vars
75                .iter()
76                .filter_map(|(k, v)| match k {
77                    VarKey::Slot(_) => Some(v),
78                    _ => None,
79                })
80                .count();
81            if io_vars.len() != io_var_count {
82                // Inconsistency. Let's see which ones.
83                let mut dups = HashMap::<FuncIO, Vec<&VarStr>>::new();
84                for (k, v) in vars {
85                    if let VarKey::Slot(slot) = k {
86                        dups.entry(*slot).or_default().push(v);
87                    }
88                }
89
90                let dups = dups;
91                for (k, names) in dups {
92                    if names.len() == 1 {
93                        continue;
94                    }
95                    log::error!("Mismatched variable! (key = {k:?}) (names = {names:?})");
96                }
97                return Err(PicusCodegenError::ConsistencyCheckFailed {
98                    expected: io_vars.len(),
99                    actual: io_var_count,
100                });
101            }
102        }
103        Ok(())
104    }
105
106    fn optimization_pipeline(&self) -> Option<Pipeline> {
107        self.inner.borrow().optimization_pipeline()
108    }
109}
110
111impl<'c: 's, 's> Codegen<'c, 's> for PicusCodegen {
112    type FuncOutput = PicusModuleLowering;
113    type Output = PicusOutput;
114    type State = InnerState;
115    type Error = PicusCodegenError;
116
117    fn initialize(state: &'s Self::State) -> Self {
118        Self {
119            inner: state.clone(),
120        }
121    }
122
123    fn set_prime_field(&self, prime: Prime) -> Result<(), PicusCodegenError> {
124        self.inner.borrow_mut().set_prime(prime);
125        Ok(())
126    }
127
128    fn define_main_function(
129        &self,
130        advice_io: &AdviceIO,
131        instance_io: &InstanceIO,
132    ) -> Result<Self::FuncOutput, PicusCodegenError> {
133        let ep = self.inner.borrow().entrypoint();
134        let nc = self.naming_convention();
135        self.inner.borrow_mut().add_module(
136            ep,
137            mk_io(
138                instance_io.inputs().len() + advice_io.inputs().len(),
139                VarKeySeed::arg,
140                nc,
141            ),
142            mk_io(
143                instance_io.outputs().len() + advice_io.outputs().len(),
144                VarKeySeed::field,
145                nc,
146            ),
147        )
148    }
149
150    fn on_scope_end(&self, _scope: Self::FuncOutput) -> Result<(), PicusCodegenError> {
151        log::debug!("Closing scope");
152        Ok(())
153    }
154
155    fn generate_output(self) -> Result<Self::Output, PicusCodegenError> {
156        let mut output = PicusOutput::new(
157            self.inner.borrow().prime()?,
158            self.inner.borrow().modules().to_vec(),
159        );
160        self.var_consistency_check(&output)?;
161        if let Some(mut opt) = self.optimization_pipeline() {
162            opt.optimize(&mut output)?;
163        }
164        Ok(output)
165    }
166
167    fn define_function(
168        &self,
169        name: &str,
170        inputs: usize,
171        outputs: usize,
172    ) -> Result<Self::FuncOutput, PicusCodegenError> {
173        let nc = self.naming_convention();
174        self.inner.borrow_mut().add_module(
175            name.to_owned(),
176            mk_io(inputs, VarKeySeed::arg, nc),
177            mk_io(outputs, VarKeySeed::field, nc),
178        )
179    }
180}
181
182/// Error type used by [`PicusCodegen`].
183#[derive(Debug, thiserror::Error)]
184pub enum PicusCodegenError {
185    /// Wraps a lowering error.
186    #[error(transparent)]
187    Lowering(#[from] haloumi_lowering::error::Error),
188    /// Wraps an IR related error.
189    #[error(transparent)]
190    IR(#[from] haloumi_ir_gen::error::Error),
191    /// Wraps a optimization pass error.
192    #[error("optimization pass error: {0}")]
193    Pass(crate::pcl::opt::PassError),
194    /// Consistency check.
195    #[error(
196        "Inconsistency detected in circuit variables. Was expecting {expected} IO variables by {actual} were generated"
197    )]
198    ConsistencyCheckFailed {
199        /// Expected number of variables.
200        expected: usize,
201        /// Actual number of variables.
202        actual: usize,
203    },
204    /// Prime not set for program.
205    #[error("Prime was not set!")]
206    PrimeNotSet,
207}
208
209impl From<crate::pcl::opt::PassError> for PicusCodegenError {
210    fn from(value: crate::pcl::opt::PassError) -> Self {
211        Self::Pass(value)
212    }
213}