use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign};
pub trait RngElt:
Add<Output = Self>
+ AddAssign
+ Copy
+ Default
+ Neg<Output = Self>
+ ShrAssign<u32>
+ Sub<Output = Self>
+ SubAssign
{
}
impl RngElt for i64 {}
impl RngElt for i128 {}
pub trait Convolve<F, T: RngElt, U: RngElt, V: RngElt> {
fn read(input: F) -> T;
fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> V;
fn reduce(z: V) -> F;
#[inline(always)]
fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [V])>(
lhs: [F; N],
rhs: [U; N],
conv: C,
) -> [F; N] {
let lhs = lhs.map(Self::read);
let mut output = [V::default(); N];
conv(lhs, rhs, &mut output);
output.map(Self::reduce)
}
#[inline(always)]
fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
}
#[inline(always)]
fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
}
#[inline(always)]
fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
output[2] = Self::parity_dot(u_p, v_p);
output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
output[0] += output[2];
output[1] += output[3];
output[0] >>= 1;
output[1] >>= 1;
output[2] -= output[0];
output[3] -= output[1];
}
#[inline(always)]
fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
}
#[inline(always)]
fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
conv_n_recursive::<6, 3, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv3,
Self::negacyclic_conv3,
)
}
#[inline(always)]
fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
negacyclic_conv_n_recursive::<6, 3, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv3)
}
#[inline(always)]
fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
conv_n_recursive::<8, 4, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv4,
Self::negacyclic_conv4,
)
}
#[inline(always)]
fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
negacyclic_conv_n_recursive::<8, 4, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv4)
}
#[inline(always)]
fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
conv_n_recursive::<12, 6, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv6,
Self::negacyclic_conv6,
)
}
#[inline(always)]
fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
negacyclic_conv_n_recursive::<12, 6, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv6)
}
#[inline(always)]
fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
conv_n_recursive::<16, 8, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv8,
Self::negacyclic_conv8,
)
}
#[inline(always)]
fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
negacyclic_conv_n_recursive::<16, 8, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv8)
}
#[inline(always)]
fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) {
conv_n_recursive::<24, 12, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv12,
Self::negacyclic_conv12,
)
}
#[inline(always)]
fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
conv_n_recursive::<32, 16, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv16,
Self::negacyclic_conv16,
)
}
#[inline(always)]
fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
negacyclic_conv_n_recursive::<32, 16, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv16)
}
#[inline(always)]
fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) {
conv_n_recursive::<64, 32, T, U, V, _, _>(
lhs,
rhs,
output,
Self::conv32,
Self::negacyclic_conv32,
)
}
}
#[inline(always)]
fn conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, C, NC>(
lhs: [T; N],
rhs: [U; N],
output: &mut [V],
inner_conv: C,
inner_negacyclic_conv: NC,
) where
T: RngElt,
U: RngElt,
V: RngElt,
C: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
{
debug_assert_eq!(2 * HALF_N, N);
let mut lhs_pos = [T::default(); HALF_N]; let mut lhs_neg = [T::default(); HALF_N]; let mut rhs_pos = [U::default(); HALF_N]; let mut rhs_neg = [U::default(); HALF_N];
for i in 0..HALF_N {
let s = lhs[i];
let t = lhs[i + HALF_N];
lhs_pos[i] = s + t;
lhs_neg[i] = s - t;
let s = rhs[i];
let t = rhs[i + HALF_N];
rhs_pos[i] = s + t;
rhs_neg[i] = s - t;
}
let (left, right) = output.split_at_mut(HALF_N);
inner_negacyclic_conv(lhs_neg, rhs_neg, left);
inner_conv(lhs_pos, rhs_pos, right);
for i in 0..HALF_N {
left[i] += right[i]; left[i] >>= 1; right[i] -= left[i]; }
}
#[inline(always)]
fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, NC>(
lhs: [T; N],
rhs: [U; N],
output: &mut [V],
inner_negacyclic_conv: NC,
) where
T: RngElt,
U: RngElt,
V: RngElt,
NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
{
debug_assert_eq!(2 * HALF_N, N);
let mut lhs_even = [T::default(); HALF_N];
let mut lhs_odd = [T::default(); HALF_N];
let mut lhs_sum = [T::default(); HALF_N];
let mut rhs_even = [U::default(); HALF_N];
let mut rhs_odd = [U::default(); HALF_N];
let mut rhs_sum = [U::default(); HALF_N];
for i in 0..HALF_N {
let s = lhs[2 * i];
let t = lhs[2 * i + 1];
lhs_even[i] = s;
lhs_odd[i] = t;
lhs_sum[i] = s + t;
let s = rhs[2 * i];
let t = rhs[2 * i + 1];
rhs_even[i] = s;
rhs_odd[i] = t;
rhs_sum[i] = s + t;
}
let mut even_s_conv = [V::default(); HALF_N];
let (left, right) = output.split_at_mut(HALF_N);
inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
inner_negacyclic_conv(lhs_odd, rhs_odd, left);
inner_negacyclic_conv(lhs_sum, rhs_sum, right);
right[0] -= even_s_conv[0] + left[0];
even_s_conv[0] -= left[HALF_N - 1];
for i in 1..HALF_N {
right[i] -= even_s_conv[i] + left[i];
even_s_conv[i] += left[i - 1];
}
for i in 0..HALF_N {
output[2 * i] = even_s_conv[i];
output[2 * i + 1] = output[i + HALF_N];
}
}