use std::{collections::HashMap, fmt::Display};
use serde::{Deserialize, Serialize};
use slop_algebra::{ExtensionField, Field};
use crate::ir::{Ast, ExprExtRef, ExprRef, Shape};
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PicusArg {
Input,
Output,
#[default]
Unknown,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Attribute {
pub picus: PicusArg,
}
impl Display for Attribute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.picus {
PicusArg::Input => write!(f, "#[picus(input)]"),
PicusArg::Output => write!(f, "#[picus(output)]"),
PicusArg::Unknown => Ok(()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FuncDecl<Expr, ExprExt> {
pub name: String,
pub input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
pub output: Shape<Expr, ExprExt>,
}
impl<Expr, ExprExt> FuncDecl<Expr, ExprExt> {
pub fn new(
name: String,
input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
output: Shape<Expr, ExprExt>,
) -> Self {
Self { name, input, output }
}
}
impl<F: Field, EF: ExtensionField<F>> FuncDecl<ExprRef<F>, ExprExtRef<EF>> {
pub fn input_mapping(&self) -> HashMap<usize, String> {
let mut mapping = HashMap::new();
for (name, _, arg) in &self.input {
arg.map_input(name.clone(), &mut mapping);
}
mapping
}
pub fn to_output_lean_type(&self) -> String {
match self.output {
Shape::Unit => "SP1ConstraintList".to_string(),
_ => format!("{} × SP1ConstraintList", self.output.to_lean_type()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Func<Expr, ExprExt> {
pub decl: FuncDecl<Expr, ExprExt>,
pub body: Ast<Expr, ExprExt>,
}
impl<F: Field, EF: ExtensionField<F>> Display for Func<ExprRef<F>, ExprExtRef<EF>> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "fn {}(", self.decl.name)?;
for (i, (name, attr, inp)) in self.decl.input.iter().enumerate() {
match attr.picus {
PicusArg::Unknown => write!(f, " {name}: {inp:?}")?,
_ => write!(f, " {attr} {name}: {inp:?}")?,
}
if i < self.decl.input.len() - 1 {
writeln!(f, ",")?;
}
}
write!(f, ")")?;
match self.decl.output {
Shape::Unit => {}
_ => write!(f, " -> {:?}", self.decl.output)?,
}
writeln!(f, " {{")?;
write!(f, "{}", self.body.to_string_pretty(" "))?;
writeln!(f, "}}")
}
}