Skip to main content

sp1_hypercube/ir/
ast.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, LazyLock, Mutex},
4};
5
6use serde::{Deserialize, Serialize};
7use slop_algebra::{extension::BinomialExtensionField, ExtensionField, Field};
8
9use crate::{
10    air::{AirInteraction, InteractionScope},
11    ir::{Attribute, BinOp, ExprExtRef, ExprRef, FuncDecl, IrVar, OpExpr, Shape},
12    InteractionKind,
13};
14
15use sp1_primitives::SP1Field;
16type F = SP1Field;
17type EF = BinomialExtensionField<SP1Field, 4>;
18
19type AstType = Ast<ExprRef<F>, ExprExtRef<EF>>;
20
21/// This should only be used under two scenarios:
22/// 1. In the `SP1OperationBuilder` macro.
23/// 2. When `SP1OperationBuilder` doesn't do its job and you need to implement `SP1Operation`
24///    manually.
25pub static GLOBAL_AST: LazyLock<Arc<Mutex<AstType>>> =
26    LazyLock::new(|| Arc::new(Mutex::new(Ast::new())));
27
28/// Ast for the constraint compiler.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Ast<Expr, ExprExt> {
31    assignments: Vec<usize>,
32    ext_assignments: Vec<usize>,
33    operations: Vec<OpExpr<Expr, ExprExt>>,
34}
35
36impl<F: Field, EF: ExtensionField<F>> Ast<ExprRef<F>, ExprExtRef<EF>> {
37    /// Constructs a new AST.
38    #[must_use]
39    pub fn new() -> Self {
40        Self { assignments: vec![], ext_assignments: vec![], operations: vec![] }
41    }
42
43    /// Allocate a new [`ExprRef`] and assign the result of the **next** operation
44    /// to it.
45    ///
46    /// In practice, this usually means after calling [`Self::alloc`] you would have to push an
47    /// [`OpExpr`]. For example, when `a = b + c`, call [`Self::alloc`] to allocate the LHS, and
48    /// then push the [`OpExpr`] representing `b + c` to `self.operations`.
49    pub fn alloc(&mut self) -> ExprRef<F> {
50        let id = self.assignments.len();
51        self.assignments.push(self.operations.len());
52        ExprRef::Expr(id)
53    }
54
55    /// Allocate an array of [`ExprRef`] of constant size using [`Self::alloc`] and assign all of
56    /// them to the **next** operation.
57    pub fn alloc_array<const N: usize>(&mut self) -> [ExprRef<F>; N] {
58        core::array::from_fn(|_| self.alloc())
59    }
60
61    /// Push an assignment operation.
62    pub fn assign(&mut self, a: ExprRef<F>, b: ExprRef<F>) {
63        let op = OpExpr::Assign(a, b);
64        self.operations.push(op);
65    }
66
67    /// Same as [`Self::alloc`] but for [`ExprExtRef`]
68    pub fn alloc_ext(&mut self) -> ExprExtRef<EF> {
69        let id = self.ext_assignments.len();
70        self.ext_assignments.push(self.operations.len());
71        ExprExtRef::Expr(id)
72    }
73
74    /// Push an operation that asserts [`ExprRef`] x is zero.
75    pub fn assert_zero(&mut self, x: ExprRef<F>) {
76        let op = OpExpr::AssertZero(x);
77        self.operations.push(op);
78    }
79
80    /// Same for [`Self::assert_zero`] but for [`ExprExtRef`].
81    pub fn assert_ext_zero(&mut self, x: ExprExtRef<EF>) {
82        let op = OpExpr::AssertExtZero(x);
83        self.operations.push(op);
84    }
85
86    /// Records a binary operation and returns a new [`ExprRef`] that represents the result of this
87    /// operation.
88    pub fn bin_op(&mut self, op: BinOp, a: ExprRef<F>, b: ExprRef<F>) -> ExprRef<F> {
89        let result = self.alloc();
90        let op = OpExpr::BinOp(op, result, a, b);
91        self.operations.push(op);
92        result
93    }
94
95    /// Same with [`Self::bin_op`] but specifically for negation.
96    pub fn negate(&mut self, a: ExprRef<F>) -> ExprRef<F> {
97        let result = self.alloc();
98        let op = OpExpr::Neg(result, a);
99        self.operations.push(op);
100        result
101    }
102
103    /// Same with [`Self::bin_op`] but for [`ExprExtRef`].
104    pub fn bin_op_ext(
105        &mut self,
106        op: BinOp,
107        a: ExprExtRef<EF>,
108        b: ExprExtRef<EF>,
109    ) -> ExprExtRef<EF> {
110        let result = self.alloc_ext();
111        let op = OpExpr::BinOpExt(op, result, a, b);
112        self.operations.push(op);
113        result
114    }
115
116    /// Same with [`Self::bin_op`] but for [`ExprExtRef`] and [`ExprRef`].
117    pub fn bin_op_base_ext(
118        &mut self,
119        op: BinOp,
120        a: ExprExtRef<EF>,
121        b: ExprRef<F>,
122    ) -> ExprExtRef<EF> {
123        let result = self.alloc_ext();
124        let op = OpExpr::BinOpBaseExt(op, result, a, b);
125        self.operations.push(op);
126        result
127    }
128
129    /// Same with [`Self::neg`] but for [`ExprExtRef`]
130    pub fn neg_ext(&mut self, a: ExprExtRef<EF>) -> ExprExtRef<EF> {
131        let result = self.alloc_ext();
132        let op = OpExpr::NegExt(result, a);
133        self.operations.push(op);
134        result
135    }
136
137    /// Get an [`ExprExtRef`] from [`ExprRef`]
138    pub fn ext_from_base(&mut self, a: ExprRef<F>) -> ExprExtRef<EF> {
139        let result = self.alloc_ext();
140        let op = OpExpr::ExtFromBase(result, a);
141        self.operations.push(op);
142        result
143    }
144
145    /// Records a send [`AirInteraction`]
146    pub fn send(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
147        let op = OpExpr::Send(message, scope);
148        self.operations.push(op);
149    }
150
151    /// Records a receive [`AirInteraction`]
152    pub fn receive(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
153        let op = OpExpr::Receive(message, scope);
154        self.operations.push(op);
155    }
156
157    /// A [String] of all the operations with [prefix] padding in the front.
158    #[must_use]
159    pub fn to_string_pretty(&self, prefix: &str) -> String {
160        let mut s = String::new();
161        for op in &self.operations {
162            s.push_str(&format!("{prefix}{op}\n"));
163        }
164        s
165    }
166
167    /// Records an operation that represents a function call.
168    pub fn call_operation(
169        &mut self,
170        name: String,
171        inputs: Vec<(String, Attribute, Shape<ExprRef<F>, ExprExtRef<EF>>)>,
172        output: Shape<ExprRef<F>, ExprExtRef<EF>>,
173    ) {
174        let func = FuncDecl::new(name, inputs, output);
175        let op = OpExpr::Call(func);
176        self.operations.push(op);
177    }
178
179    /// Go through the AST and returns a tuple that contains:
180    /// 1. All the evaluation steps and function calls.
181    /// 2. All the constraints.
182    /// 3. Number of calls.
183    #[must_use]
184    pub fn to_lean_components(
185        &self,
186        mapping: &HashMap<usize, String>,
187    ) -> (Vec<String>, Vec<String>, usize) {
188        let mut steps: Vec<String> = Vec::default();
189        let mut calls: usize = 0;
190        let mut constraints: Vec<String> = Vec::default();
191
192        for opexpr in &self.operations {
193            match opexpr {
194                OpExpr::AssertZero(expr) => {
195                    constraints.push(format!("(.assertZero {})", expr.to_lean_string(mapping)));
196                }
197                OpExpr::Neg(a, b) => {
198                    steps.push(format!(
199                        "let {} : Fin KB := -{}",
200                        a.expr_to_lean_string(),
201                        b.to_lean_string(mapping),
202                    ));
203                }
204                OpExpr::BinOp(op, result, a, b) => {
205                    let result_str = result.expr_to_lean_string();
206                    let a_str = a.to_lean_string(mapping);
207                    let b_str = b.to_lean_string(mapping);
208                    match op {
209                        BinOp::Add => {
210                            steps.push(format!("let {result_str} : Fin KB := {a_str} + {b_str}"));
211                        }
212                        BinOp::Sub => {
213                            steps.push(format!("let {result_str} : Fin KB := {a_str} - {b_str}"));
214                        }
215                        BinOp::Mul => {
216                            steps.push(format!("let {result_str} : Fin KB := {a_str} * {b_str}"));
217                        }
218                    }
219                }
220                OpExpr::Send(interaction, _) => match interaction.kind {
221                    InteractionKind::Byte
222                    | InteractionKind::State
223                    | InteractionKind::Memory
224                    | InteractionKind::Program => {
225                        constraints.push(format!(
226                            "(.send {} {})",
227                            interaction.to_lean_string(mapping),
228                            interaction.multiplicity.to_lean_string(mapping)
229                        ));
230                    }
231                    _ => {}
232                },
233                OpExpr::Receive(interaction, _) => match interaction.kind {
234                    InteractionKind::Byte
235                    | InteractionKind::State
236                    | InteractionKind::Memory
237                    | InteractionKind::Program => {
238                        constraints.push(format!(
239                            "(.receive {} {})",
240                            interaction.to_lean_string(mapping),
241                            interaction.multiplicity.to_lean_string(mapping),
242                        ));
243                    }
244                    _ => {}
245                },
246                OpExpr::Call(decl) => {
247                    let mut step = String::new();
248                    match decl.output {
249                        Shape::Unit => {
250                            step.push_str(&format!("let CS{calls} : SP1ConstraintList := "));
251                        }
252                        _ => {
253                            step.push_str(&format!(
254                                "let ⟨{}, CS{}⟩ := ",
255                                decl.output.to_lean_destructor(),
256                                calls,
257                            ));
258                        }
259                    }
260
261                    step.push_str(&format!("{}.constraints", decl.name));
262
263                    for input in &decl.input {
264                        step.push(' ');
265                        step.push_str(&input.2.to_lean_constructor(mapping));
266                    }
267
268                    calls += 1;
269                    steps.push(step);
270                }
271                OpExpr::Assign(ExprRef::IrVar(IrVar::OutputArg(_)), _) => {
272                    // Output(x) are specifically ignored
273                }
274                _ => todo!(),
275            }
276        }
277
278        (steps, constraints, calls)
279    }
280}
281
282impl<F: Field, EF: ExtensionField<F>> Default for Ast<ExprRef<F>, ExprExtRef<EF>> {
283    fn default() -> Self {
284        Self::new()
285    }
286}