Skip to main content

haloumi_ir/
groups.rs

1//! IR for representing groups.
2
3use crate::{
4    diagnostics::{SimpleDiagnostic, Validation},
5    groups::callsite::CallSite,
6    printer::IRPrintable,
7    stmt::IRStmt,
8    traits::{Canonicalize, ConstantFolding, Validatable},
9};
10use eqv::{EqvRelation, equiv};
11use haloumi_core::eqv::SymbolicEqv;
12use haloumi_lowering::{
13    Lowering,
14    lowerable::{LowerableExpr, LowerableStmt},
15};
16use std::fmt::Write;
17use thiserror::Error;
18
19/// Uniquely identifies groups that represent the same semantics.
20pub type GroupKey = u64;
21
22pub mod callsite;
23
24/// Body of a group.
25#[derive(Debug, Clone)]
26pub struct IRGroup<E> {
27    name: String,
28    /// Index in the original groups array.
29    id: usize,
30    input_count: usize,
31    output_count: usize,
32    key: Option<GroupKey>,
33    gates: IRStmt<E>,
34    eq_constraints: IRStmt<E>,
35    callsites: Vec<CallSite<E>>,
36    lookups: IRStmt<E>,
37    injected: Vec<IRStmt<E>>,
38    generate_debug_comments: bool,
39}
40
41impl<E> IRGroup<E> {
42    /// Creates a new empty group.
43    pub fn new(name: String, id: usize) -> Self {
44        Self {
45            name,
46            id,
47            input_count: Default::default(),
48            output_count: Default::default(),
49            key: Default::default(),
50            gates: Default::default(),
51            eq_constraints: Default::default(),
52            callsites: Default::default(),
53            lookups: Default::default(),
54            injected: Default::default(),
55            generate_debug_comments: Default::default(),
56        }
57    }
58
59    /// Updates the input count of the group.
60    pub fn with_input_count(mut self, input_count: usize) -> Self {
61        self.input_count = input_count;
62        self
63    }
64
65    /// Updates the output count of the group.
66    pub fn with_output_count(mut self, output_count: usize) -> Self {
67        self.output_count = output_count;
68        self
69    }
70
71    /// Updates the group key.
72    pub fn with_key(mut self, key: Option<GroupKey>) -> Self {
73        self.key = key;
74        self
75    }
76
77    /// Removes the group key.
78    pub fn no_key(mut self) -> Self {
79        self.key = None;
80        self
81    }
82
83    /// Updates the IR of the PLONK gates.
84    pub fn with_gates(mut self, gates: impl IntoIterator<Item = IRStmt<E>>) -> Self {
85        self.gates = IRStmt::from_iter(gates);
86        self
87    }
88
89    /// Updates the IR of the copy constraints.
90    pub fn with_copy_constraints(
91        mut self,
92        constraints: impl IntoIterator<Item = IRStmt<E>>,
93    ) -> Self {
94        self.eq_constraints = IRStmt::from_iter(constraints);
95        self
96    }
97
98    /// Adds a callsite to the group.
99    pub fn with_callsites(mut self, callsites: impl IntoIterator<Item = CallSite<E>>) -> Self {
100        self.callsites = Vec::from_iter(callsites);
101        self
102    }
103
104    /// Updates the IR of the lookups.
105    pub fn with_lookups(mut self, lookups: impl IntoIterator<Item = IRStmt<E>>) -> Self {
106        self.lookups = IRStmt::from_iter(lookups);
107        self
108    }
109
110    /// Injects IR into the body of the group.
111    pub fn inject(&mut self, ir: IRStmt<E>) {
112        self.injected.push(ir);
113    }
114
115    /// Returns the number of statements injected into this group.
116    pub fn injected_count(&self) -> usize {
117        self.injected.len()
118    }
119
120    /// Sets the flag that control the generation of debug comments.
121    pub fn do_debug_comments(mut self, do_it: bool) -> Self {
122        self.generate_debug_comments = do_it;
123        self
124    }
125
126    /// Returns true if the group is the top-level.
127    pub fn is_main(&self) -> bool {
128        self.key.is_none()
129    }
130
131    /// Returns the name of the group.
132    pub fn name(&self) -> &str {
133        &self.name
134    }
135
136    /// Returns a mutable reference to the group's name.
137    pub fn name_mut(&mut self) -> &mut String {
138        &mut self.name
139    }
140
141    /// Returns the id of the group.
142    pub fn id(&self) -> usize {
143        self.id
144    }
145
146    /// Sets the id of the group.
147    pub fn set_id(&mut self, id: usize) {
148        self.id = id;
149    }
150
151    /// Returns the number of inputs.
152    pub fn input_count(&self) -> usize {
153        self.input_count
154    }
155
156    /// Returns the number of outputs.
157    pub fn output_count(&self) -> usize {
158        self.output_count
159    }
160
161    /// Returns the key of the group.
162    pub fn key(&self) -> Option<GroupKey> {
163        self.key
164    }
165
166    /// Returns a referece to the callsites.
167    pub fn callsites(&self) -> &[CallSite<E>] {
168        &self.callsites
169    }
170
171    /// Returns a mutable referece to the callsites.
172    pub fn callsites_mut(&mut self) -> &mut Vec<CallSite<E>> {
173        &mut self.callsites
174    }
175
176    /// Returns an iterator with all the [`IRStmt`] in the group.
177    pub fn statements(&self) -> impl Iterator<Item = &IRStmt<E>> {
178        self.gates
179            .iter()
180            .chain(self.eq_constraints.iter())
181            .chain(self.lookups.iter())
182            .chain(self.injected.iter().flatten())
183    }
184
185    /// Tries to convert the inner expression type to another.
186    pub fn try_map<O, Err>(
187        self,
188        f: &mut impl FnMut(E) -> Result<O, Err>,
189    ) -> Result<IRGroup<O>, Err> {
190        Ok(IRGroup {
191            name: self.name,
192            id: self.id,
193            input_count: self.input_count,
194            output_count: self.output_count,
195            key: self.key,
196            gates: self.gates.try_map(f)?,
197            eq_constraints: self.eq_constraints.try_map(f)?,
198            callsites: self
199                .callsites
200                .into_iter()
201                .map(|cs| cs.try_map(f))
202                .collect::<Result<Vec<_>, _>>()?,
203            lookups: self.lookups.try_map(f)?,
204            injected: self
205                .injected
206                .into_iter()
207                .map(|i| i.try_map(f))
208                .collect::<Result<Vec<_>, _>>()?,
209            generate_debug_comments: self.generate_debug_comments,
210        })
211    }
212
213    fn validate_callsite(
214        &self,
215        callsite: &CallSite<E>,
216        groups: &[Self],
217    ) -> Result<(), ValidationErrors> {
218        let callee_id = callsite.callee_id();
219        let callee = groups
220            .get(callee_id)
221            .ok_or(ValidationErrors::CalleeNotFound { callee_id })?;
222        if callee.id() != callsite.callee_id() {
223            return Err(ValidationErrors::WrongCallee {
224                callsite_name: callsite.name().to_string(),
225                callsite_id: callee_id,
226                callee_name: callee.name().to_string(),
227                callee_id: callee.id(),
228            });
229        }
230        if callee.input_count != callsite.inputs().len() {
231            return Err(ValidationErrors::UnexpectedInputs {
232                callee_name: callee.name().to_string(),
233                callee_id: callee.id(),
234                callee_count: callee.input_count,
235                callsite_count: callsite.inputs().len(),
236            });
237        }
238        if callee.output_count != callsite.outputs().len() {
239            return Err(ValidationErrors::UnexpectedOutputs {
240                callee_name: callee.name().to_string(),
241                callee_id: callee.id(),
242                callee_count: callee.output_count,
243                callsite_count: callsite.outputs().len(),
244            });
245        }
246        if callsite.outputs().len() != callsite.output_vars().len() {
247            return Err(ValidationErrors::UnexpectedOutputsVars {
248                callsite_name: callsite.name().to_string(),
249                callsite_id: callee_id,
250                callsite_count: callsite.outputs().len(),
251                callsite_vars_count: callsite.output_vars().len(),
252            });
253        }
254
255        Ok(())
256    }
257
258    /// Returns a mutable reference to the copy constraints.
259    pub fn eq_constraints_mut(&mut self) -> &mut IRStmt<E> {
260        &mut self.eq_constraints
261    }
262}
263
264impl<E> Validatable for IRGroup<E>
265where
266    IRStmt<E>: Validatable<Diagnostic = SimpleDiagnostic, Context = ()>,
267{
268    type Diagnostic = SimpleDiagnostic;
269
270    type Context = [Self];
271
272    fn validate_with_context(
273        &self,
274        groups: &Self::Context,
275    ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
276        let mut validation = Validation::new();
277
278        // Check 1. Consistency of callsites arity.
279        validation.with_errors(self.callsites().iter().enumerate().filter_map(
280            |(call_no, callsite)| match self.validate_callsite(callsite, groups) {
281                Ok(_) => None,
282                Err(err) => Some(SimpleDiagnostic::error(format!(
283                    "on callsite {call_no}: {err}"
284                ))),
285            },
286        ));
287
288        // Check 2. Each's statement validation.
289        validation.append_from_result(self.gates.validate(), "on gates");
290        validation.append_from_result(self.eq_constraints.validate(), "on copy constraints");
291        validation.append_from_result(self.lookups.validate(), "on lookups");
292        for ir in &self.injected {
293            validation.append_from_result(ir.validate(), "on injected ir");
294        }
295
296        validation.into()
297    }
298}
299
300/// Error type for when validation fails.
301#[derive(Error, Debug)]
302#[error("Validation of group {name} failed with {error_count} errors")]
303pub struct ValidationFailed {
304    pub(crate) name: String,
305    pub(crate) error_count: usize,
306}
307
308impl ValidationFailed {
309    /// Returns the number of errors
310    pub fn error_count(&self) -> usize {
311        self.error_count
312    }
313}
314
315#[derive(Error, Debug)]
316enum ValidationErrors {
317    #[error("Callee with id {callee_id} was not found")]
318    CalleeNotFound { callee_id: usize },
319    #[error(
320        "Callsite points to \"{callsite_name}\" ({callsite_id}) but callee was \"{callee_name}\" ({callee_id})"
321    )]
322    WrongCallee {
323        callsite_name: String,
324        callsite_id: usize,
325        callee_name: String,
326        callee_id: usize,
327    },
328    #[error(
329        "Callee \"{callee_name}\" ({callee_id}) was expecting {callee_count} inputs but callsite has {callsite_count}"
330    )]
331    UnexpectedInputs {
332        callee_name: String,
333        callee_id: usize,
334        callee_count: usize,
335        callsite_count: usize,
336    },
337    #[error(
338        "Callee \"{callee_name}\" ({callee_id}) was expecting {callee_count} outputs but callsite has {callsite_count}"
339    )]
340    UnexpectedOutputs {
341        callee_name: String,
342        callee_id: usize,
343        callee_count: usize,
344        callsite_count: usize,
345    },
346    #[error(
347        "Call to \"{callsite_name}\" ({callsite_id}) has {callsite_count} outputs but declared {callsite_vars_count} output variables"
348    )]
349    UnexpectedOutputsVars {
350        callsite_name: String,
351        callsite_id: usize,
352        callsite_count: usize,
353        callsite_vars_count: usize,
354    },
355}
356
357impl<E: ConstantFolding> ConstantFolding for IRGroup<E>
358where
359    IRStmt<E>: ConstantFolding,
360{
361    type Error = ConstantFoldingError<E>;
362
363    type T = ();
364
365    fn constant_fold(&mut self) -> Result<(), Self::Error> {
366        self.gates
367            .constant_fold()
368            .map_err(ConstantFoldingError::Stmt)?;
369        self.eq_constraints
370            .constant_fold()
371            .map_err(ConstantFoldingError::Stmt)?;
372        self.lookups
373            .constant_fold()
374            .map_err(ConstantFoldingError::Stmt)?;
375        self.callsites
376            .constant_fold()
377            .map_err(ConstantFoldingError::CallsiteArg)?;
378        self.injected
379            .constant_fold()
380            .map_err(ConstantFoldingError::Stmt)
381    }
382}
383
384/// Error type for constant folding of [`IRGroup`].
385#[derive(Error)]
386pub enum ConstantFoldingError<E>
387where
388    IRStmt<E>: ConstantFolding,
389    E: ConstantFolding,
390{
391    /// Case where the error happened while folding a statement.
392    #[error(transparent)]
393    Stmt(<IRStmt<E> as ConstantFolding>::Error),
394    /// Case where the error happened while folding a callsite argument.
395    #[error(transparent)]
396    CallsiteArg(<E as ConstantFolding>::Error),
397}
398
399impl<E, Err1, Err2> std::fmt::Debug for ConstantFoldingError<E>
400where
401    IRStmt<E>: ConstantFolding<Error = Err1>,
402    E: ConstantFolding<Error = Err2>,
403    Err1: std::fmt::Debug,
404    Err2: std::fmt::Debug,
405{
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        match self {
408            Self::Stmt(e) => std::fmt::Debug::fmt(e, f),
409            Self::CallsiteArg(e) => std::fmt::Debug::fmt(e, f),
410        }
411    }
412}
413
414impl<E> Canonicalize for IRGroup<E>
415where
416    IRStmt<E>: Canonicalize,
417    CallSite<E>: Canonicalize,
418{
419    fn canonicalize(&mut self) {
420        self.gates.canonicalize();
421        self.eq_constraints.canonicalize();
422        self.lookups.canonicalize();
423        self.callsites.canonicalize();
424        self.injected.canonicalize();
425    }
426}
427
428impl<E> EqvRelation<IRGroup<E>> for SymbolicEqv
429where
430    SymbolicEqv: EqvRelation<E>,
431{
432    /// Two groups are equivalent if the code they represent is equivalent and have the same key.
433    ///
434    /// Special case is main which is never equivalent to anything.
435    fn equivalent(lhs: &IRGroup<E>, rhs: &IRGroup<E>) -> bool {
436        // Main is never equivalent to others
437        if lhs.is_main() || rhs.is_main() {
438            return false;
439        }
440
441        let lhs_key = lhs.key.unwrap();
442        let rhs_key = rhs.key.unwrap();
443
444        let k = lhs_key == rhs_key;
445        log::debug!("[equivalent({} ~ {})] key: {k}", lhs.id(), rhs.id());
446        let io = lhs.input_count == rhs.input_count && lhs.output_count == rhs.output_count;
447        log::debug!("[equivalent({} ~ {})] io: {io}", lhs.id(), rhs.id());
448        let gates = equiv! { Self | &lhs.gates, &rhs.gates };
449        log::debug!("[equivalent({} ~ {})] gates: {gates}", lhs.id(), rhs.id());
450        let eqc = equiv! { Self | &lhs.eq_constraints, &rhs.eq_constraints };
451        log::debug!("[equivalent({} ~ {})] eqc: {eqc}", lhs.id(), rhs.id());
452        let lookups = equiv! { Self | &lhs.lookups, &rhs.lookups };
453        log::debug!(
454            "[equivalent({} ~ {})] lookups: {lookups}",
455            lhs.id(),
456            rhs.id()
457        );
458        let callsites = equiv! { Self | &lhs.callsites, &rhs.callsites };
459        log::debug!(
460            "[equivalent({} ~ {})] callsites: {callsites}",
461            lhs.id(),
462            rhs.id()
463        );
464
465        k && io && gates && eqc && lookups && callsites
466    }
467}
468
469impl<E> LowerableStmt for IRGroup<E>
470where
471    E: LowerableExpr + std::fmt::Debug,
472    CallSite<E>: LowerableStmt,
473    IRStmt<E>: LowerableStmt,
474{
475    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
476    where
477        L: Lowering + ?Sized,
478    {
479        log::debug!("Lowering {self:?}");
480        if self.generate_debug_comments {
481            l.generate_comment("Calls to subgroups".to_owned())?;
482        }
483        log::debug!("  Lowering callsites");
484        for callsite in self.callsites {
485            callsite.lower(l)?;
486        }
487        if self.generate_debug_comments {
488            l.generate_comment("Gate constraints".to_owned())?;
489        }
490        log::debug!("  Lowering gates");
491        self.gates.lower(l)?;
492        if self.generate_debug_comments {
493            l.generate_comment("Equality constraints".to_owned())?;
494        }
495        log::debug!("  Lowering equality constraints");
496        self.eq_constraints.lower(l)?;
497        if self.generate_debug_comments {
498            l.generate_comment("Lookups".to_owned())?;
499        }
500        log::debug!("  Lowering lookups");
501        self.lookups.lower(l)?;
502        if self.generate_debug_comments {
503            l.generate_comment("Injected".to_owned())?;
504        }
505        log::debug!("  Lowering injected IR");
506        for stmt in self.injected {
507            stmt.lower(l)?;
508        }
509
510        Ok(())
511    }
512}
513
514impl<E: IRPrintable> IRPrintable for IRGroup<E> {
515    fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
516        ctx.block("group", |ctx| {
517            writeln!(
518                ctx,
519                "{} \"{}\" (inputs {}) (outputs {})",
520                self.id(),
521                self.name(),
522                self.input_count(),
523                self.output_count()
524            )?;
525
526            for callsite in self.callsites() {
527                ctx.fmt_call(
528                    callsite.name(),
529                    callsite.inputs(),
530                    callsite.output_vars(),
531                    Some(callsite.callee_id()),
532                )?;
533                ctx.nl()?;
534            }
535
536            for stmt in self.statements() {
537                stmt.fmt(ctx)?;
538                ctx.nl()?;
539            }
540
541            Ok(())
542        })
543    }
544}