use core::marker::PhantomData;
use crate::field::SumcheckField;
use spongefish::{Decoding, Encoding, NargDeserialize, NargSerialize};
use crate::field::Goldilocks4;
pub(crate) trait Field:
SumcheckField + Encoding + Decoding + NargSerialize + NargDeserialize
{
}
impl<F> Field for F where F: SumcheckField + Encoding + Decoding + NargSerialize + NargDeserialize {}
pub(crate) trait AdditiveCode {
type InputAlphabet: Field;
type OutputAlphabet: Clone + Eq;
fn msg_len(&self) -> usize;
fn codeword_len(&self) -> usize;
fn encode(&self, input: &[Self::InputAlphabet]) -> Vec<Self::OutputAlphabet>;
}
pub(crate) trait LinearCode:
AdditiveCode<InputAlphabet = Self::Alphabet, OutputAlphabet = Self::Alphabet>
{
type Alphabet: Field;
}
pub(crate) struct ReedSolomon<F> {
pub(crate) msg_len: usize,
pub(crate) codeword_len: usize,
_f: PhantomData<F>,
}
impl<F> ReedSolomon<F> {
pub(crate) fn new(msg_len: usize) -> Self {
assert!(
msg_len.is_power_of_two(),
"ReedSolomon::new: msg_len must be a power of 2, got {msg_len}"
);
Self {
msg_len,
codeword_len: msg_len * 4,
_f: PhantomData,
}
}
}
impl AdditiveCode for ReedSolomon<Goldilocks4> {
type InputAlphabet = Goldilocks4;
type OutputAlphabet = Goldilocks4;
fn msg_len(&self) -> usize {
self.msg_len
}
fn codeword_len(&self) -> usize {
self.codeword_len
}
fn encode(&self, input: &[Self::InputAlphabet]) -> Vec<Self::OutputAlphabet> {
assert_eq!(
input.len(),
self.msg_len,
"ReedSolomon::encode: input length mismatch"
);
let mut padded = input.to_vec();
padded.resize(self.codeword_len, Goldilocks4::default());
Goldilocks4::ntt(padded)
}
}
impl LinearCode for ReedSolomon<Goldilocks4> {
type Alphabet = Goldilocks4;
}
pub(crate) struct InterleavedCode<EC> {
interleaving_factor: usize,
inner_code: EC,
}
impl<EC> InterleavedCode<EC> {
pub(crate) fn new(inner_code: EC, interleaving_factor: usize) -> Self {
Self {
interleaving_factor,
inner_code,
}
}
pub(crate) fn interleaving_factor(&self) -> usize {
self.interleaving_factor
}
pub(crate) fn inner_code(&self) -> &EC {
&self.inner_code
}
}
impl<EC> AdditiveCode for InterleavedCode<EC>
where
EC: AdditiveCode,
{
type InputAlphabet = EC::InputAlphabet;
type OutputAlphabet = Vec<EC::OutputAlphabet>;
fn msg_len(&self) -> usize {
self.inner_code.msg_len() * self.interleaving_factor
}
fn codeword_len(&self) -> usize {
self.inner_code.codeword_len()
}
fn encode(&self, input: &[Self::InputAlphabet]) -> Vec<Self::OutputAlphabet> {
assert!(self.interleaving_factor > 0);
assert!(input.len().is_multiple_of(self.interleaving_factor));
assert_eq!(input.len(), self.msg_len());
let chunk_size = input.len() / self.interleaving_factor;
assert_eq!(chunk_size, self.inner_code.msg_len());
let codeword_len = self.codeword_len();
let encoded_chunks: Vec<_> = input
.chunks_exact(chunk_size)
.map(|chunk| self.inner_code.encode(chunk))
.collect();
(0..codeword_len)
.map(|i| encoded_chunks.iter().map(|c| c[i].clone()).collect())
.collect()
}
}
impl InterleavedCode<ReedSolomon<Goldilocks4>> {
pub(crate) fn encode_slab(&self, input: &[Goldilocks4]) -> CodewordSlab {
let k_int = self.interleaving_factor;
assert!(k_int > 0);
assert_eq!(input.len(), self.msg_len());
let chunk_size = input.len() / k_int;
let codeword_len = self.codeword_len();
let encoded_chunks: Vec<Vec<Goldilocks4>> = input
.chunks_exact(chunk_size)
.map(|chunk| self.inner_code.encode(chunk))
.collect();
let mut flat: Vec<Goldilocks4> = Vec::with_capacity(codeword_len * k_int);
for i in 0..codeword_len {
for chunk_enc in &encoded_chunks {
flat.push(chunk_enc[i]);
}
}
CodewordSlab::new(flat, k_int)
}
}
#[derive(Clone, Debug)]
pub(crate) struct CodewordSlab {
pub(crate) data: Vec<Goldilocks4>,
pub(crate) width: usize,
}
impl CodewordSlab {
pub(crate) fn new(data: Vec<Goldilocks4>, width: usize) -> Self {
assert!(width > 0);
assert_eq!(
data.len() % width,
0,
"CodewordSlab::new: data.len() ({}) not a multiple of width ({})",
data.len(),
width
);
Self { data, width }
}
pub(crate) fn positions(&self) -> usize {
self.data.len() / self.width
}
pub(crate) fn position(&self, i: usize) -> &[Goldilocks4] {
&self.data[i * self.width..(i + 1) * self.width]
}
pub(crate) fn iter_positions(&self) -> impl Iterator<Item = &[Goldilocks4]> + '_ {
self.data.chunks_exact(self.width)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::Goldilocks;
fn from_base(g: Goldilocks) -> Goldilocks4 {
Goldilocks4::new([
g,
Goldilocks::new(0),
Goldilocks::new(0),
Goldilocks::new(0),
])
}
#[test]
fn reed_solomon_evaluates_polynomial_at_subgroup() {
let k = 8;
let code = ReedSolomon::<Goldilocks4>::new(k);
let mut msg = vec![Goldilocks4::default(); k];
msg[0] = from_base(Goldilocks::new(2));
msg[1] = from_base(Goldilocks::new(1));
let codeword = code.encode(&msg);
assert_eq!(codeword.len(), 32);
let g = Goldilocks4::two_adic_generator(5);
for (i, cw_i) in codeword.iter().enumerate() {
let x = g.pow(i as u64);
let expected = from_base(Goldilocks::new(2)) + x;
assert_eq!(*cw_i, expected, "i={i}");
}
}
#[test]
fn reed_solomon_linearity() {
let k = 16;
let code = ReedSolomon::<Goldilocks4>::new(k);
let m1: Vec<_> = (0..k)
.map(|i| from_base(Goldilocks::new(i as u64)))
.collect();
let m2: Vec<_> = (0..k)
.map(|i| from_base(Goldilocks::new((i * 3) as u64)))
.collect();
let sum: Vec<_> = m1.iter().zip(&m2).map(|(a, b)| *a + *b).collect();
let c1 = code.encode(&m1);
let c2 = code.encode(&m2);
let cs = code.encode(&sum);
for (i, (a, (b, c))) in cs.iter().zip(c1.iter().zip(c2.iter())).enumerate() {
assert_eq!(*a, *b + *c, "linearity failed at i={i}");
}
}
#[test]
fn encode_slab_matches_encode_layout() {
use crate::field::Goldilocks;
let inner = ReedSolomon::<Goldilocks4>::new(8);
let ic = InterleavedCode::new(inner, 4);
let zero = Goldilocks::new(0);
let msg: Vec<Goldilocks4> = (0..32u64)
.map(|n| Goldilocks4::new([Goldilocks::new(n), Goldilocks::new(n + 100), zero, zero]))
.collect();
let nested = ic.encode(&msg);
let slab = ic.encode_slab(&msg);
assert_eq!(nested.len(), slab.positions());
for (i, pos) in nested.iter().enumerate() {
assert_eq!(slab.position(i), pos.as_slice(), "mismatch at position {i}");
}
}
#[test]
fn codeword_slab_position_iteration() {
use crate::field::Goldilocks;
let zero = Goldilocks::new(0);
let data: Vec<Goldilocks4> = (0..12u64)
.map(|n| Goldilocks4::new([Goldilocks::new(n), zero, zero, zero]))
.collect();
let slab = CodewordSlab::new(data, 3);
assert_eq!(slab.positions(), 4);
assert_eq!(slab.position(0).len(), 3);
assert_eq!(slab.iter_positions().count(), 4);
}
}