libpep/internal/
arithmetic.rsuse curve25519_dalek_libpep::ristretto::CompressedRistretto;
use curve25519_dalek_libpep::ristretto::RistrettoPoint;
use curve25519_dalek_libpep::scalar::Scalar;
use curve25519_dalek_libpep::traits::Identity;
use std::fmt::Formatter;
use rand_core::{CryptoRng, RngCore};
use serde::de::{Error, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::Sha256;
pub const G: GroupElement =
GroupElement(curve25519_dalek_libpep::constants::RISTRETTO_BASEPOINT_POINT);
#[derive(Debug)]
pub struct ZeroArgumentError;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct GroupElement(RistrettoPoint);
impl GroupElement {
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
Self(RistrettoPoint::random(rng))
}
pub fn decode(v: &[u8; 32]) -> Option<Self> {
CompressedRistretto(*v).decompress().map(Self)
}
pub fn decode_from_slice(v: &[u8]) -> Option<Self> {
CompressedRistretto::from_slice(v)
.ok()?
.decompress()
.map(Self)
}
pub fn encode(&self) -> [u8; 32] {
self.0.compress().0
}
pub fn decode_from_hash(v: &[u8; 64]) -> Self {
Self(RistrettoPoint::from_uniform_bytes(v))
}
pub fn decode_lizard(v: &[u8; 16]) -> Self {
Self(RistrettoPoint::lizard_encode::<Sha256>(v))
}
pub fn encode_lizard(&self) -> Option<[u8; 16]> {
self.0.lizard_decode::<Sha256>()
}
pub fn decode_from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let bytes = match hex::decode(s) {
Ok(v) => v,
Err(_) => return None,
};
CompressedRistretto::from_slice(&bytes)
.unwrap()
.decompress()
.map(Self)
}
pub fn encode_as_hex(&self) -> String {
hex::encode(self.encode())
}
pub fn identity() -> Self {
Self(RistrettoPoint::identity())
}
}
impl Serialize for GroupElement {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.encode_as_hex().as_str())
}
}
impl<'de> Deserialize<'de> for GroupElement {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct GroupElementVisitor;
impl Visitor<'_> for GroupElementVisitor {
type Value = GroupElement;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a hex encoded string representing a GroupElement")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
GroupElement::decode_from_hex(v)
.ok_or(E::custom(format!("invalid hex encoded string: {}", v)))
}
}
deserializer.deserialize_str(GroupElementVisitor)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct ScalarNonZero(Scalar);
impl ScalarNonZero {
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
loop {
let r = ScalarCanBeZero::random(rng);
if let Ok(s) = r.try_into() {
return s;
}
}
}
pub fn decode(v: &[u8; 32]) -> Option<Self> {
ScalarCanBeZero::decode(v).and_then(|x| x.try_into().ok())
}
pub fn decode_from_slice(v: &[u8]) -> Option<Self> {
ScalarCanBeZero::decode_from_slice(v).and_then(|x| x.try_into().ok())
}
pub fn decode_from_hash(v: &[u8; 64]) -> Self {
let retval = Scalar::from_bytes_mod_order_wide(v);
if retval.as_bytes().iter().all(|x| *x == 0) {
Self(Scalar::ONE)
} else {
Self(retval)
}
}
pub fn decode_from_hex(s: &str) -> Option<Self> {
ScalarCanBeZero::decode_from_hex(s).and_then(|x| x.try_into().ok())
}
pub fn one() -> Self {
Self(Scalar::ONE)
}
pub fn invert(&self) -> Self {
Self(self.0.invert())
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct ScalarCanBeZero(Scalar);
impl ScalarCanBeZero {
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
Self(Scalar::random(rng))
}
pub fn decode(v: &[u8; 32]) -> Option<Self> {
Option::from(Scalar::from_canonical_bytes(*v).map(Self))
}
pub fn decode_from_slice(v: &[u8]) -> Option<Self> {
if v.len() != 32 {
None
} else {
let mut tmp = [0u8; 32];
tmp.copy_from_slice(v);
Option::from(Scalar::from_canonical_bytes(tmp).map(Self))
}
}
pub fn decode_from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let bytes = match hex::decode(s) {
Ok(v) => v,
Err(_) => return None,
};
let mut tmp = [0u8; 32];
tmp.copy_from_slice(&bytes);
Option::from(Scalar::from_canonical_bytes(tmp).map(Self))
}
pub fn one() -> Self {
Self(Scalar::ONE)
}
pub fn zero() -> Self {
Self(Scalar::ZERO)
}
pub fn is_zero(&self) -> bool {
self.0.as_bytes().iter().all(|x| *x == 0)
}
}
impl From<ScalarNonZero> for ScalarCanBeZero {
fn from(value: ScalarNonZero) -> Self {
Self(value.0)
}
}
impl TryFrom<ScalarCanBeZero> for ScalarNonZero {
type Error = ZeroArgumentError;
fn try_from(value: ScalarCanBeZero) -> Result<Self, Self::Error> {
if value.is_zero() {
Err(ZeroArgumentError)
} else {
Ok(Self(value.0))
}
}
}
pub trait ScalarTraits {
fn encode(&self) -> [u8; 32] {
let mut retval = [0u8; 32];
retval[0..32].clone_from_slice(self.raw().as_bytes());
retval
}
fn encode_as_hex(&self) -> String {
hex::encode(self.encode())
}
fn raw(&self) -> &Scalar;
}
impl ScalarTraits for ScalarCanBeZero {
fn raw(&self) -> &Scalar {
&self.0
}
}
impl ScalarTraits for ScalarNonZero {
fn raw(&self) -> &Scalar {
&self.0
}
}
impl<'b> std::ops::Add<&'b ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(self, rhs: &'b ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 + rhs.0)
}
}
impl<'b> std::ops::Add<&'b ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(mut self, rhs: &'b ScalarCanBeZero) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl std::ops::Add<ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(self, mut rhs: ScalarCanBeZero) -> Self::Output {
rhs.0 += self.0;
rhs
}
}
impl std::ops::Add<ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn add(mut self, rhs: ScalarCanBeZero) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl<'b> std::ops::Sub<&'b ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(self, rhs: &'b ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 - rhs.0)
}
}
impl<'b> std::ops::Sub<&'b ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(mut self, rhs: &'b ScalarCanBeZero) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl std::ops::Sub<ScalarCanBeZero> for &ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(self, rhs: ScalarCanBeZero) -> Self::Output {
ScalarCanBeZero(self.0 - rhs.0)
}
}
impl std::ops::Sub<ScalarCanBeZero> for ScalarCanBeZero {
type Output = ScalarCanBeZero;
fn sub(mut self, rhs: Self) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl<'b> std::ops::Mul<&'b ScalarNonZero> for &ScalarNonZero {
type Output = ScalarNonZero;
fn mul(self, rhs: &'b ScalarNonZero) -> Self::Output {
ScalarNonZero(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b ScalarNonZero> for ScalarNonZero {
type Output = ScalarNonZero;
fn mul(mut self, rhs: &'b ScalarNonZero) -> Self::Output {
self.0 *= rhs.0;
self
}
}
impl std::ops::Mul<ScalarNonZero> for &ScalarNonZero {
type Output = ScalarNonZero;
fn mul(self, mut rhs: ScalarNonZero) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<ScalarNonZero> for ScalarNonZero {
type Output = ScalarNonZero;
fn mul(mut self, rhs: Self) -> Self::Output {
self.0 *= rhs.0;
self
}
}
impl<'b> std::ops::Add<&'b GroupElement> for &GroupElement {
type Output = GroupElement;
fn add(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 + rhs.0)
}
}
impl<'b> std::ops::Add<&'b GroupElement> for GroupElement {
type Output = GroupElement;
fn add(mut self, rhs: &'b GroupElement) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl std::ops::Add<GroupElement> for &GroupElement {
type Output = GroupElement;
fn add(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 += self.0;
rhs
}
}
impl std::ops::Add<GroupElement> for GroupElement {
type Output = GroupElement;
fn add(mut self, rhs: Self) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl<'b> std::ops::Sub<&'b GroupElement> for &GroupElement {
type Output = GroupElement;
fn sub(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 - rhs.0)
}
}
impl<'b> std::ops::Sub<&'b GroupElement> for GroupElement {
type Output = GroupElement;
fn sub(mut self, rhs: &'b GroupElement) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl std::ops::Sub<GroupElement> for &GroupElement {
type Output = GroupElement;
fn sub(self, rhs: GroupElement) -> Self::Output {
GroupElement(self.0 - rhs.0)
}
}
impl std::ops::Sub<GroupElement> for GroupElement {
type Output = GroupElement;
fn sub(mut self, rhs: Self) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for &ScalarNonZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for ScalarNonZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl std::ops::Mul<GroupElement> for &ScalarNonZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<GroupElement> for ScalarNonZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for &ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl std::ops::Mul<GroupElement> for &ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<GroupElement> for ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}