1use std::collections::HashMap;
2
3use itertools::Itertools;
4use slop_algebra::{ExtensionField, Field};
5
6use crate::{
7 air::AirInteraction,
8 ir::{ExprExtRef, ExprRef, IrVar, Shape},
9 InteractionKind,
10};
11
12impl<F: Field, EF: ExtensionField<F>> Shape<ExprRef<F>, ExprExtRef<EF>> {
14 pub fn to_lean_constructor(&self, mapping: &HashMap<usize, String>) -> String {
16 match self {
17 Shape::Unit => unimplemented!("Unit shouldn't appear in constructors"),
18 Shape::Expr(expr) => expr.to_lean_string(mapping),
19 Shape::ExprExt(_) => todo!(),
20 Shape::Word(word) => {
21 format!("#v[{}]", word.iter().map(|x| x.to_lean_string(mapping)).join(", "))
22 }
23 Shape::Array(vals) => {
24 format!("#v[{}]", vals.iter().map(|x| x.to_lean_constructor(mapping)).join(", "))
25 }
26 Shape::Struct(_, fields) => {
27 format!(
28 "{{ {} }}",
29 fields
30 .iter()
31 .map(|(field_name, field_val)| format!(
32 "{field_name} := {}",
33 field_val.to_lean_constructor(mapping)
34 ))
35 .join(", ")
36 )
37 }
38 }
39 }
40
41 pub fn to_lean_destructor(&self) -> String {
43 match self {
44 Shape::Unit => unimplemented!("Unit shouldn't appear in destructors"),
45 Shape::Expr(expr) => expr.to_lean_string(&HashMap::default()),
46 Shape::ExprExt(_) => todo!(),
47 Shape::Word(word) => format!(
48 "⟨⟨[{}]⟩, _⟩",
49 word.iter().map(|x| x.to_lean_string(&HashMap::default())).join(", ")
50 ),
51 Shape::Array(vals) => {
52 format!("⟨⟨[{}]⟩, _⟩", vals.iter().map(|x| x.to_lean_destructor()).join(", "))
53 }
54 Shape::Struct(_, _) => todo!("why would you need to destruct a struct"),
55 }
56 }
57
58 pub fn map_input(&self, prefix: String, input_mapping: &mut HashMap<usize, String>) {
76 match self {
77 Shape::Unit => unimplemented!("Unit shouldn't appear as input"),
78 Shape::Expr(ExprRef::IrVar(IrVar::InputArg(idx))) => {
79 input_mapping.insert(*idx, prefix);
80 }
81 Shape::Word(vals) => {
82 for (i, val) in vals.iter().enumerate() {
83 match val {
84 ExprRef::IrVar(IrVar::InputArg(idx)) => {
85 if prefix == "c" {
87 input_mapping.insert(*idx, format!("cc[{i}]"));
88 } else {
89 input_mapping.insert(*idx, format!("{prefix}[{i}]"));
90 }
91 }
92 _ => unimplemented!("map_input must be backed by Input(x)"),
93 }
94 }
95 }
96 Shape::Array(vals) => {
97 for (i, val) in vals.iter().enumerate() {
98 val.map_input(format!("{prefix}[{i}]"), input_mapping);
99 }
100 }
101 Shape::Struct(_, fields) => {
102 for (name, field) in fields {
103 field.map_input(format!("{prefix}.{name}"), input_mapping);
104 }
105 }
106 _ => unimplemented!(),
107 }
108 }
109}
110
111impl<F: Field> AirInteraction<ExprRef<F>> {
112 pub fn to_lean_string(&self, input_mapping: &HashMap<usize, String>) -> String {
114 let mut res = "(".to_string();
115
116 let kind_str = match self.kind {
117 InteractionKind::Memory => ".memory",
118 InteractionKind::Program => ".program",
119 InteractionKind::Byte => ".byte",
120 InteractionKind::State => ".state",
121 _ => todo!(),
122 };
123 res.push_str(kind_str);
124
125 match self.kind {
126 InteractionKind::Byte => {
127 assert_eq!(self.values.len(), 4);
128 for (idx, val) in self.values.iter().enumerate() {
129 if idx == 0 {
130 res.push_str(&format!(
132 " (ByteOpcode.ofNat {})",
133 val.to_lean_string(input_mapping)
134 ));
135 } else {
136 res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
137 }
138 }
139 }
140 InteractionKind::Memory => {
141 assert_eq!(self.values.len(), 9);
142 for val in &self.values {
143 res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
144 }
145 }
146 InteractionKind::State => {
147 assert_eq!(self.values.len(), 5);
148 for val in &self.values {
149 res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
150 }
151 }
152 InteractionKind::Program => {
153 assert_eq!(self.values.len(), 16);
154
155 for (idx, val) in self.values.iter().enumerate() {
156 if idx == 3 {
157 res.push_str(&format!(
159 " (Opcode.ofNat {})",
160 val.to_lean_string(input_mapping)
161 ));
162 } else {
163 res.push_str(&format!(" {}", val.to_lean_string(input_mapping)));
164 }
165 }
166 }
167 _ => {
168 todo!();
169 }
170 }
171
172 res.push(')');
173 res
174 }
175}