use std::collections::BTreeMap;
use crate::{
air::{AirInteraction, InteractionScope, MachineAir, MessageBuilder},
ir::{Ast, Attribute, ExprExtRef, ExprRef, Func, Shape, GLOBAL_AST},
};
use slop_air::{AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder};
use slop_matrix::dense::RowMajorMatrix;
use crate::ir::expr_impl::{Expr, ExprExt, EF, F};
#[derive(Clone, Debug)]
pub struct ConstraintCompiler {
public_values: Vec<Expr>,
preprocessed: RowMajorMatrix<Expr>,
main: RowMajorMatrix<Expr>,
modules: BTreeMap<String, Func<Expr, ExprExt>>,
parent: Option<Ast<ExprRef<F>, ExprExtRef<EF>>>,
}
impl ConstraintCompiler {
pub fn new<A: MachineAir<F>>(air: &A, num_public_values: usize) -> Self {
let preprocessed_width = air.preprocessed_width();
let main_width = air.width();
Self::with_sizes(num_public_values, preprocessed_width, main_width)
}
pub fn with_sizes(
num_public_values: usize,
preprocessed_width: usize,
main_width: usize,
) -> Self {
let mut ast = GLOBAL_AST.lock().unwrap();
*ast = Ast::new();
let public_values = (0..num_public_values).map(Expr::public).collect();
let preprocessed = (0..preprocessed_width).map(Expr::preprocessed).collect();
let preprocessed = RowMajorMatrix::new(preprocessed, preprocessed_width);
let main = (0..main_width).map(Expr::main).collect();
let main = RowMajorMatrix::new(main, main_width);
Self { public_values, preprocessed, main, modules: BTreeMap::new(), parent: None }
}
pub fn ast(&self) -> Ast<ExprRef<F>, ExprExtRef<EF>> {
let ast = GLOBAL_AST.lock().unwrap();
ast.clone()
}
fn region(&self) -> Self {
let parent = self.ast();
let mut ast = GLOBAL_AST.lock().unwrap();
*ast = Ast::new();
Self {
public_values: self.public_values.clone(),
preprocessed: self.preprocessed.clone(),
main: self.main.clone(),
modules: BTreeMap::new(),
parent: Some(parent),
}
}
pub fn register_module(
&mut self,
name: String,
params: Vec<(String, Attribute, Shape<ExprRef<F>, ExprExtRef<EF>>)>,
body: impl FnOnce(&mut Self) -> Shape<ExprRef<F>, ExprExtRef<EF>>,
) {
let mut body_builder = self.region();
let result = body(&mut body_builder);
let body = body_builder.ast();
let decl = crate::ir::FuncDecl::new(name.clone(), params, result);
self.modules.append(&mut body_builder.modules);
self.modules.insert(name, Func { decl, body });
}
#[must_use]
pub fn modules(&self) -> &BTreeMap<String, Func<Expr, ExprExt>> {
&self.modules
}
#[must_use]
pub fn num_cols(&self) -> usize {
self.main.width
}
}
impl Drop for ConstraintCompiler {
fn drop(&mut self) {
if let Some(parent) = self.parent.take() {
let mut ast = GLOBAL_AST.lock().unwrap();
*ast = parent;
}
}
}
impl AirBuilder for ConstraintCompiler {
type F = F;
type Expr = Expr;
type Var = Expr;
type M = RowMajorMatrix<Expr>;
fn main(&self) -> Self::M {
self.main.clone()
}
fn is_first_row(&self) -> Self::Expr {
unreachable!("first row is not supported")
}
fn is_last_row(&self) -> Self::Expr {
unreachable!("last row is not supported")
}
fn is_transition_window(&self, _size: usize) -> Self::Expr {
unreachable!("transition window is not supported")
}
fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
let x = x.into();
let mut ast = GLOBAL_AST.lock().unwrap();
ast.assert_zero(x);
}
}
impl MessageBuilder<AirInteraction<Expr>> for ConstraintCompiler {
fn send(&mut self, message: AirInteraction<Expr>, scope: InteractionScope) {
let mut ast = GLOBAL_AST.lock().unwrap();
ast.send(message, scope);
}
fn receive(&mut self, message: AirInteraction<Expr>, scope: InteractionScope) {
let mut ast = GLOBAL_AST.lock().unwrap();
ast.receive(message, scope);
}
}
impl PairBuilder for ConstraintCompiler {
fn preprocessed(&self) -> Self::M {
self.preprocessed.clone()
}
}
impl AirBuilderWithPublicValues for ConstraintCompiler {
type PublicVar = Expr;
fn public_values(&self) -> &[Self::PublicVar] {
&self.public_values
}
}
impl ExtensionBuilder for ConstraintCompiler {
type EF = EF;
type ExprEF = ExprExt;
type VarEF = ExprExt;
fn assert_zero_ext<I>(&mut self, x: I)
where
I: Into<Self::ExprEF>,
{
let x = x.into();
let mut ast = GLOBAL_AST.lock().unwrap();
ast.assert_ext_zero(x);
}
}