1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
use crate::ast::{AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, Type};
use std::fmt;
/// This type provides type and contextual information about a binding,
/// i.e. not only does it tell us the type of a binding, but what type
/// of value was bound. This is used during analysis to check whether a
/// particular access is valid for the context it is in, as well as to
/// propagate type information while retaining information about where
/// the type was derived from.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BindingType {
/// A local variable whose value is not an alias of a global/module declaration
Local(Type),
/// A local variable that aliases a global/module declaration
Alias(Box<BindingType>),
/// A direct reference to a constant declaration
Constant(Type),
/// A type associated with a function signature
///
/// The result type is None if the function is an evaluator
Function(FunctionType),
/// A binding to a bus definition
Bus(BusType),
/// A function parameter corresponding to trace columns
TraceParam(TraceBinding),
/// A direct reference to one or more contiguous trace columns
TraceColumn(TraceBinding),
/// A potentially non-contiguous set of trace columns
Vector(Vec<BindingType>),
/// A direct reference to a public input
PublicInput(Type),
/// A direct reference to a periodic column
PeriodicColumn(usize),
}
impl BindingType {
/// Get the value type of this binding, if applicable
pub fn ty(&self) -> Option<Type> {
match self {
Self::TraceColumn(tb) | Self::TraceParam(tb) => Some(tb.ty()),
Self::Vector(elems) => Some(Type::Vector(elems.len())),
Self::Alias(aliased) => aliased.ty(),
Self::Local(ty) | Self::Constant(ty) | Self::PublicInput(ty) => Some(*ty),
Self::PeriodicColumn(_) => Some(Type::Felt),
Self::Function(ty) => ty.result(),
Self::Bus(_) => Some(Type::Felt),
}
}
/// Returns true if this binding type is a trace binding
pub fn is_trace_binding(&self) -> bool {
match self {
Self::TraceColumn(_) | Self::TraceParam(_) => true,
Self::Vector(elems) => elems.iter().all(|e| e.is_trace_binding()),
_ => false,
}
}
/// This function is used to split the current binding into two parts, the
/// first of which contains `n` trace columns, the second of which contains
/// what remains of the original binding. This function returns `Ok` when
/// there were `n` columns in the input binding type, otherwise `Err` with
/// a binding that contains as many columns as possible.
///
/// If the input binding type is a single logical binding, then the resulting
/// binding types will be of the same type. If however, the input binding type
/// is a vector of bindings, then the first part of the split will be a vector
/// containing `n` elements, where each element is a single logical binding of
/// size 1. This corresponds to the way trace column bindings are packed/unpacked
/// using vectors/lists in AirScript
pub fn split_columns(&self, n: usize) -> Result<(Self, Option<Self>), Self> {
use core::cmp::Ordering;
if n == 1 {
return Ok(self.pop_column());
}
match self {
Self::TraceColumn(tb) => match n.cmp(&tb.size) {
Ordering::Equal => Ok((self.clone(), None)),
Ordering::Less => {
let remaining = tb.size - n;
let first = Self::TraceColumn(TraceBinding { size: n, ..*tb });
let rest = Self::TraceColumn(TraceBinding {
size: remaining,
offset: tb.offset + n,
..*tb
});
Ok((first, Some(rest)))
}
Ordering::Greater => Err(self.clone()),
},
Self::Vector(elems) if elems.len() == 1 => elems[0].split_columns(n),
Self::Vector(elems) => {
let mut index = 0;
let mut remaining = n;
let mut set = Vec::with_capacity(elems.len());
let mut next = elems.get(index).cloned();
while remaining > 0 {
match next.take() {
None => return Err(Self::Vector(set)),
Some(binding_ty) => {
let (col, rest) = binding_ty.pop_column();
set.push(col);
remaining -= 1;
next = rest.or_else(|| {
index += 1;
elems.get(index).cloned()
});
}
}
}
let leftover = elems.len() - (index + 1);
match next {
None => Ok((Self::Vector(set), None)),
Some(mid) => {
index += 1;
let mut rest = Vec::with_capacity(leftover + 1);
rest.push(mid);
rest.extend_from_slice(&elems[index..]);
Ok((Self::Vector(set), Some(Self::Vector(rest))))
}
}
}
invalid => panic!("invalid trace column(s) binding type: {invalid:#?}"),
}
}
/// This function is like `split`, for the use case in which only a single
/// column is desired. This is used internally by `split` to handle those
/// cases, but may be used directly as well.
pub fn pop_column(&self) -> (Self, Option<Self>) {
match self {
// If we have a single logical binding, return the first half as
// a binding containing the first column of that binding, and the
// second half as a binding representing whatever was left, or `None`
// if it is empty.
Self::TraceColumn(tb) if tb.is_scalar() => (Self::TraceColumn(*tb), None),
Self::TraceColumn(tb) => {
let first = Self::TraceColumn(TraceBinding {
size: 1,
ty: Type::Felt,
..*tb
});
let remaining = tb.size - 1;
if remaining == 0 {
(first, None)
} else {
let rest = Self::TraceColumn(TraceBinding {
size: remaining,
ty: Type::Vector(remaining),
offset: tb.offset + 1,
..*tb
});
(first, Some(rest))
}
}
// If the vector has only one element, remove the vector and
// return the result of popping a column on the first element.
Self::Vector(elems) if elems.len() == 1 => elems[0].pop_column(),
// If the vector has multiple elements, then we're going to return
// a vector for the remainder of the split.
Self::Vector(elems) => {
// Take the first element out of the vector
let (popped, rest) = elems.split_first().unwrap();
// Pop a single trace column from that element
let (first, mid) = popped.pop_column();
// The `popped` binding must have been a TraceColumn type, as
// as nested binding vectors are not permitted in calls to evaluators
match mid {
None => (first, Some(Self::Vector(rest.to_vec()))),
Some(mid) => {
let mut mid_and_rest = Vec::with_capacity(rest.len() + 1);
mid_and_rest.push(mid);
mid_and_rest.extend_from_slice(rest);
(first, Some(Self::Vector(mid_and_rest)))
}
}
}
invalid => panic!("invalid trace column(s) binding type: {invalid:#?}"),
}
}
/// Produce a new [BindingType] which represents accessing the current binding via `access_type`
pub fn access(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
match self {
Self::Alias(aliased) => aliased.access(access_type),
Self::Local(ty) => ty.access(access_type).map(Self::Local),
Self::Constant(ty) => ty
.access(access_type)
.map(|t| Self::Alias(Box::new(Self::Constant(t)))),
Self::TraceColumn(tb) => tb.access(access_type).map(Self::TraceColumn),
Self::TraceParam(tb) => tb.access(access_type).map(Self::TraceParam),
Self::Vector(elems) => match access_type {
AccessType::Default => Ok(Self::Vector(elems.clone())),
AccessType::Index(idx) if idx >= elems.len() => {
Err(InvalidAccessError::IndexOutOfBounds)
}
AccessType::Index(idx) => Ok(elems[idx].clone()),
AccessType::Slice(range) => {
let slice_range = range.to_slice_range();
if slice_range.end > elems.len() {
Err(InvalidAccessError::IndexOutOfBounds)
} else {
Ok(Self::Vector(elems[slice_range].to_vec()))
}
}
AccessType::Matrix(row, _) if row >= elems.len() => {
Err(InvalidAccessError::IndexOutOfBounds)
}
AccessType::Matrix(row, col) => elems[row].access(AccessType::Index(col)),
},
Self::PublicInput(ty) => ty.access(access_type).map(Self::PublicInput),
Self::PeriodicColumn(period) => match access_type {
AccessType::Default => Ok(Self::PeriodicColumn(*period)),
_ => Err(InvalidAccessError::IndexIntoScalar),
},
Self::Function(_) => Err(InvalidAccessError::InvalidBinding),
Self::Bus(bus) => Ok(Self::Bus(*bus)),
}
}
}
impl fmt::Display for BindingType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Alias(aliased) => write!(f, "{aliased}"),
Self::Local(_) => f.write_str("local"),
Self::Constant(_) => f.write_str("constant"),
Self::Vector(_) => f.write_str("vector"),
Self::Function(_) => f.write_str("function"),
Self::TraceColumn(_) | Self::TraceParam(_) => f.write_str("trace column(s)"),
Self::PublicInput(_) => f.write_str("public input(s)"),
Self::PeriodicColumn(_) => f.write_str("periodic column(s)"),
Self::Bus(_) => f.write_str("bus"),
}
}
}