use alloc::vec;
use alloc::vec::Vec;
use p3_field::{ExtensionField, Field, batch_multiplicative_inverse};
use p3_matrix::Matrix;
use p3_matrix::dense::DenseMatrix;
#[allow(clippy::doc_overindented_list_items)]
pub fn generate_logup_trace<EF, F>(main_trace: &DenseMatrix<F>, randomness: &EF) -> DenseMatrix<F>
where
EF: ExtensionField<F>,
F: Field,
{
let len = main_trace.height();
let width = main_trace.width();
assert!(
width >= 2,
"Permutation check is not possible for main trace width ({width}) < 2"
);
let mut main_second_last_col = vec![];
let mut main_last_col = vec![];
for row_idx in 0..len {
main_second_last_col.push(main_trace.get(row_idx, width - 2).unwrap());
main_last_col.push(main_trace.get(row_idx, width - 1).unwrap());
}
#[cfg(debug_assertions)]
{
assert!(
is_permutation(&main_second_last_col, &main_last_col),
"The last two columns of the main trace must form a permutation"
);
}
let r_sub_main_second_last_col = main_second_last_col
.iter()
.map(|&x| *randomness - EF::from(x))
.collect::<Vec<EF>>();
let r_sub_main_last_col = main_last_col
.iter()
.map(|&x| *randomness - EF::from(x))
.collect::<Vec<EF>>();
let aux_first_col = batch_multiplicative_inverse(&r_sub_main_second_last_col);
let aux_second_col = batch_multiplicative_inverse(&r_sub_main_last_col);
let mut aux_trace_values = vec![
aux_first_col[0],
aux_second_col[0],
aux_first_col[0] - aux_second_col[0],
];
for row_idx in 1..len {
let tmp = aux_trace_values[(row_idx - 1) * 3 + 2] + aux_first_col[row_idx]
- aux_second_col[row_idx];
aux_trace_values.extend_from_slice(&[aux_first_col[row_idx], aux_second_col[row_idx], tmp]);
}
let aux_trace_base_values = aux_trace_values
.iter()
.flat_map(|r| r.as_basis_coefficients_slice())
.cloned()
.collect();
DenseMatrix::new(aux_trace_base_values, 3 * EF::DIMENSION)
}
#[cfg(debug_assertions)]
fn is_permutation<F: Field>(col1: &[F], col2: &[F]) -> bool {
if col1.len() != col2.len() {
return false;
}
let mut col2_used = vec![false; col2.len()];
for &elem1 in col1 {
let mut found = false;
for (i, &elem2) in col2.iter().enumerate() {
if !col2_used[i] && elem1 == elem2 {
col2_used[i] = true;
found = true;
break;
}
}
if !found {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use p3_field::PrimeCharacteristicRing;
use p3_field::extension::BinomialExtensionField;
use p3_goldilocks::Goldilocks;
use super::*;
type F = Goldilocks;
type EF = BinomialExtensionField<F, 2>;
#[test]
fn test_simple_permutation() {
let trace_values = vec![
F::from_u64(0),
F::from_u64(1),
F::from_u64(4),
F::from_u64(0),
F::from_u64(2),
F::from_u64(3),
F::from_u64(0),
F::from_u64(3),
F::from_u64(2),
F::from_u64(0),
F::from_u64(4),
F::from_u64(1),
];
let main_trace = DenseMatrix::new(trace_values, 3);
let randomness = EF::from_u64(100);
let aux_trace = generate_logup_trace::<EF, _>(&main_trace, &randomness);
assert_eq!(aux_trace.height(), 4);
assert_eq!(aux_trace.width(), 6);
let last_running_sum = aux_trace.get(3, 4).unwrap();
assert_eq!(last_running_sum, F::ZERO);
}
#[test]
fn test_running_sum_initialization() {
let trace_values = vec![
F::from_u64(10),
F::from_u64(5),
F::from_u64(8),
F::from_u64(20),
F::from_u64(7),
F::from_u64(5),
F::from_u64(30),
F::from_u64(8),
F::from_u64(7),
];
let main_trace = DenseMatrix::new(trace_values, 3);
let randomness = EF::from_u64(50);
let aux_trace = generate_logup_trace::<EF, _>(&main_trace, &randomness);
assert_eq!(aux_trace.height(), 3);
assert_eq!(aux_trace.width(), 6);
let t0 = aux_trace.get(0, 0).unwrap();
let w0 = aux_trace.get(0, 2).unwrap();
let running_sum_0 = aux_trace.get(0, 4).unwrap();
assert_eq!(running_sum_0, t0 - w0);
}
#[test]
#[cfg(debug_assertions)]
fn test_is_permutation() {
let a = F::from_u64(1);
let b = F::from_u64(2);
let c = F::from_u64(3);
let d = F::from_u64(4);
let col1 = vec![a, b, c, d];
let col2 = vec![d, c, b, a];
assert!(is_permutation(&col1, &col2));
let col3 = vec![a, a, b, c];
let col4 = vec![c, a, b, a];
assert!(is_permutation(&col3, &col4));
let col5 = vec![a, b, c, d];
let col6 = vec![a, a, b, c];
assert!(!is_permutation(&col5, &col6));
let col7 = vec![a, b];
let col8 = vec![a, b, c];
assert!(!is_permutation(&col7, &col8));
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "The last two columns of the main trace must form a permutation")]
fn test_invalid_permutation() {
let trace_values = vec![
F::from_u64(0),
F::from_u64(1),
F::from_u64(4),
F::from_u64(0),
F::from_u64(2),
F::from_u64(5), F::from_u64(0),
F::from_u64(3),
F::from_u64(2),
F::from_u64(0),
F::from_u64(4),
F::from_u64(1),
];
let main_trace = DenseMatrix::new(trace_values, 3);
let randomness = EF::from_u64(100);
generate_logup_trace::<EF, _>(&main_trace, &randomness);
}
}