use curve25519_dalek::scalar::Scalar;
use rand_core::{CryptoRng, Rng};
#[cfg(feature = "serde")]
use serde::de::{Error, Visitor};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde")]
use std::fmt::Formatter;
#[derive(Debug)]
pub struct ZeroArgumentError;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct ScalarNonZero(pub(crate) Scalar);
impl ScalarNonZero {
#[must_use]
pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
loop {
let r = ScalarCanBeZero::random(rng);
if let Ok(s) = r.try_into() {
return s;
}
}
}
#[must_use]
pub fn from_bytes(v: &[u8; 32]) -> Option<Self> {
ScalarCanBeZero::from_bytes(v).and_then(|x| x.try_into().ok())
}
#[must_use]
pub fn from_slice(v: &[u8]) -> Option<Self> {
ScalarCanBeZero::from_slice(v).and_then(|x| x.try_into().ok())
}
#[must_use]
pub fn from_hash(v: &[u8; 64]) -> Self {
let retval = Scalar::from_bytes_mod_order_wide(v);
if retval.as_bytes().iter().all(|x| *x == 0) {
Self(Scalar::ONE)
} else {
Self(retval)
}
}
#[must_use]
pub fn from_hex(s: &str) -> Option<Self> {
ScalarCanBeZero::from_hex(s).and_then(|x| x.try_into().ok())
}
#[must_use]
pub fn one() -> Self {
Self(Scalar::ONE)
}
#[must_use]
pub fn invert(&self) -> Self {
Self(self.0.invert())
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct ScalarCanBeZero(pub(crate) Scalar);
impl ScalarCanBeZero {
#[must_use]
pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
Self(Scalar::random(rng))
}
#[must_use]
pub fn from_bytes(v: &[u8; 32]) -> Option<Self> {
Option::from(Scalar::from_canonical_bytes(*v).map(Self))
}
#[must_use]
pub fn from_slice(v: &[u8]) -> Option<Self> {
if v.len() != 32 {
None
} else {
let mut tmp = [0u8; 32];
tmp.copy_from_slice(v);
Option::from(Scalar::from_canonical_bytes(tmp).map(Self))
}
}
#[must_use]
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let bytes = match hex::decode(s) {
Ok(v) => v,
Err(_) => return None,
};
let mut tmp = [0u8; 32];
tmp.copy_from_slice(&bytes);
Option::from(Scalar::from_canonical_bytes(tmp).map(Self))
}
#[must_use]
pub fn one() -> Self {
Self(Scalar::ONE)
}
#[must_use]
pub fn zero() -> Self {
Self(Scalar::ZERO)
}
#[must_use]
pub fn is_zero(&self) -> bool {
self.0 == Scalar::ZERO
}
}
impl From<ScalarNonZero> for ScalarCanBeZero {
fn from(value: ScalarNonZero) -> Self {
Self(value.0)
}
}
impl TryFrom<ScalarCanBeZero> for ScalarNonZero {
type Error = ZeroArgumentError;
fn try_from(value: ScalarCanBeZero) -> Result<Self, Self::Error> {
if value.is_zero() {
Err(ZeroArgumentError)
} else {
Ok(Self(value.0))
}
}
}
pub trait ScalarTraits {
fn to_bytes(&self) -> [u8; 32] {
let mut retval = [0u8; 32];
retval[0..32].clone_from_slice(self.raw().as_bytes());
retval
}
fn to_hex(&self) -> String {
hex::encode(self.to_bytes())
}
fn raw(&self) -> &Scalar;
}
impl ScalarTraits for ScalarCanBeZero {
fn raw(&self) -> &Scalar {
&self.0
}
}
impl ScalarTraits for ScalarNonZero {
fn raw(&self) -> &Scalar {
&self.0
}
}
#[cfg(feature = "serde")]
impl Serialize for ScalarNonZero {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_hex())
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for ScalarNonZero {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarNonZeroVisitor;
impl Visitor<'_> for ScalarNonZeroVisitor {
type Value = ScalarNonZero;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a hex encoded string representing a non-zero scalar")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
ScalarNonZero::from_hex(v)
.ok_or_else(|| E::custom(format!("invalid hex encoded string: {v}")))
}
}
deserializer.deserialize_str(ScalarNonZeroVisitor)
}
}
impl<'b> std::ops::Add<&'b ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(self, rhs: &'b ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 + rhs.0)
}
}
impl<'b> std::ops::Add<&'b ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(mut self, rhs: &'b ScalarCanBeZero) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl std::ops::Add<ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(self, mut rhs: ScalarCanBeZero) -> Self::Output {
rhs.0 += self.0;
rhs
}
}
impl std::ops::Add<ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(mut self, rhs: ScalarCanBeZero) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl<'b> std::ops::Sub<&'b ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(self, rhs: &'b ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 - rhs.0)
}
}
impl<'b> std::ops::Sub<&'b ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(mut self, rhs: &'b ScalarCanBeZero) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl std::ops::Sub<ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(self, rhs: ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 - rhs.0)
}
}
impl std::ops::Sub<ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(mut self, rhs: Self) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl<'b> std::ops::Mul<&'b ScalarNonZero> for &ScalarNonZero {
type Output = ScalarNonZero;
fn mul(self, rhs: &'b ScalarNonZero) -> Self::Output {
ScalarNonZero(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b ScalarNonZero> for ScalarNonZero {
type Output = ScalarNonZero;
fn mul(mut self, rhs: &'b ScalarNonZero) -> Self::Output {
self.0 *= rhs.0;
self
}
}
impl std::ops::Mul<ScalarNonZero> for &ScalarNonZero {
type Output = ScalarNonZero;
fn mul(self, mut rhs: ScalarNonZero) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<ScalarNonZero> for ScalarNonZero {
type Output = ScalarNonZero;
fn mul(mut self, rhs: Self) -> Self::Output {
self.0 *= rhs.0;
self
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn encode_decode_non_zero() {
let mut rng = rand::rng();
let original = ScalarNonZero::random(&mut rng);
let encoded = original.to_bytes();
let decoded = ScalarNonZero::from_bytes(&encoded).expect("decoding should succeed");
assert_eq!(decoded, original);
}
#[test]
fn encode_decode_can_be_zero() {
let mut rng = rand::rng();
let original = ScalarCanBeZero::random(&mut rng);
let encoded = original.to_bytes();
let decoded = ScalarCanBeZero::from_bytes(&encoded).expect("decoding should succeed");
assert_eq!(decoded, original);
}
#[test]
fn addition() {
let mut rng = rand::rng();
let a = ScalarNonZero::random(&mut rng);
let b = ScalarNonZero::random(&mut rng);
let sum = ScalarCanBeZero::from(a) + ScalarCanBeZero::from(b);
assert_ne!(sum, ScalarCanBeZero::zero()); }
#[test]
fn multiplication() {
let mut rng = rand::rng();
let a = ScalarNonZero::random(&mut rng);
let b = ScalarNonZero::random(&mut rng);
let product = a * b;
assert_ne!(product, ScalarNonZero::one()); }
#[test]
fn inversion() {
let mut rng = rand::rng();
let a = ScalarNonZero::random(&mut rng);
let inv = a.invert();
let should_be_one = a * inv;
assert_eq!(should_be_one, ScalarNonZero::one());
}
}