1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::Mul;
4
5use p3_field::{Field, PrimeCharacteristicRing};
6
7#[derive(Clone, Debug)]
9pub struct VirtualPairCol<F: Field> {
10 column_weights: Vec<(PairCol, F)>,
11 constant: F,
12}
13
14#[derive(Clone, Copy, Debug)]
16pub enum PairCol {
17 Preprocessed(usize),
18 Main(usize),
19}
20
21impl PairCol {
22 pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
23 match self {
24 Self::Preprocessed(i) => preprocessed[*i],
25 Self::Main(i) => main[*i],
26 }
27 }
28}
29
30impl<F: Field> VirtualPairCol<F> {
31 pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
32 Self {
33 column_weights,
34 constant,
35 }
36 }
37
38 pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
39 Self::new(
40 column_weights
41 .into_iter()
42 .map(|(i, w)| (PairCol::Preprocessed(i), w))
43 .collect(),
44 constant,
45 )
46 }
47
48 pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
49 Self::new(
50 column_weights
51 .into_iter()
52 .map(|(i, w)| (PairCol::Main(i), w))
53 .collect(),
54 constant,
55 )
56 }
57
58 pub const ONE: Self = Self::constant(F::ONE);
59
60 #[must_use]
61 pub const fn constant(x: F) -> Self {
62 Self {
63 column_weights: vec![],
64 constant: x,
65 }
66 }
67
68 #[must_use]
69 pub fn single(column: PairCol) -> Self {
70 Self {
71 column_weights: vec![(column, F::ONE)],
72 constant: F::ZERO,
73 }
74 }
75
76 #[must_use]
77 pub fn single_preprocessed(column: usize) -> Self {
78 Self::single(PairCol::Preprocessed(column))
79 }
80
81 #[must_use]
82 pub fn single_main(column: usize) -> Self {
83 Self::single(PairCol::Main(column))
84 }
85
86 #[must_use]
87 pub fn sum_main(columns: Vec<usize>) -> Self {
88 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
89 Self::new_main(column_weights, F::ZERO)
90 }
91
92 #[must_use]
93 pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
94 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
95 Self::new_preprocessed(column_weights, F::ZERO)
96 }
97
98 #[must_use]
100 pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
101 Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
102 }
103
104 #[must_use]
106 pub fn diff_main(a_col: usize, b_col: usize) -> Self {
107 Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
108 }
109
110 pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
111 where
112 F: Into<Expr>,
113 Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
114 Var: Into<Expr> + Copy,
115 {
116 self.column_weights
117 .iter()
118 .fold(self.constant.into(), |acc, &(col, w)| {
119 acc + col.get(preprocessed, main).into() * w
120 })
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use p3_baby_bear::BabyBear;
127
128 use super::*;
129
130 type F = BabyBear;
131
132 #[test]
133 fn test_pair_col_get_main_and_preprocessed() {
134 let pre = [F::from_u8(10), F::from_u8(20)];
135 let main = [F::from_u8(30), F::from_u8(40)];
136
137 assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
139
140 assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
142 }
143
144 #[test]
145 fn test_constant_only_virtual_pair_col() {
146 let col = VirtualPairCol::<F>::constant(F::from_u8(7));
147
148 let pre = [F::ONE];
150 let main = [F::ONE];
151 let result = col.apply::<F, F>(&pre, &main);
152
153 assert_eq!(result, F::from_u8(7));
154 }
155
156 #[test]
157 fn test_single_main_column() {
158 let col = VirtualPairCol::<F>::single_main(1); let main = [F::from_u8(9), F::from_u8(5)];
161 let pre = [F::ZERO]; let result = col.apply::<F, F>(&pre, &main);
164
165 assert_eq!(result, F::from_u8(5));
167 }
168
169 #[test]
170 fn test_single_preprocessed_column() {
171 let col = VirtualPairCol::<F>::single_preprocessed(0);
172
173 let pre = [F::from_u8(12)];
174 let main = [];
175
176 let result = col.apply::<F, F>(&pre, &main);
177
178 assert_eq!(result, F::from_u8(12));
179 }
180
181 #[test]
182 fn test_sum_main_columns() {
183 let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
185
186 let main = [
187 F::TWO,
188 F::from_u8(99), F::from_u8(5),
190 ];
191 let pre = [];
192
193 let result = col.apply::<F, F>(&pre, &main);
194
195 assert_eq!(result, F::from_u8(2) + F::from_u8(5));
196 }
197
198 #[test]
199 fn test_sum_preprocessed_columns() {
200 let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
201
202 let pre = [
203 F::from_u8(3), F::from_u8(4),
205 F::from_u8(6),
206 ];
207 let main = [];
208
209 let result = col.apply::<F, F>(&pre, &main);
210
211 assert_eq!(result, F::from_u8(4) + F::from_u8(6));
212 }
213
214 #[test]
215 fn test_diff_main_columns() {
216 let col = VirtualPairCol::<F>::diff_main(2, 0);
218
219 let main = [
220 F::from_u8(7),
221 F::ZERO, F::from_u8(10),
223 ];
224 let pre = [];
225
226 let result = col.apply::<F, F>(&pre, &main);
227
228 assert_eq!(result, F::from_u8(10) - F::from_u8(7));
229 }
230
231 #[test]
232 fn test_diff_preprocessed_columns() {
233 let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
235
236 let pre = [F::from_u8(4), F::from_u8(15)];
237 let main = [];
238
239 let result = col.apply::<F, F>(&pre, &main);
240
241 assert_eq!(result, F::from_u8(15) - F::from_u8(4));
242 }
243
244 #[test]
245 fn test_combination_with_constant_and_weights() {
246 let col = VirtualPairCol {
248 column_weights: vec![
249 (PairCol::Main(1), F::from_u8(3)),
250 (PairCol::Preprocessed(0), F::TWO),
251 ],
252 constant: F::from_u8(5),
253 };
254
255 let main = [F::ZERO, F::from_u8(4)];
256 let pre = [F::from_u8(6)];
257
258 let result = col.apply::<F, F>(&pre, &main);
259
260 assert_eq!(result, F::from_u8(29));
262 }
263
264 #[test]
265 fn test_virtual_pair_col_one_is_identity() {
266 let col = VirtualPairCol::<F>::ONE;
268 let pre = [F::from_u8(99)];
269 let main = [F::from_u8(42)];
270
271 let result = col.apply::<F, F>(&pre, &main);
272
273 assert_eq!(result, F::ONE);
274 }
275}