1use std::{collections::HashMap, fmt::Display};
2
3use serde::{Deserialize, Serialize};
4use slop_algebra::{ExtensionField, Field};
5
6use crate::ir::{Ast, ExprExtRef, ExprRef, Shape};
7
8#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum PicusArg {
14 Input,
16 Output,
18 #[default]
20 Unknown,
21}
22
23#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct Attribute {
26 pub picus: PicusArg,
29}
30
31impl Display for Attribute {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self.picus {
34 PicusArg::Input => write!(f, "#[picus(input)]"),
35 PicusArg::Output => write!(f, "#[picus(output)]"),
36 PicusArg::Unknown => Ok(()),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct FuncDecl<Expr, ExprExt> {
45 pub name: String,
47 pub input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
49 pub output: Shape<Expr, ExprExt>,
51}
52
53impl<Expr, ExprExt> FuncDecl<Expr, ExprExt> {
54 pub fn new(
56 name: String,
57 input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
58 output: Shape<Expr, ExprExt>,
59 ) -> Self {
60 Self { name, input, output }
61 }
62}
63
64impl<F: Field, EF: ExtensionField<F>> FuncDecl<ExprRef<F>, ExprExtRef<EF>> {
65 pub fn input_mapping(&self) -> HashMap<usize, String> {
67 let mut mapping = HashMap::new();
68 for (name, _, arg) in &self.input {
69 arg.map_input(name.clone(), &mut mapping);
70 }
71 mapping
72 }
73
74 pub fn to_output_lean_type(&self) -> String {
76 match self.output {
77 Shape::Unit => "SP1ConstraintList".to_string(),
78 _ => format!("{} × SP1ConstraintList", self.output.to_lean_type()),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct Func<Expr, ExprExt> {
86 pub decl: FuncDecl<Expr, ExprExt>,
88 pub body: Ast<Expr, ExprExt>,
91}
92
93impl<F: Field, EF: ExtensionField<F>> Display for Func<ExprRef<F>, ExprExtRef<EF>> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 writeln!(f, "fn {}(", self.decl.name)?;
96 for (i, (name, attr, inp)) in self.decl.input.iter().enumerate() {
97 match attr.picus {
99 PicusArg::Unknown => write!(f, " {name}: {inp:?}")?,
100 _ => write!(f, " {attr} {name}: {inp:?}")?,
101 }
102 if i < self.decl.input.len() - 1 {
103 writeln!(f, ",")?;
104 }
105 }
106 write!(f, ")")?;
107 match self.decl.output {
108 Shape::Unit => {}
109 _ => write!(f, " -> {:?}", self.decl.output)?,
110 }
111 writeln!(f, " {{")?;
112 write!(f, "{}", self.body.to_string_pretty(" "))?;
113 writeln!(f, "}}")
114 }
115}