use std::{
collections::HashMap,
sync::{Arc, LazyLock, Mutex},
};
use serde::{Deserialize, Serialize};
use slop_algebra::{extension::BinomialExtensionField, ExtensionField, Field};
use crate::{
air::{AirInteraction, InteractionScope},
ir::{Attribute, BinOp, ExprExtRef, ExprRef, FuncDecl, IrVar, OpExpr, Shape},
InteractionKind,
};
use sp1_primitives::SP1Field;
type F = SP1Field;
type EF = BinomialExtensionField<SP1Field, 4>;
type AstType = Ast<ExprRef<F>, ExprExtRef<EF>>;
pub static GLOBAL_AST: LazyLock<Arc<Mutex<AstType>>> =
LazyLock::new(|| Arc::new(Mutex::new(Ast::new())));
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ast<Expr, ExprExt> {
assignments: Vec<usize>,
ext_assignments: Vec<usize>,
operations: Vec<OpExpr<Expr, ExprExt>>,
}
impl<F: Field, EF: ExtensionField<F>> Ast<ExprRef<F>, ExprExtRef<EF>> {
#[must_use]
pub fn new() -> Self {
Self { assignments: vec![], ext_assignments: vec![], operations: vec![] }
}
pub fn alloc(&mut self) -> ExprRef<F> {
let id = self.assignments.len();
self.assignments.push(self.operations.len());
ExprRef::Expr(id)
}
pub fn alloc_array<const N: usize>(&mut self) -> [ExprRef<F>; N] {
core::array::from_fn(|_| self.alloc())
}
pub fn assign(&mut self, a: ExprRef<F>, b: ExprRef<F>) {
let op = OpExpr::Assign(a, b);
self.operations.push(op);
}
pub fn alloc_ext(&mut self) -> ExprExtRef<EF> {
let id = self.ext_assignments.len();
self.ext_assignments.push(self.operations.len());
ExprExtRef::Expr(id)
}
pub fn assert_zero(&mut self, x: ExprRef<F>) {
let op = OpExpr::AssertZero(x);
self.operations.push(op);
}
pub fn assert_ext_zero(&mut self, x: ExprExtRef<EF>) {
let op = OpExpr::AssertExtZero(x);
self.operations.push(op);
}
pub fn bin_op(&mut self, op: BinOp, a: ExprRef<F>, b: ExprRef<F>) -> ExprRef<F> {
let result = self.alloc();
let op = OpExpr::BinOp(op, result, a, b);
self.operations.push(op);
result
}
pub fn negate(&mut self, a: ExprRef<F>) -> ExprRef<F> {
let result = self.alloc();
let op = OpExpr::Neg(result, a);
self.operations.push(op);
result
}
pub fn bin_op_ext(
&mut self,
op: BinOp,
a: ExprExtRef<EF>,
b: ExprExtRef<EF>,
) -> ExprExtRef<EF> {
let result = self.alloc_ext();
let op = OpExpr::BinOpExt(op, result, a, b);
self.operations.push(op);
result
}
pub fn bin_op_base_ext(
&mut self,
op: BinOp,
a: ExprExtRef<EF>,
b: ExprRef<F>,
) -> ExprExtRef<EF> {
let result = self.alloc_ext();
let op = OpExpr::BinOpBaseExt(op, result, a, b);
self.operations.push(op);
result
}
pub fn neg_ext(&mut self, a: ExprExtRef<EF>) -> ExprExtRef<EF> {
let result = self.alloc_ext();
let op = OpExpr::NegExt(result, a);
self.operations.push(op);
result
}
pub fn ext_from_base(&mut self, a: ExprRef<F>) -> ExprExtRef<EF> {
let result = self.alloc_ext();
let op = OpExpr::ExtFromBase(result, a);
self.operations.push(op);
result
}
pub fn send(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
let op = OpExpr::Send(message, scope);
self.operations.push(op);
}
pub fn receive(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
let op = OpExpr::Receive(message, scope);
self.operations.push(op);
}
#[must_use]
pub fn to_string_pretty(&self, prefix: &str) -> String {
let mut s = String::new();
for op in &self.operations {
s.push_str(&format!("{prefix}{op}\n"));
}
s
}
pub fn call_operation(
&mut self,
name: String,
inputs: Vec<(String, Attribute, Shape<ExprRef<F>, ExprExtRef<EF>>)>,
output: Shape<ExprRef<F>, ExprExtRef<EF>>,
) {
let func = FuncDecl::new(name, inputs, output);
let op = OpExpr::Call(func);
self.operations.push(op);
}
#[must_use]
pub fn to_lean_components(
&self,
mapping: &HashMap<usize, String>,
) -> (Vec<String>, Vec<String>, usize) {
let mut steps: Vec<String> = Vec::default();
let mut calls: usize = 0;
let mut constraints: Vec<String> = Vec::default();
for opexpr in &self.operations {
match opexpr {
OpExpr::AssertZero(expr) => {
constraints.push(format!("(.assertZero {})", expr.to_lean_string(mapping)));
}
OpExpr::Neg(a, b) => {
steps.push(format!(
"let {} : Fin KB := -{}",
a.expr_to_lean_string(),
b.to_lean_string(mapping),
));
}
OpExpr::BinOp(op, result, a, b) => {
let result_str = result.expr_to_lean_string();
let a_str = a.to_lean_string(mapping);
let b_str = b.to_lean_string(mapping);
match op {
BinOp::Add => {
steps.push(format!("let {result_str} : Fin KB := {a_str} + {b_str}"));
}
BinOp::Sub => {
steps.push(format!("let {result_str} : Fin KB := {a_str} - {b_str}"));
}
BinOp::Mul => {
steps.push(format!("let {result_str} : Fin KB := {a_str} * {b_str}"));
}
}
}
OpExpr::Send(interaction, _) => match interaction.kind {
InteractionKind::Byte
| InteractionKind::State
| InteractionKind::Memory
| InteractionKind::Program => {
constraints.push(format!(
"(.send {} {})",
interaction.to_lean_string(mapping),
interaction.multiplicity.to_lean_string(mapping)
));
}
_ => {}
},
OpExpr::Receive(interaction, _) => match interaction.kind {
InteractionKind::Byte
| InteractionKind::State
| InteractionKind::Memory
| InteractionKind::Program => {
constraints.push(format!(
"(.receive {} {})",
interaction.to_lean_string(mapping),
interaction.multiplicity.to_lean_string(mapping),
));
}
_ => {}
},
OpExpr::Call(decl) => {
let mut step = String::new();
match decl.output {
Shape::Unit => {
step.push_str(&format!("let CS{calls} : SP1ConstraintList := "));
}
_ => {
step.push_str(&format!(
"let ⟨{}, CS{}⟩ := ",
decl.output.to_lean_destructor(),
calls,
));
}
}
step.push_str(&format!("{}.constraints", decl.name));
for input in &decl.input {
step.push(' ');
step.push_str(&input.2.to_lean_constructor(mapping));
}
calls += 1;
steps.push(step);
}
OpExpr::Assign(ExprRef::IrVar(IrVar::OutputArg(_)), _) => {
}
_ => todo!(),
}
}
(steps, constraints, calls)
}
}
impl<F: Field, EF: ExtensionField<F>> Default for Ast<ExprRef<F>, ExprExtRef<EF>> {
fn default() -> Self {
Self::new()
}
}