Skip to main content

haloumi_ir/
circuit.rs

1//! Types for representing a complete circuit.
2
3use crate::{
4    diagnostics::{Diagnostic, Validation},
5    groups::IRGroup,
6    printer::{IRPrintable, IRPrinter},
7    traits::{Canonicalize, ConstantFolding, Validatable},
8};
9
10/// Generic type representing a circuit.
11///
12/// Is parametrized on the expression type and the type used to represent the external context
13/// relative to the circuit.
14#[derive(Debug)]
15pub struct IRCircuit<E, C> {
16    body: Vec<IRGroup<E>>,
17    context: C,
18}
19
20impl<E, C> IRCircuit<E, C> {
21    /// Creates a new circuit.
22    pub fn new(body: Vec<IRGroup<E>>, context: C) -> Self {
23        Self { body, context }
24    }
25
26    /// Returns a reference to the context.
27    pub fn context(&self) -> &C {
28        &self.context
29    }
30
31    /// Returns a list of the groups inside the circuit.
32    pub fn body(&self) -> &[IRGroup<E>] {
33        &self.body
34    }
35
36    /// Returns a list of mutable references to the groups inside the circuit.
37    pub fn body_mut(&mut self) -> &mut [IRGroup<E>] {
38        &mut self.body
39    }
40
41    /// Returns the body of the circuit, consuming it.
42    pub fn take_body(self) -> Vec<IRGroup<E>> {
43        self.body
44    }
45
46    /// Returns a printer of the circuit.
47    pub fn display(&self) -> IRPrinter<'_>
48    where
49        Self: IRPrintable,
50    {
51        IRPrinter::from(self)
52    }
53
54    /// Returns the main group.
55    ///
56    /// Panics if there isn't a main group.
57    pub fn main(&self) -> &IRGroup<E> {
58        // Reverse the iterator because the main group is likely to be the last one.
59        self.body
60            .iter()
61            .rev()
62            .find(|g| g.is_main())
63            .expect("A main group is required")
64    }
65
66    ///// Validates the IR, returning errors if it failed.
67    //pub fn validate(&self) -> (Result<(), ValidationFailed>, Vec<String>) {
68    //    let mut errors = vec![];
69    //
70    //    for group in &self.body {
71    //        let (status, group_errors) = group.validate(&self.body);
72    //        if status.is_err() {
73    //            for err in group_errors {
74    //                errors.push(format!("Error in group \"{}\": {err}", group.name()));
75    //            }
76    //        }
77    //    }
78    //
79    //    (
80    //        if errors.is_empty() {
81    //            Ok(())
82    //        } else {
83    //            Err(ValidationFailed {
84    //                name: self
85    //                    .body
86    //                    .iter()
87    //                    .find_map(|g| g.is_main().then_some(g.name()))
88    //                    .unwrap_or("circuit")
89    //                    .to_string(),
90    //                error_count: errors.len(),
91    //            })
92    //        },
93    //        errors,
94    //    )
95    //}
96}
97
98impl<E, C, D> Validatable for IRCircuit<E, C>
99where
100    IRGroup<E>: Validatable<Diagnostic = D, Context = [IRGroup<E>]>,
101    D: Diagnostic,
102{
103    type Diagnostic = D;
104
105    type Context = ();
106
107    fn validate_with_context(
108        &self,
109        _: &Self::Context,
110    ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
111        let mut validation = Validation::new();
112
113        for group in &self.body {
114            let header = format!("in group \"{}\"", group.name());
115            let result = group.validate_with_context(&self.body);
116            validation.append_from_result(result, &header);
117        }
118        validation.into()
119    }
120}
121
122impl<E, C, Err> ConstantFolding for IRCircuit<E, C>
123where
124    IRGroup<E>: ConstantFolding<Error = Err>,
125{
126    type Error = Err;
127
128    type T = ();
129
130    fn constant_fold(&mut self) -> Result<(), Self::Error> {
131        self.body.constant_fold()
132    }
133}
134
135impl<E, C> Canonicalize for IRCircuit<E, C>
136where
137    IRGroup<E>: Canonicalize,
138{
139    fn canonicalize(&mut self) {
140        self.body.canonicalize()
141    }
142}
143
144impl<E: IRPrintable, C: IRPrintable> IRPrintable for IRCircuit<E, C> {
145    fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
146        self.context().fmt(ctx)?;
147        for group in self.body() {
148            ctx.nl()?;
149            group.fmt(ctx)?;
150        }
151        Ok(())
152    }
153}