Skip to main content

sp1_hypercube/ir/
lean.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4use slop_algebra::{ExtensionField, Field};
5
6use crate::{
7    air::AirInteraction,
8    ir::{ExprExtRef, ExprRef, IrVar, Shape},
9    InteractionKind,
10};
11
12// TODO(gzgz): implement constructor and destructor
13impl<F: Field, EF: ExtensionField<F>> Shape<ExprRef<F>, ExprExtRef<EF>> {
14    /// Output the string that would construct a value of this [Shape]
15    pub fn to_lean_constructor(&self, mapping: &HashMap<usize, String>) -> String {
16        match self {
17            Shape::Unit => unimplemented!("Unit shouldn't appear in constructors"),
18            Shape::Expr(expr) => expr.to_lean_string(mapping),
19            Shape::ExprExt(_) => todo!(),
20            Shape::Word(word) => {
21                format!("#v[{}]", word.iter().map(|x| x.to_lean_string(mapping)).join(", "))
22            }
23            Shape::Array(vals) => {
24                format!("#v[{}]", vals.iter().map(|x| x.to_lean_constructor(mapping)).join(", "))
25            }
26            Shape::Struct(_, fields) => {
27                format!(
28                    "{{ {} }}",
29                    fields
30                        .iter()
31                        .map(|(field_name, field_val)| format!(
32                            "{field_name} := {}",
33                            field_val.to_lean_constructor(mapping)
34                        ))
35                        .join(", ")
36                )
37            }
38        }
39    }
40
41    /// Output the string that would destruct a value of this [Shape]
42    pub fn to_lean_destructor(&self) -> String {
43        match self {
44            Shape::Unit => unimplemented!("Unit shouldn't appear in destructors"),
45            Shape::Expr(expr) => expr.to_lean_string(&HashMap::default()),
46            Shape::ExprExt(_) => todo!(),
47            Shape::Word(word) => format!(
48                "⟨⟨[{}]⟩, _⟩",
49                word.iter().map(|x| x.to_lean_string(&HashMap::default())).join(", ")
50            ),
51            Shape::Array(vals) => {
52                format!("⟨⟨[{}]⟩, _⟩", vals.iter().map(|x| x.to_lean_destructor()).join(", "))
53            }
54            Shape::Struct(_, _) => todo!("why would you need to destruct a struct"),
55        }
56    }
57
58    /// Calculates the full variable name that corresponds to `InputArg(x)`.
59    ///
60    /// For example,
61    /// ```lean
62    /// structure AddOperation where
63    ///   value : Word SP1Field
64    ///
65    /// def AddOperation.constraints
66    ///   (b : SP1Field)
67    ///   (c : SP1Field)
68    ///   (cols : AddOperation)
69    ///   (is_real : SP1Field) := sorry
70    /// ```
71    ///
72    /// `Expr(InputArg(3))` then maps to "cols.value[1]" because if you recursively flatten the
73    /// input arguments to `AddOperation.constraints` in argument/field declaration order, then the
74    /// element at index 3 corresponds to `cols.value[1]`.
75    pub fn map_input(&self, prefix: String, input_mapping: &mut HashMap<usize, String>) {
76        match self {
77            Shape::Unit => unimplemented!("Unit shouldn't appear as input"),
78            Shape::Expr(ExprRef::IrVar(IrVar::InputArg(idx))) => {
79                input_mapping.insert(*idx, prefix);
80            }
81            Shape::Word(vals) => {
82                for (i, val) in vals.iter().enumerate() {
83                    match val {
84                        ExprRef::IrVar(IrVar::InputArg(idx)) => {
85                            // In Mathlib, c[i] means some permutation stuff...
86                            if prefix == "c" {
87                                input_mapping.insert(*idx, format!("cc[{i}]"));
88                            } else {
89                                input_mapping.insert(*idx, format!("{prefix}[{i}]"));
90                            }
91                        }
92                        _ => unimplemented!("map_input must be backed by Input(x)"),
93                    }
94                }
95            }
96            Shape::Array(vals) => {
97                for (i, val) in vals.iter().enumerate() {
98                    val.map_input(format!("{prefix}[{i}]"), input_mapping);
99                }
100            }
101            Shape::Struct(_, fields) => {
102                for (name, field) in fields {
103                    field.map_input(format!("{prefix}.{name}"), input_mapping);
104                }
105            }
106            _ => unimplemented!(),
107        }
108    }
109}
110
111impl<F: Field> AirInteraction<ExprRef<F>> {
112    /// Converts an Air interaction to an `AirInteraction` in sp1-lean.
113    pub fn to_lean_string(&self, input_mapping: &HashMap<usize, String>) -> String {
114        let mut res = "(".to_string();
115
116        let kind_str = match self.kind {
117            InteractionKind::Memory => ".memory",
118            InteractionKind::Program => ".program",
119            InteractionKind::Byte => ".byte",
120            InteractionKind::State => ".state",
121            _ => todo!(),
122        };
123        res.push_str(kind_str);
124
125        match self.kind {
126            InteractionKind::Byte => {
127                assert_eq!(self.values.len(), 4);
128                for (idx, val) in self.values.iter().enumerate() {
129                    if idx == 0 {
130                        // ByteOpcode
131                        res.push_str(&format!(
132                            " (ByteOpcode.ofNat {})",
133                            val.to_lean_string(input_mapping)
134                        ));
135                    } else {
136                        res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
137                    }
138                }
139            }
140            InteractionKind::Memory => {
141                assert_eq!(self.values.len(), 9);
142                for val in &self.values {
143                    res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
144                }
145            }
146            InteractionKind::State => {
147                assert_eq!(self.values.len(), 5);
148                for val in &self.values {
149                    res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
150                }
151            }
152            InteractionKind::Program => {
153                assert_eq!(self.values.len(), 16);
154
155                for (idx, val) in self.values.iter().enumerate() {
156                    if idx == 3 {
157                        // Opcode
158                        res.push_str(&format!(
159                            " (Opcode.ofNat {})",
160                            val.to_lean_string(input_mapping)
161                        ));
162                    } else {
163                        res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
164                    }
165                }
166            }
167            _ => {
168                todo!();
169            }
170        }
171
172        res.push(')');
173        res
174    }
175}