p3_air/
virtual_column.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::Mul;
4
5use p3_field::{Field, PrimeCharacteristicRing};
6
7/// An affine function over columns in a PAIR.
8#[derive(Clone, Debug)]
9pub struct VirtualPairCol<F: Field> {
10    column_weights: Vec<(PairCol, F)>,
11    constant: F,
12}
13
14/// A column in a PAIR, i.e. either a preprocessed column or a main trace column.
15#[derive(Clone, Copy, Debug)]
16pub enum PairCol {
17    Preprocessed(usize),
18    Main(usize),
19}
20
21impl PairCol {
22    pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
23        match self {
24            Self::Preprocessed(i) => preprocessed[*i],
25            Self::Main(i) => main[*i],
26        }
27    }
28}
29
30impl<F: Field> VirtualPairCol<F> {
31    pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
32        Self {
33            column_weights,
34            constant,
35        }
36    }
37
38    pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
39        Self::new(
40            column_weights
41                .into_iter()
42                .map(|(i, w)| (PairCol::Preprocessed(i), w))
43                .collect(),
44            constant,
45        )
46    }
47
48    pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
49        Self::new(
50            column_weights
51                .into_iter()
52                .map(|(i, w)| (PairCol::Main(i), w))
53                .collect(),
54            constant,
55        )
56    }
57
58    pub const ONE: Self = Self::constant(F::ONE);
59
60    #[must_use]
61    pub const fn constant(x: F) -> Self {
62        Self {
63            column_weights: vec![],
64            constant: x,
65        }
66    }
67
68    #[must_use]
69    pub fn single(column: PairCol) -> Self {
70        Self {
71            column_weights: vec![(column, F::ONE)],
72            constant: F::ZERO,
73        }
74    }
75
76    #[must_use]
77    pub fn single_preprocessed(column: usize) -> Self {
78        Self::single(PairCol::Preprocessed(column))
79    }
80
81    #[must_use]
82    pub fn single_main(column: usize) -> Self {
83        Self::single(PairCol::Main(column))
84    }
85
86    #[must_use]
87    pub fn sum_main(columns: Vec<usize>) -> Self {
88        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
89        Self::new_main(column_weights, F::ZERO)
90    }
91
92    #[must_use]
93    pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
94        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
95        Self::new_preprocessed(column_weights, F::ZERO)
96    }
97
98    /// `a - b`, where `a` and `b` are columns in the preprocessed trace.
99    #[must_use]
100    pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
101        Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
102    }
103
104    /// `a - b`, where `a` and `b` are columns in the main trace.
105    #[must_use]
106    pub fn diff_main(a_col: usize, b_col: usize) -> Self {
107        Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
108    }
109
110    pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
111    where
112        F: Into<Expr>,
113        Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
114        Var: Into<Expr> + Copy,
115    {
116        self.column_weights
117            .iter()
118            .fold(self.constant.into(), |acc, &(col, w)| {
119                acc + col.get(preprocessed, main).into() * w
120            })
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use p3_baby_bear::BabyBear;
127
128    use super::*;
129
130    type F = BabyBear;
131
132    #[test]
133    fn test_pair_col_get_main_and_preprocessed() {
134        let pre = [F::from_u8(10), F::from_u8(20)];
135        let main = [F::from_u8(30), F::from_u8(40)];
136
137        // Preprocessed(1) should return 20
138        assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
139
140        // Main(0) should return 30
141        assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
142    }
143
144    #[test]
145    fn test_constant_only_virtual_pair_col() {
146        let col = VirtualPairCol::<F>::constant(F::from_u8(7));
147
148        // Apply to any input: result should always be the constant
149        let pre = [F::ONE];
150        let main = [F::ONE];
151        let result = col.apply::<F, F>(&pre, &main);
152
153        assert_eq!(result, F::from_u8(7));
154    }
155
156    #[test]
157    fn test_single_main_column() {
158        let col = VirtualPairCol::<F>::single_main(1); // column index 1
159
160        let main = [F::from_u8(9), F::from_u8(5)];
161        let pre = [F::ZERO]; // ignored
162
163        let result = col.apply::<F, F>(&pre, &main);
164
165        // Since we used single_main(1), this should equal main[1] = 5
166        assert_eq!(result, F::from_u8(5));
167    }
168
169    #[test]
170    fn test_single_preprocessed_column() {
171        let col = VirtualPairCol::<F>::single_preprocessed(0);
172
173        let pre = [F::from_u8(12)];
174        let main = [];
175
176        let result = col.apply::<F, F>(&pre, &main);
177
178        assert_eq!(result, F::from_u8(12));
179    }
180
181    #[test]
182    fn test_sum_main_columns() {
183        // This adds up main[0] + main[2]
184        let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
185
186        let main = [
187            F::TWO,
188            F::from_u8(99), // ignored
189            F::from_u8(5),
190        ];
191        let pre = [];
192
193        let result = col.apply::<F, F>(&pre, &main);
194
195        assert_eq!(result, F::from_u8(2) + F::from_u8(5));
196    }
197
198    #[test]
199    fn test_sum_preprocessed_columns() {
200        let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
201
202        let pre = [
203            F::from_u8(3), // ignored
204            F::from_u8(4),
205            F::from_u8(6),
206        ];
207        let main = [];
208
209        let result = col.apply::<F, F>(&pre, &main);
210
211        assert_eq!(result, F::from_u8(4) + F::from_u8(6));
212    }
213
214    #[test]
215    fn test_diff_main_columns() {
216        // Computes main[2] - main[0]
217        let col = VirtualPairCol::<F>::diff_main(2, 0);
218
219        let main = [
220            F::from_u8(7),
221            F::ZERO, // ignored
222            F::from_u8(10),
223        ];
224        let pre = [];
225
226        let result = col.apply::<F, F>(&pre, &main);
227
228        assert_eq!(result, F::from_u8(10) - F::from_u8(7));
229    }
230
231    #[test]
232    fn test_diff_preprocessed_columns() {
233        // Computes pre[1] - pre[0]
234        let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
235
236        let pre = [F::from_u8(4), F::from_u8(15)];
237        let main = [];
238
239        let result = col.apply::<F, F>(&pre, &main);
240
241        assert_eq!(result, F::from_u8(15) - F::from_u8(4));
242    }
243
244    #[test]
245    fn test_combination_with_constant_and_weights() {
246        // Computes: 3 * main[1] + 2 * pre[0] + constant (5)
247        let col = VirtualPairCol {
248            column_weights: vec![
249                (PairCol::Main(1), F::from_u8(3)),
250                (PairCol::Preprocessed(0), F::TWO),
251            ],
252            constant: F::from_u8(5),
253        };
254
255        let main = [F::ZERO, F::from_u8(4)];
256        let pre = [F::from_u8(6)];
257
258        let result = col.apply::<F, F>(&pre, &main);
259
260        // result = 3*4 + 2*6 + 5
261        assert_eq!(result, F::from_u8(29));
262    }
263
264    #[test]
265    fn test_virtual_pair_col_one_is_identity() {
266        // VirtualPairCol::ONE should always evaluate to 1 regardless of input
267        let col = VirtualPairCol::<F>::ONE;
268        let pre = [F::from_u8(99)];
269        let main = [F::from_u8(42)];
270
271        let result = col.apply::<F, F>(&pre, &main);
272
273        assert_eq!(result, F::ONE);
274    }
275}