Skip to main content

machine_cat/
trace.rs

1//! Execution traces: 2D tables of field elements.
2//!
3//! A [`Trace<F>`] is the witness for an AIR.  It has a fixed
4//! number of columns (the AIR's shape) and a variable number
5//! of rows (the computation length).  Stored in row-major
6//! flat layout for cache locality.
7
8use crate::column::{Column, ColumnCount, ColumnRef};
9use crate::error::Error;
10use field_cat::Field;
11
12/// A row count newtype.
13///
14/// # Examples
15///
16/// ```
17/// use machine_cat::RowCount;
18///
19/// let rc = RowCount::new(8);
20/// assert_eq!(rc.count(), 8);
21/// ```
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct RowCount(usize);
24
25impl RowCount {
26    /// Create a new row count.
27    #[must_use]
28    pub fn new(n: usize) -> Self {
29        Self(n)
30    }
31
32    /// The underlying count.
33    #[must_use]
34    pub fn count(self) -> usize {
35        self.0
36    }
37}
38
39impl core::fmt::Display for RowCount {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45/// An execution trace: a 2D table of field elements.
46///
47/// Stored in row-major order: `data[row * column_count + col]`.
48/// Each row has exactly `column_count` elements.
49///
50/// # Examples
51///
52/// ```
53/// use field_cat::F101;
54/// use machine_cat::{Column, ColumnCount, Trace};
55///
56/// let trace = Trace::from_rows(
57///     ColumnCount::new(2),
58///     &[
59///         vec![F101::new(1), F101::new(1)],
60///         vec![F101::new(1), F101::new(2)],
61///         vec![F101::new(2), F101::new(3)],
62///         vec![F101::new(3), F101::new(5)],
63///     ],
64/// )?;
65///
66/// assert_eq!(trace.get(2, Column::new(0))?, F101::new(2));
67/// assert_eq!(trace.get(2, Column::new(1))?, F101::new(3));
68/// # Ok::<(), machine_cat::Error>(())
69/// ```
70#[derive(Debug, Clone)]
71pub struct Trace<F: Field> {
72    data: Vec<F>,
73    column_count: ColumnCount,
74    row_count: RowCount,
75}
76
77impl<F: Field> Trace<F> {
78    /// Construct a trace from a vector of rows.
79    ///
80    /// Each inner vector must have exactly `column_count` elements.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`Error::EmptyTrace`] if `rows` is empty, or
85    /// [`Error::RowLengthMismatch`] if any row has the wrong length.
86    pub fn from_rows(column_count: ColumnCount, rows: &[Vec<F>]) -> Result<Self, Error> {
87        if rows.is_empty() {
88            Err(Error::EmptyTrace)
89        } else {
90            // Validate all row lengths, then flatten.
91            let data: Result<Vec<F>, Error> = rows.iter().enumerate().try_fold(
92                Vec::with_capacity(rows.len() * column_count.count()),
93                |acc, (i, row)| {
94                    if row.len() == column_count.count() {
95                        Ok(acc.into_iter().chain(row.iter().cloned()).collect())
96                    } else {
97                        Err(Error::RowLengthMismatch {
98                            row: i,
99                            expected: column_count.count(),
100                            actual: row.len(),
101                        })
102                    }
103                },
104            );
105            Ok(Self {
106                data: data?,
107                column_count,
108                row_count: RowCount::new(rows.len()),
109            })
110        }
111    }
112
113    /// The number of columns.
114    #[must_use]
115    pub fn column_count(&self) -> ColumnCount {
116        self.column_count
117    }
118
119    /// The number of rows.
120    #[must_use]
121    pub fn row_count(&self) -> RowCount {
122        self.row_count
123    }
124
125    /// Get the value at `(row, col)`.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`Error::ColumnOutOfBounds`] if the row or column
130    /// is out of range.
131    pub fn get(&self, row: usize, col: Column) -> Result<F, Error> {
132        if row >= self.row_count.count() || col.index() >= self.column_count.count() {
133            Err(Error::ColumnOutOfBounds {
134                index: col.index(),
135                column_count: self.column_count.count(),
136            })
137        } else {
138            Ok(self.data[row * self.column_count.count() + col.index()].clone())
139        }
140    }
141
142    /// Build an assignment function for a row pair `(row, row+1)`.
143    ///
144    /// The returned closure maps [`ColumnRef::Current`] to values
145    /// in `row` and [`ColumnRef::Next`] to values in `row + 1`.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`Error::InsufficientRows`] if `row + 1 >= row_count`.
150    pub fn row_pair_assignment(
151        &self,
152        row: usize,
153    ) -> Result<impl Fn(ColumnRef) -> Result<F, Error> + '_, Error> {
154        if row + 1 >= self.row_count.count() {
155            Err(Error::InsufficientRows {
156                row_count: self.row_count.count(),
157            })
158        } else {
159            Ok(move |cr: ColumnRef| {
160                let (r, c) = match cr {
161                    ColumnRef::Current(col) => (row, col),
162                    ColumnRef::Next(col) => (row + 1, col),
163                };
164                self.get(r, c)
165            })
166        }
167    }
168
169    /// Extract all values from a single column.
170    ///
171    /// Returns a vector of length `row_count`.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`Error::ColumnOutOfBounds`] if `col` is out of range.
176    pub fn column_values(&self, col: Column) -> Result<Vec<F>, Error> {
177        if col.index() >= self.column_count.count() {
178            Err(Error::ColumnOutOfBounds {
179                index: col.index(),
180                column_count: self.column_count.count(),
181            })
182        } else {
183            Ok((0..self.row_count.count())
184                .map(|r| self.data[r * self.column_count.count() + col.index()].clone())
185                .collect())
186        }
187    }
188
189    /// The raw data as a flat slice (row-major order).
190    #[must_use]
191    pub fn data(&self) -> &[F] {
192        &self.data
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use field_cat::F101;
200
201    fn fib_trace() -> Result<Trace<F101>, Error> {
202        Trace::from_rows(
203            ColumnCount::new(2),
204            &[
205                vec![F101::new(1), F101::new(1)],
206                vec![F101::new(1), F101::new(2)],
207                vec![F101::new(2), F101::new(3)],
208                vec![F101::new(3), F101::new(5)],
209            ],
210        )
211    }
212
213    #[test]
214    fn from_rows_and_get() -> Result<(), Error> {
215        let t = fib_trace()?;
216        assert_eq!(t.row_count(), RowCount::new(4));
217        assert_eq!(t.column_count(), ColumnCount::new(2));
218        assert_eq!(t.get(0, Column::new(0))?, F101::new(1));
219        assert_eq!(t.get(3, Column::new(1))?, F101::new(5));
220        Ok(())
221    }
222
223    #[test]
224    fn empty_rows_fails() {
225        let result = Trace::<F101>::from_rows(ColumnCount::new(2), &[]);
226        assert!(result.is_err());
227    }
228
229    #[test]
230    fn wrong_row_length_fails() {
231        let result = Trace::from_rows(
232            ColumnCount::new(2),
233            &[vec![F101::new(1), F101::new(2)], vec![F101::new(3)]],
234        );
235        assert!(result.is_err());
236    }
237
238    #[test]
239    fn column_values_extraction() -> Result<(), Error> {
240        let t = fib_trace()?;
241        let col0 = t.column_values(Column::new(0))?;
242        assert_eq!(
243            col0,
244            vec![F101::new(1), F101::new(1), F101::new(2), F101::new(3)]
245        );
246        Ok(())
247    }
248
249    #[test]
250    fn row_pair_assignment_works() -> Result<(), Error> {
251        let t = fib_trace()?;
252        let assign = t.row_pair_assignment(1)?;
253        // Current row 1: [1, 2], Next row 2: [2, 3]
254        assert_eq!(assign(ColumnRef::Current(Column::new(0)))?, F101::new(1));
255        assert_eq!(assign(ColumnRef::Current(Column::new(1)))?, F101::new(2));
256        assert_eq!(assign(ColumnRef::Next(Column::new(0)))?, F101::new(2));
257        assert_eq!(assign(ColumnRef::Next(Column::new(1)))?, F101::new(3));
258        Ok(())
259    }
260
261    #[test]
262    fn row_pair_at_last_row_fails() -> Result<(), Error> {
263        let t = fib_trace()?;
264        // Row 3 is the last row; no "next" row exists.
265        let result = t.row_pair_assignment(3);
266        assert!(result.is_err());
267        Ok(())
268    }
269}