use core::ops::{Index, IndexMut};
use crate::{
encode::r3::{r3_decode, r3_encode},
params::{params::P, params::R3_BYTES},
};
use super::{error::PolyErrors, f3, fq, rq::Rq};
use crate::math::nums::{i16_negative_mask, i16_nonzero_mask};
#[derive(Debug, Clone)]
pub struct R3 {
pub coeffs: [i8; P],
}
impl Default for R3 {
fn default() -> Self {
Self::new()
}
}
impl R3 {
pub fn new() -> Self {
Self { coeffs: [0i8; P] }
}
pub fn from(coeffs: [i8; P]) -> Self {
Self { coeffs }
}
pub fn eq_zero(&self) -> bool {
for c in self.coeffs {
if c != 0 {
return false;
}
}
true
}
pub fn mult(&self, g3: &R3) -> R3 {
let f = self.coeffs;
let g = g3.coeffs;
let mut out = [0i8; P];
let mut fg = [0i8; P + P - 1];
let quotient = |r: i8, f: i8, g: i8| {
let x = r + f * g;
f3::freeze(x as i16)
};
for i in 0..P {
let mut r = 0i8;
for j in 0..=i {
r = quotient(r, f[j], g[i - j]);
}
fg[i] = r;
}
for i in P..P + P - 1 {
let mut r = 0i8;
for j in (i - P + 1)..P {
r = quotient(r, f[j], g[i - j]);
}
fg[i] = r;
}
for i in (P..P + P - 1).rev() {
let x0 = fg[i - P] + fg[i];
let x1 = fg[i - P + 1] + fg[i];
fg[i - P] = f3::freeze(x0 as i16);
fg[i - P + 1] = f3::freeze(x1 as i16);
}
out[..P].clone_from_slice(&fg[..P]);
R3::from(out)
}
pub fn eq_one(&self) -> bool {
for i in 1..self.coeffs.len() {
if self.coeffs[i] != 0 {
return false;
}
}
self.coeffs[0] == 1
}
pub fn recip(&self) -> Result<R3, PolyErrors> {
let input = self.coeffs;
let mut out = [0i8; P];
let mut f = [0i8; P + 1];
let mut g = [0i8; P + 1];
let mut v = [0i8; P + 1];
let mut r = [0i8; P + 1];
let mut delta: i8;
let mut sign: i8;
let mut swap: i8;
let mut t: i8;
let quotient = |g: i8, sign: i8, f: i8| {
let x = g + sign * f;
f3::freeze(x as i16)
};
r[0] = 1;
f[0] = 1;
f[P - 1] = -1;
f[P] = -1;
for i in 0..P {
g[P - 1 - i] = input[i];
}
g[P] = 0;
delta = 1;
for _ in 0..2 * P - 1 {
for i in (1..=P).rev() {
v[i] = v[i - 1];
}
v[0] = 0;
sign = -g[0] * f[0];
swap = (i16_negative_mask(-delta as i16) & i16_nonzero_mask(g[0] as i16)) as i8;
delta ^= swap & (delta ^ -delta);
delta += 1;
for i in 0..P + 1 {
t = swap & (f[i] ^ g[i]);
f[i] ^= t;
g[i] ^= t;
t = swap & (v[i] ^ r[i]);
v[i] ^= t;
r[i] ^= t;
}
for i in 0..P + 1 {
g[i] = quotient(g[i], sign, f[i]);
}
for i in 0..P + 1 {
r[i] = quotient(r[i], sign, v[i]);
}
for i in 0..P {
g[i] = g[i + 1];
}
g[P] = 0;
}
sign = f[0];
for i in 0..P {
out[i] = sign * v[P - 1 - i];
}
if i16_nonzero_mask(delta as i16) == 0 {
Ok(R3::from(out))
} else {
Err(PolyErrors::R3NoSolutionRecip)
}
}
pub fn rq_from_r3(&self) -> Rq {
let mut out = [0i16; P];
for (i, v) in out.iter_mut().enumerate() {
*v = fq::freeze(self.coeffs[i].into());
}
Rq::from(out)
}
pub fn to_bytes(&self) -> [u8; R3_BYTES] {
r3_encode(self.as_ref())
}
}
impl From<Rq> for R3 {
fn from(value: Rq) -> Self {
value.r3_from_rq()
}
}
impl From<&[u8; R3_BYTES]> for R3 {
fn from(value: &[u8; R3_BYTES]) -> Self {
r3_decode(value).into()
}
}
impl From<[u8; R3_BYTES]> for R3 {
fn from(value: [u8; R3_BYTES]) -> Self {
r3_decode(&value).into()
}
}
impl From<[i8; P]> for R3 {
fn from(coeffs: [i8; P]) -> Self {
R3 { coeffs }
}
}
impl From<R3> for [i8; P] {
fn from(r3: R3) -> Self {
r3.coeffs
}
}
impl AsRef<[i8; P]> for R3 {
fn as_ref(&self) -> &[i8; P] {
&self.coeffs
}
}
impl AsRef<[i8]> for R3 {
fn as_ref(&self) -> &[i8] {
&self.coeffs
}
}
impl AsMut<[i8; P]> for R3 {
fn as_mut(&mut self) -> &mut [i8; P] {
&mut self.coeffs
}
}
impl AsMut<[i8]> for R3 {
fn as_mut(&mut self) -> &mut [i8] {
&mut self.coeffs
}
}
impl Index<usize> for R3 {
type Output = i8;
fn index(&self, index: usize) -> &Self::Output {
&self.coeffs[index]
}
}
impl IndexMut<usize> for R3 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.coeffs[index]
}
}
impl TryFrom<&[i8]> for R3 {
type Error = PolyErrors;
fn try_from(slice: &[i8]) -> Result<Self, Self::Error> {
if slice.len() != P {
Err(PolyErrors::SliceLengthNotR3Size)
} else {
let mut coeffs = [0; P];
coeffs.copy_from_slice(slice);
Ok(R3 { coeffs })
}
}
}
impl IntoIterator for R3 {
type Item = i8;
type IntoIter = core::array::IntoIter<i8, P>;
fn into_iter(self) -> Self::IntoIter {
self.coeffs.into_iter()
}
}
impl<'a> IntoIterator for &'a R3 {
type Item = &'a i8;
type IntoIter = core::slice::Iter<'a, i8>;
fn into_iter(self) -> Self::IntoIter {
self.coeffs.iter()
}
}
impl<'a> IntoIterator for &'a mut R3 {
type Item = &'a mut i8;
type IntoIter = core::slice::IterMut<'a, i8>;
fn into_iter(self) -> Self::IntoIter {
self.coeffs.iter_mut()
}
}
impl PartialEq<[i8; P]> for R3 {
fn eq(&self, other: &[i8; P]) -> bool {
self.coeffs == *other
}
}
#[cfg(test)]
mod test_r3 {
use super::*;
use crate::rng::random_small;
#[cfg(feature = "ntrup761")]
#[test]
fn test_r3_mult() {
let f: R3 = R3::from([
1, 0, -1, 0, 1, -1, 0, 0, -1, 0, -1, 1, -1, -1, 0, 1, 1, 0, 0, 0, 0, -1, 0, -1, 0, 1,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, 0, 0, 1, 1, 1, -1, 1, 1, 1, 1,
0, 0, 1, -1, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 1, -1, -1, -1, 0,
0, 1, 0, -1, 1, 1, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, -1, 0, -1, 1, 1, 0, 0,
1, -1, 0, 1, 0, -1, 0, -1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, -1, 0, 1, 0, 0, 1, 0, 0,
-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 1, 0, 0, -1, 0, 0, -1, 0, 0, 1, 0, 1,
0, 0, 0, -1, 1, 0, -1, 1, 0, 0, 1, 0, 1, -1, 0, 0, 1, 0, 1, -1, 1, 0, 1, 0, -1, 1, 0,
0, -1, 0, 0, 0, 0, 0, 1, -1, -1, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 1, -1, 0,
0, 0, -1, 0, -1, 1, 1, -1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, -1, 0, 0, 1, 0, 0, -1,
0, 0, -1, 1, 1, 0, 0, 1, 0, 1, 1, -1, -1, 0, 0, 0, -1, 0, 1, 0, -1, 0, 0, 0, 0, 0, -1,
0, 1, 1, -1, -1, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, -1, -1, 0, -1, 0, -1, -1,
0, 0, 1, 0, 1, 0, -1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, -1, 0, 0, 0, 0, 1, 0, 0, -1, 0,
0, -1, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0,
0, 0, 1, -1, 0, 0, 0, -1, 1, 1, 1, 0, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, 0,
-1, 0, 1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 1, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0,
0, 0, -1, -1, 0, 0, 0, 1, 1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, -1, 0, -1,
0, 0, 1, -1, -1, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1,
-1, -1, 0, 1, 0, 1, -1, 0, -1, 0, -1, -1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1,
1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 1, -1, -1, 0, -1, 0, 1, 0, -1, 0,
0, 0, 0, 0, 1, -1, 0, 0, -1, 1, 0, 1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0,
0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, -1, 1, 0, 1, 0, 0, 1, -1, 1, 0, 1,
1, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, -1, 0, 0, 0,
1, -1, 0, -1, 1, 0, 0, 1, 0, -1, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, -1, -1, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, -1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0,
-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, -1, 0, -1, 1, 1, 0, 0, 1, 0, 1, -1, -1, 0, 1,
-1, -1, 0, 0, 0, 0, -1, 1, 0, 0, -1, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 1, -1, 1, 0,
0, 0, 1, 1, 1, 0, 0, -1, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
]);
let g: R3 = R3::from([
-1, 1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, -1,
-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1,
-1, 0, -1, -1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 1, 0, 0, -1, -1,
0, -1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0,
0, -1, 1, 1, -1, 0, -1, -1, 0, 1, 0, 0, -1, -1, 1, 1, 0, -1, 0, 0, -1, 1, 0, -1, 0, 1,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, 0,
1, 0, 1, 1, -1, 0, 1, 0, -1, 1, 0, 0, 0, 1, 1, 0, 1, -1, 1, 0, 1, -1, 0, 0, 0, -1, 1,
0, 1, 1, -1, 0, 0, 1, 0, 0, -1, -1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0,
-1, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 1, 0, 1, -1, 0, 0, 0, 1, 0, 0, 1,
-1, 1, -1, 0, 0, -1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 1, 1, 1, 0, 0, -1, 1,
0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 1, 0, 0, 1, -1, 0, -1, 0, 0, 0, 0, 0, 1, 0, -1,
1, 0, -1, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, -1, -1, 0, 1, 1, 1, -1, 0, 0, 0, -1, -1, 1,
0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 0, 0, -1, 0, 1, 1,
0, -1, -1, 0, 0, 1, 0, 1, -1, -1, 0, 1, 0, 0, 0, 1, 0, 0, -1, -1, -1, 0, -1, 1, -1, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, -1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, -1, 1, 0, 0, -1, 0, 0, 0, -1, 0, -1, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0,
-1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, -1, 0, 0, 1, -1, 0, 0, 1, -1, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 1, 1, 0, -1, 1, 0, 0, 0, 1, 1, 1, -1, -1, -1, 0, 0, 0, 1, 0, 1, -1,
0, 0, -1, 1, -1, 0, 1, 0, 1, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 1,
0, 0, 1, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, 1, 0, 0, 0, 1, 0, 1,
-1, 0, 1, 0, 0, 0, 0, 1, -1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
1, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, -1, 0, 0, 1, -1, 1, -1, -1, 1, 0, 1, -1, -1, 0, 0,
0, 1, -1, -1, 1, 0, 1, -1, 1, 0, 0, 0, 0, -1, 0, 0, 0, -1, 1, 1, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 1, 0, -1, 0, 0, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0,
1, -1, 0, 0, 1, 0, 0, 0, 1, -1, 0, -1, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0,
0, 0, 1, 1, 1, -1, -1, -1, 0, 0, 0, 0, 1, -1, 1, 0, 0, 0, 0, 0, 1, 0, -1, 0, 1, -1, 0,
0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1,
]);
let h = f.mult(&g);
assert_eq!(
h.coeffs,
[
-1, 1, 1, 0, 0, 1, -1, 1, 0, 1, 0, 1, 0, 1, 1, -1, 0, 0, 0, 1, 0, 1, -1, 0, -1, -1,
0, 0, 0, 0, -1, -1, 0, 1, 0, 1, -1, -1, 1, 0, -1, -1, 1, 0, 0, -1, 1, 1, 1, -1, 1,
1, 0, 1, -1, -1, 0, 1, 1, -1, -1, -1, 0, -1, 0, -1, 1, 1, -1, 0, 0, 0, -1, 0, 0,
-1, -1, 0, -1, 1, 1, 1, -1, 0, -1, -1, -1, 1, -1, 0, -1, 0, 1, 1, -1, 0, -1, 0, 0,
0, -1, 0, -1, -1, -1, -1, 0, -1, -1, 1, 0, -1, 0, 1, 1, 0, 0, 1, 0, 0, -1, 0, 1,
-1, -1, -1, 0, -1, 1, -1, 0, 1, 1, 1, 0, -1, 1, -1, -1, 0, -1, 1, 1, 1, 1, -1, 1,
-1, 1, 0, 1, 1, 1, -1, 1, 1, 0, -1, 1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, -1, 1, 0,
0, -1, -1, 0, 0, 0, 0, -1, -1, 0, 1, -1, -1, 0, 1, 1, 0, 1, 1, -1, 0, 0, 1, 1, -1,
0, 0, 1, 0, 0, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 0, 1, 0,
0, 1, 1, 1, 0, 1, 0, 0, -1, 0, -1, 1, 1, -1, -1, 0, 0, 0, -1, 1, -1, 1, 0, 0, -1,
0, 0, 0, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, 1, 0, 0, 0, -1, 0, -1, 0, 1, 1,
0, -1, -1, 0, 1, 1, 0, 0, 1, -1, 0, 0, -1, 0, 0, -1, 1, 1, -1, -1, 0, -1, 1, 0, 0,
0, 1, 0, -1, 1, -1, 1, -1, 0, 0, 1, 0, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, -1, -1, 1, -1, 1, -1, 0, 0, 1, 1, 1, 1,
0, 0, 1, 0, -1, -1, -1, -1, 0, -1, 1, -1, -1, 0, -1, 0, 0, 0, -1, -1, 0, -1, 0, -1,
0, 0, -1, 1, 1, 1, -1, -1, 0, 0, 0, -1, -1, 0, 0, 1, 0, -1, 1, -1, -1, 1, 0, 0, 1,
0, 0, 1, 0, 1, 0, -1, 0, 0, -1, 0, 1, 0, 1, 0, -1, -1, 0, 1, 1, 1, 0, 1, -1, -1,
-1, 1, 0, 1, -1, 1, 0, 0, 0, 1, 0, -1, -1, -1, 0, 0, 1, 1, -1, 0, 0, 1, 1, 1, 1,
-1, 0, -1, -1, -1, 0, 1, 0, 1, -1, 0, -1, 0, -1, 1, -1, 0, -1, 0, -1, 1, 0, 0, 1,
-1, 1, -1, 0, 0, -1, 0, -1, 1, 0, -1, -1, 0, 0, -1, 0, 0, 1, -1, 1, 0, 1, -1, 0, 0,
1, 1, 0, 0, -1, 1, -1, 0, -1, 0, 1, 1, 0, 0, 1, 0, -1, -1, 1, 1, 0, 0, 1, 1, 1, 1,
-1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, -1,
-1, -1, 1, 1, 0, -1, -1, 1, 1, -1, 0, 1, -1, 1, 0, 0, 0, 1, 1, -1, 0, 1, 1, 1, 1,
1, 1, -1, 1, 0, 1, 0, -1, 1, -1, 1, -1, 1, -1, 1, 0, 0, -1, 0, -1, 1, 1, -1, 1, -1,
0, 1, 0, -1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 0, 1, -1, -1,
1, 0, -1, 0, 1, 0, 0, 0, -1, -1, 0, 0, 0, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, 1,
0, -1, -1, 0, -1, -1, 0, 0, 0, 0, -1, 0, -1, 1, 0, -1, 0, 0, -1, -1, -1, 1, -1, 1,
-1, -1, 0, -1, 0, 1, 0, -1, 1, -1, 1, 0, 0, -1, 0, -1, -1, 1, 1, 0, 0, -1, -1, 0,
0, 0, 1, -1, 0, -1, -1, -1, 0, -1, -1, -1, 1, 1, 0, 0, 0, 0, -1, -1, 1, 0, 1, 0,
-1, -1, 0, 0, 1, 0, 1, 0, 0, 0, -1, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, 0, 0, 1,
-1, 0, -1, 0, 0, -1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 0, 1, -1, 1
]
);
}
#[test]
fn test_recip() {
let mut rng = rand::rng();
for _ in 0..2 {
let r3: R3 = R3::from(random_small(&mut rng));
let out = match r3.recip() {
Ok(o) => o,
Err(_) => continue,
};
let one = out.mult(&r3);
assert_eq!(one.coeffs[0], 1);
assert!(one.eq_one());
}
}
}