use core::fmt::Debug;
use core::ops::{Add, Mul, Neg, Sub};
use zkboo::backend::{BooleanWordRef, Frontend};
use zkboo::{
backend::{Backend, WordRef},
word::{CompositeWord, Word, WordLike},
};
pub trait MontgomeryMod<W: Word, const N: usize>: Clone + Copy + Debug + PartialEq + Eq {
fn n(&self) -> CompositeWord<W, N>;
fn inv_exp(&self) -> Option<CompositeWord<W, N>>;
fn rr_mod_n(&self) -> CompositeWord<W, N>;
fn n_neg_inv(&self) -> CompositeWord<W, N>;
#[inline]
fn const_word<U: WordLike<W, N>>(&self, value: U) -> MontgomeryWord<W, N, Self> {
return MontgomeryWord::new(value, *self);
}
#[inline]
fn zero_word(&self) -> MontgomeryWord<W, N, Self> {
return MontgomeryWord::from_inner(CompositeWord::<W, N>::ZERO, *self);
}
#[inline]
fn one_word(&self) -> MontgomeryWord<W, N, Self> {
return MontgomeryWord::new(CompositeWord::<W, N>::ONE, *self);
}
fn redc<B: Backend>(&self, lo: WordRef<B, W, N>) -> WordRef<B, W, N> {
let n = self.n();
let (m, _) = lo.clone().wide_mul_const(self.n_neg_inv());
let (mn_lo, mn_hi) = m.wide_mul_const(n);
let (_, t_lo_carry) = lo.overflowing_add(mn_lo);
let (mut t, t_hi_lo_carry) = mn_hi.overflowing_add(WordRef::from_bool(t_lo_carry));
t = t_hi_lo_carry.select(t.clone() + n.wrapping_neg(), t);
t = t.clone().ge_const(n).select(t.clone() - n, t);
return t;
}
fn redc_wide<B: Backend>(
&self,
lo: WordRef<B, W, N>,
hi: WordRef<B, W, N>,
) -> WordRef<B, W, N> {
let n = self.n();
let (m, _) = lo.clone().wide_mul_const(self.n_neg_inv());
let (mn_lo, mn_hi) = m.wide_mul_const(n);
let (_, t_lo_carry) = lo.overflowing_add(mn_lo);
let (t, t_hi_carry) = hi.overflowing_add(mn_hi);
let (mut t, t_hi_lo_carry) = t.overflowing_add(WordRef::from_bool(t_lo_carry));
t = t_hi_carry.select(t.clone() + n.wrapping_neg(), t);
t = t_hi_lo_carry.select(t.clone() + n.wrapping_neg(), t);
t = t.clone().ge_const(n).select(t.clone() - n, t);
return t;
}
fn to_montgomery<B: Backend>(&self, value: WordRef<B, W, N>) -> WordRef<B, W, N> {
let (lo, hi) = value.wide_mul_const(self.rr_mod_n());
return self.redc_wide(lo, hi);
}
fn from_montgomery<B: Backend>(&self, value: WordRef<B, W, N>) -> WordRef<B, W, N> {
return self.redc(value);
}
fn reduce<B: Backend>(&self, value: WordRef<B, W, N>) -> WordRef<B, W, N> {
return self.from_montgomery(self.to_montgomery(value));
}
fn redc_const(&self, lo: CompositeWord<W, N>) -> CompositeWord<W, N> {
let n = self.n();
let (m, _) = lo.wide_mul(self.n_neg_inv());
let (mn_lo, mn_hi) = m.wide_mul(n);
let (_, t_lo_carry) = lo.overflowing_add(mn_lo);
let (mut t, t_hi_lo_carry) = mn_hi.overflowing_add(CompositeWord::from_bool(t_lo_carry));
if t_hi_lo_carry {
t = t.wrapping_add(n.wrapping_neg());
}
if t.ge(n) {
t = t.wrapping_sub(n);
}
return t;
}
fn redc_wide_const(
&self,
lo: CompositeWord<W, N>,
hi: CompositeWord<W, N>,
) -> CompositeWord<W, N> {
let n = self.n();
let (m, _) = lo.clone().wide_mul(self.n_neg_inv());
let (mn_lo, mn_hi) = m.wide_mul(n);
let (_, t_lo_carry) = lo.overflowing_add(mn_lo);
let (t, t_hi_carry) = hi.overflowing_add(mn_hi);
let (mut t, t_hi_lo_carry) = t.overflowing_add(CompositeWord::from_bool(t_lo_carry));
if t_hi_carry {
t = t.wrapping_add(n.wrapping_neg());
}
if t_hi_lo_carry {
t = t.wrapping_add(n.wrapping_neg());
}
if t.ge(n) {
t = t.wrapping_sub(n);
}
return t;
}
fn to_montgomery_const(&self, value: CompositeWord<W, N>) -> CompositeWord<W, N> {
let (lo, hi) = value.wide_mul(self.rr_mod_n());
return self.redc_wide_const(lo, hi);
}
fn from_montgomery_const(&self, value: CompositeWord<W, N>) -> CompositeWord<W, N> {
return self.redc_const(value);
}
fn reduce_const(&self, value: CompositeWord<W, N>) -> CompositeWord<W, N> {
return self.from_montgomery_const(self.to_montgomery_const(value));
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MontgomeryWord<W: Word, const N: usize, M: MontgomeryMod<W, N>> {
montgomery_val: CompositeWord<W, N>,
modulus: M,
}
impl<W: Word, const N: usize, M: MontgomeryMod<W, N>> MontgomeryWord<W, N, M> {
#[inline]
pub fn new<U: WordLike<W, N>>(value: U, modulus: M) -> Self {
return MontgomeryWord {
montgomery_val: modulus.to_montgomery_const(value.to_word()),
modulus,
};
}
pub fn from_inner(montgomery_val: CompositeWord<W, N>, modulus: M) -> Self {
return Self {
montgomery_val,
modulus,
};
}
pub fn modulus(&self) -> M {
return self.modulus;
}
pub fn into_modulus(self) -> M {
return self.modulus;
}
pub fn value(self) -> CompositeWord<W, N> {
return self.modulus.from_montgomery_const(self.montgomery_val);
}
pub fn inner(&self) -> &CompositeWord<W, N> {
return &self.montgomery_val;
}
pub fn into_inner(self) -> CompositeWord<W, N> {
return self.montgomery_val;
}
pub fn destructure(self) -> (CompositeWord<W, N>, M) {
return (self.montgomery_val, self.modulus);
}
pub fn is_zero(self) -> bool {
return self.montgomery_val.is_zero();
}
pub fn is_nonzero(self) -> bool {
return self.montgomery_val.is_nonzero();
}
pub fn inv(self) -> Self {
let e = self.modulus.inv_exp();
return match e {
Some(e) => self.inv_by_rep_squaring(e),
None => unimplemented!("Cannot yet compute inverse without explicit inverse exponent."),
};
}
fn inv_by_rep_squaring(self, mut e: CompositeWord<W, N>) -> Self {
let modulus = self.modulus;
let mut res = MontgomeryWord {
montgomery_val: modulus.to_montgomery_const(CompositeWord::<W, N>::ONE),
modulus: modulus,
};
let mut base = self;
while e.is_nonzero() {
if e.lsb() {
res = res * base.clone();
}
base = base.clone() * base;
e = e >> 1;
}
return res;
}
}
impl<W: Word, const N: usize, M: MontgomeryMod<W, N>> Neg for MontgomeryWord<W, N, M> {
type Output = Self;
fn neg(self) -> Self::Output {
if self.montgomery_val == CompositeWord::<W, N>::ZERO {
return self;
} else {
return MontgomeryWord {
montgomery_val: self.modulus.n().wrapping_sub(self.montgomery_val),
modulus: self.modulus,
};
}
}
}
impl<W: Word, const N: usize, M: MontgomeryMod<W, N>> Add<Self> for MontgomeryWord<W, N, M> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot add modular words with different moduli");
}
let (res, carry) = self.montgomery_val.overflowing_add(rhs.montgomery_val);
if carry || res.ge(self.modulus.n()) {
return MontgomeryWord {
montgomery_val: res.wrapping_sub(self.modulus.n()),
modulus: self.modulus,
};
} else {
return MontgomeryWord {
montgomery_val: res,
modulus: self.modulus,
};
}
}
}
impl<W: Word, const N: usize, M: MontgomeryMod<W, N>> Sub<Self> for MontgomeryWord<W, N, M> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot subtract modular words with different moduli");
}
let (res, borrow) = self.montgomery_val.overflowing_sub(rhs.montgomery_val);
if borrow {
return MontgomeryWord {
montgomery_val: res.wrapping_add(self.modulus.n()),
modulus: self.modulus,
};
} else {
return MontgomeryWord {
montgomery_val: res,
modulus: self.modulus,
};
}
}
}
impl<W: Word, const N: usize, M: MontgomeryMod<W, N>> Mul<Self> for MontgomeryWord<W, N, M> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot multiply modular words with different moduli");
}
let (res_lo, res_hi) = self.montgomery_val.wide_mul(rhs.montgomery_val);
return MontgomeryWord {
montgomery_val: self.modulus.redc_wide_const(res_lo, res_hi),
modulus: self.modulus,
};
}
}
#[derive(Debug)]
pub struct MontgomeryWordRef<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> {
montgomery_val: WordRef<B, W, N>,
modulus: M,
}
impl<B: Backend, M: MontgomeryMod<W, N>, W: Word, const N: usize> Clone
for MontgomeryWordRef<B, W, N, M>
{
fn clone(&self) -> Self {
return Self {
montgomery_val: self.montgomery_val.clone(),
modulus: self.modulus,
};
}
}
impl<B: Backend, M: MontgomeryMod<W, N>, W: Word, const N: usize> MontgomeryWordRef<B, W, N, M> {
pub fn new(value: WordRef<B, W, N>, modulus: M) -> Self {
return MontgomeryWordRef {
montgomery_val: modulus.to_montgomery(value),
modulus,
};
}
pub fn from_inner(montgomery_val: WordRef<B, W, N>, modulus: M) -> Self {
return Self {
montgomery_val,
modulus,
};
}
pub fn modulus(&self) -> M {
return self.modulus;
}
pub fn into_modulus(self) -> M {
return self.modulus;
}
pub fn value(self) -> WordRef<B, W, N> {
return self.modulus.from_montgomery(self.montgomery_val);
}
pub fn inner(&self) -> &WordRef<B, W, N> {
return &self.montgomery_val;
}
pub fn into_inner(self) -> WordRef<B, W, N> {
return self.montgomery_val;
}
pub fn destructure(self) -> (WordRef<B, W, N>, M) {
return (self.montgomery_val, self.modulus);
}
pub fn is_zero(self) -> BooleanWordRef<B> {
return self.montgomery_val.is_zero();
}
pub fn is_nonzero(self) -> BooleanWordRef<B> {
return self.montgomery_val.is_nonzero();
}
pub fn eq(self, other: Self) -> BooleanWordRef<B> {
if self.modulus != other.modulus {
panic!("Cannot compare modular words with different moduli");
}
return self.montgomery_val.eq(other.montgomery_val);
}
pub fn ne(self, other: Self) -> BooleanWordRef<B> {
if self.modulus != other.modulus {
panic!("Cannot compare modular words with different moduli");
}
return self.montgomery_val.ne(other.montgomery_val);
}
pub fn eq_const(self, other: MontgomeryWord<W, N, M>) -> BooleanWordRef<B> {
return self.montgomery_val.eq_const(other.into_inner());
}
pub fn ne_const(self, other: MontgomeryWord<W, N, M>) -> BooleanWordRef<B> {
return self.montgomery_val.ne_const(other.into_inner());
}
pub fn add_const(self, rhs: CompositeWord<W, N>) -> Self {
let (res, carry) = self
.montgomery_val
.overflowing_add_const(self.modulus.to_montgomery_const(rhs));
return Self {
montgomery_val: (carry | res.clone().ge_const(self.modulus.n()))
.select(res.clone() - self.modulus.n(), res),
modulus: self.modulus,
};
}
pub fn mul_const(self, rhs: CompositeWord<W, N>) -> Self {
let (res_lo, res_hi) = self
.montgomery_val
.wide_mul_const(self.modulus.to_montgomery_const(rhs));
return Self {
montgomery_val: self.modulus.redc_wide(res_lo, res_hi),
modulus: self.modulus,
};
}
pub fn into_zero(self) -> Self {
let modulus = self.modulus;
let montgomery_val = self.montgomery_val;
return Self {
montgomery_val: montgomery_val.into_zero(),
modulus,
};
}
pub fn into_const(self, word: CompositeWord<W, N>) -> Self {
let modulus = self.modulus;
let montgomery_val = self.montgomery_val;
let word = modulus.to_montgomery_const(word);
return Self {
montgomery_val: montgomery_val.into_const_same_width(word),
modulus,
};
}
pub fn into_montgomery_const(self, word: MontgomeryWord<W, N, M>) -> Self {
let modulus = self.modulus;
let montgomery_val = self.montgomery_val;
return Self {
montgomery_val: montgomery_val.into_const_same_width(word.into_inner()),
modulus,
};
}
pub fn inv(self) -> Self {
let e = self.modulus.inv_exp();
return match e {
Some(e) => self.inv_by_rep_squaring(e),
None => unimplemented!("Cannot yet compute inverse without explicit inverse exponent."),
};
}
fn inv_by_rep_squaring(self, mut e: CompositeWord<W, N>) -> Self {
let mut res = self.clone().into_const(CompositeWord::<W, N>::ONE);
let mut base = self;
while e.is_nonzero() {
if e.lsb() {
res = res * base.clone();
}
base = base.clone() * base;
e = e >> 1;
}
return res;
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Neg
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn neg(self) -> Self::Output {
let value = self.montgomery_val;
return Self {
montgomery_val: value
.clone()
.is_zero()
.select(value.clone(), -value + self.modulus.n()),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Add<Self>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot add modular words with different moduli");
}
let (res, carry) = self.montgomery_val.overflowing_add(rhs.montgomery_val);
return Self {
montgomery_val: (carry | res.clone().ge_const(self.modulus.n()))
.select(res.clone() - self.modulus.n(), res),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Add<MontgomeryWord<W, N, M>>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn add(self, rhs: MontgomeryWord<W, N, M>) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot add modular words with different moduli");
}
let (res, carry) = self
.montgomery_val
.overflowing_add_const(rhs.montgomery_val);
return Self {
montgomery_val: (carry | res.clone().ge_const(self.modulus.n()))
.select(res.clone() - self.modulus.n(), res),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Sub<Self>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot subtract modular words with different moduli");
}
let (res, borrow) = self.montgomery_val.overflowing_sub(rhs.montgomery_val);
return Self {
montgomery_val: borrow.select(res.clone() + self.modulus.n(), res),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Sub<MontgomeryWord<W, N, M>>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn sub(self, rhs: MontgomeryWord<W, N, M>) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot subtract modular words with different moduli");
}
let (res, borrow) = self
.montgomery_val
.overflowing_sub_const(rhs.montgomery_val);
return Self {
montgomery_val: borrow.select(res.clone() + self.modulus.n(), res),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Mul<Self>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot multiply modular words with different moduli");
}
let (res_lo, res_hi) = self.montgomery_val.wide_mul(rhs.montgomery_val);
return Self {
montgomery_val: self.modulus.redc_wide(res_lo, res_hi),
modulus: self.modulus,
};
}
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> Mul<MontgomeryWord<W, N, M>>
for MontgomeryWordRef<B, W, N, M>
{
type Output = Self;
fn mul(self, rhs: MontgomeryWord<W, N, M>) -> Self::Output {
if self.modulus != rhs.modulus {
panic!("Cannot multiply modular words with different moduli");
}
let (res_lo, res_hi) = self.montgomery_val.wide_mul_const(rhs.montgomery_val);
return Self {
montgomery_val: self.modulus.redc_wide(res_lo, res_hi),
modulus: self.modulus,
};
}
}
pub trait MontgomeryBooleanWordRefSelector<
B: Backend,
W: Word,
const N: usize,
M: MontgomeryMod<W, N>,
>
{
fn montgomery_select_const_const(
self,
then: MontgomeryWord<W, N, M>,
else_: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<B, W, N, M>;
fn montgomery_select_const_var(
self,
then: MontgomeryWord<W, N, M>,
else_: MontgomeryWordRef<B, W, N, M>,
) -> MontgomeryWordRef<B, W, N, M>;
fn montgomery_select_var_const(
self,
then: MontgomeryWordRef<B, W, N, M>,
else_: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<B, W, N, M>;
fn montgomery_select(
self,
then: MontgomeryWordRef<B, W, N, M>,
else_: MontgomeryWordRef<B, W, N, M>,
) -> MontgomeryWordRef<B, W, N, M>;
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>>
MontgomeryBooleanWordRefSelector<B, W, N, M> for BooleanWordRef<B>
{
fn montgomery_select_const_const(
self,
then: MontgomeryWord<W, N, M>,
else_: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<B, W, N, M> {
if then.modulus != else_.modulus {
panic!("Cannot select between words with different moduli");
}
let (then, modulus) = then.destructure();
let else_ = else_.into_inner();
return MontgomeryWordRef {
montgomery_val: self.select_const_const(then, else_),
modulus,
};
}
fn montgomery_select_const_var(
self,
then: MontgomeryWord<W, N, M>,
else_: MontgomeryWordRef<B, W, N, M>,
) -> MontgomeryWordRef<B, W, N, M> {
if then.modulus != else_.modulus {
panic!("Cannot select between words with different moduli");
}
let (then, modulus) = then.destructure();
let else_ = else_.into_inner();
return MontgomeryWordRef {
montgomery_val: self.select_const_var(then, else_),
modulus,
};
}
fn montgomery_select_var_const(
self,
then: MontgomeryWordRef<B, W, N, M>,
else_: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<B, W, N, M> {
if then.modulus != else_.modulus {
panic!("Cannot select between words with different moduli");
}
let (then, modulus) = then.destructure();
let else_ = else_.into_inner();
return MontgomeryWordRef {
montgomery_val: self.select_var_const(then, else_),
modulus,
};
}
fn montgomery_select(
self,
then: MontgomeryWordRef<B, W, N, M>,
else_: MontgomeryWordRef<B, W, N, M>,
) -> MontgomeryWordRef<B, W, N, M> {
if then.modulus != else_.modulus {
panic!("Cannot select between words with different moduli");
}
let (then, modulus) = then.destructure();
let else_ = else_.into_inner();
return MontgomeryWordRef {
montgomery_val: self.select(then, else_),
modulus,
};
}
}
pub trait MontgomeryWordRefAllocator<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> {
fn alloc_montgomery_constant(
&self,
word: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<B, W, N, M>;
}
impl<_B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>, _W: Word, const _N: usize>
MontgomeryWordRefAllocator<_B, W, N, M> for WordRef<_B, _W, _N>
{
fn alloc_montgomery_constant(
&self,
word: MontgomeryWord<W, N, M>,
) -> MontgomeryWordRef<_B, W, N, M> {
if word.is_zero() {
return MontgomeryWordRef::new(self.alloc_new_zero(), word.modulus());
}
return MontgomeryWordRef::new(self.alloc_new_word(word.value()), word.modulus());
}
}
pub trait MontgomeryFrontendIO<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> {
fn montgomery_input(&self, in_: MontgomeryWord<W, N, M>) -> MontgomeryWordRef<B, W, N, M>;
fn montgomery_alloc(&self, in_: MontgomeryWord<W, N, M>) -> MontgomeryWordRef<B, W, N, M>;
fn montgomery_output(&self, out: MontgomeryWordRef<B, W, N, M>);
}
impl<B: Backend, W: Word, const N: usize, M: MontgomeryMod<W, N>> MontgomeryFrontendIO<B, W, N, M>
for Frontend<B>
{
fn montgomery_input(&self, in_: MontgomeryWord<W, N, M>) -> MontgomeryWordRef<B, W, N, M> {
let (in_, modulus) = in_.destructure();
let in_ = self.input(in_);
return MontgomeryWordRef::from_inner(in_, modulus);
}
fn montgomery_alloc(&self, in_: MontgomeryWord<W, N, M>) -> MontgomeryWordRef<B, W, N, M> {
let (in_, modulus) = in_.destructure();
let in_ = self.alloc(in_);
return MontgomeryWordRef::from_inner(in_, modulus);
}
fn montgomery_output(&self, out: MontgomeryWordRef<B, W, N, M>) {
self.output(out.value());
}
}