use core::ops::{Add, Mul};
use crate::constants::SCALAR_LENGTH;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub type U256le = [u8; 32];
pub type U512le = [u8; 64];
#[repr(C)]
#[derive(Clone, Debug, Default, PartialEq, Zeroize, ZeroizeOnDrop)]
pub struct Scalar(pub [u8; SCALAR_LENGTH]);
type UnpackedScalar = crate::scalar29::Scalar29;
impl From<&[u8; SCALAR_LENGTH]> for Scalar {
fn from(bytes: &[u8; SCALAR_LENGTH]) -> Scalar {
Scalar(*bytes)
}
}
impl UnpackedScalar {
fn pack(&self) -> Scalar {
Scalar(self.to_bytes())
}
}
impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
type Output = Scalar;
#[allow(non_snake_case)]
fn add(self, _rhs: &'b Scalar) -> Scalar {
let sum = UnpackedScalar::add(&self.unpack(), &_rhs.unpack());
let sum_R = UnpackedScalar::mul_internal(&sum, &crate::scalar29::constants::R);
let sum_mod_l = UnpackedScalar::montgomery_reduce(&sum_R);
sum_mod_l.pack()
}
}
impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
type Output = Scalar;
fn mul(self, _rhs: &'b Scalar) -> Scalar {
UnpackedScalar::mul(&self.unpack(), &_rhs.unpack()).pack()
}
}
impl Scalar {
#[allow(non_snake_case)]
const L: [u8; 32] = [
0xed, 0xd3, 0xf5, 0x5c, 0x1a, 0x63, 0x12, 0x58, 0xd6, 0x9c, 0xf7, 0xa2, 0xde, 0xf9, 0xde,
0x14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x10,
];
pub fn ell() -> U256le {
Scalar::L
}
pub fn from_bytes(bytes: &[u8; SCALAR_LENGTH]) -> Self {
Scalar(*bytes)
}
pub fn as_bytes(&self) -> &[u8; SCALAR_LENGTH] {
&self.0
}
pub fn to_bytes(&self) -> [u8; SCALAR_LENGTH] {
self.0
}
pub(crate) fn bits(&self) -> [i8; 256] {
let mut bits = [0i8; 256];
for (i, bit) in bits.iter_mut().enumerate() {
*bit = ((self.0[i >> 3] >> (i & 7)) & 1u8) as i8;
}
bits
}
pub fn from_u256_le(x: &U256le) -> Scalar {
let s_unreduced = Scalar(*x);
let s = s_unreduced.reduce();
debug_assert_eq!(0u8, s.0[31] >> 7);
s
}
pub fn from_u512_le(x: &U512le) -> Scalar {
UnpackedScalar::from_bytes_wide(x).pack()
}
pub(crate) fn unpack(&self) -> UnpackedScalar {
UnpackedScalar::from_bytes(&self.0)
}
#[allow(non_snake_case)]
pub fn reduce(&self) -> Scalar {
let x = self.unpack();
let xR = UnpackedScalar::mul_internal(&x, &crate::scalar29::constants::R);
let x_mod_l = UnpackedScalar::montgomery_reduce(&xR);
x_mod_l.pack()
}
pub fn is_canonical(&self) -> bool {
*self == self.reduce()
}
pub fn one() -> Self {
Self::from(1u64)
}
}
impl From<u64> for Scalar {
fn from(scalar: u64) -> Self {
let mut scalar_le = [0u8; 32];
scalar_le[..8].copy_from_slice(&scalar.to_le_bytes());
Scalar::from_u256_le(&scalar_le)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn from_unsigned() {
let one = Scalar::one();
let two = &one + &one;
let three = &two + &one;
let five = &two + &three;
assert_eq!(five, Scalar::from(5u64));
}
#[test]
fn zeroize_on_drop() {
let mut one = Scalar([1u8; SCALAR_LENGTH]);
assert_ne!(one.0, [0u8; SCALAR_LENGTH]);
unsafe {
core::ptr::drop_in_place(&mut one);
}
assert_eq!(one.0, [0u8; SCALAR_LENGTH]);
}
}