use std::ops::{BitAnd, BitOrAssign, Shl, ShrAssign};
#[cfg(feature = "prover")]
use std::simd::u32x16;
use stwo::core::fields::m31::M31;
#[cfg(feature = "prover")]
use stwo::prover::backend::simd::m31::{PackedM31, N_LANES};
use super::cpu::{FELT252_BITS_PER_WORD, FELT252_N_WORDS};
pub fn split<const N: usize, const M: usize, TU32>(x: [TU32; N], mask: TU32) -> [TU32; M]
where
TU32: BitAnd<Output = TU32>
+ BitOrAssign
+ Copy
+ ShrAssign<u32>
+ Shl<u32, Output = TU32>
+ Default,
{
let mut res = [TU32::default(); M];
let mut n_bits_in_word = 32;
let mut word_i = 0;
let mut word = x[word_i];
for e in res.iter_mut() {
if n_bits_in_word > FELT252_BITS_PER_WORD {
*e = word & mask;
word >>= FELT252_BITS_PER_WORD as u32;
n_bits_in_word -= FELT252_BITS_PER_WORD;
continue;
}
*e = word;
word_i += 1;
word = x.get(word_i).copied().unwrap_or_default();
if n_bits_in_word < FELT252_BITS_PER_WORD {
*e |= (word << n_bits_in_word as u32) & mask;
word >>= (FELT252_BITS_PER_WORD - n_bits_in_word) as u32;
}
n_bits_in_word += 32 - FELT252_BITS_PER_WORD;
}
res
}
#[cfg(feature = "prover")]
pub fn split_f252_simd(x: [u32x16; 8]) -> [PackedM31; FELT252_N_WORDS] {
split(
x,
u32x16::from_array([(1 << FELT252_BITS_PER_WORD) - 1; N_LANES]),
)
.map(|x| PackedM31::from(x.to_array().map(M31::from_u32_unchecked)))
}
pub fn split_f252(x: [u32; 8]) -> [M31; FELT252_N_WORDS] {
split(x, (1 << FELT252_BITS_PER_WORD) - 1).map(M31::from_u32_unchecked)
}