use curve25519_dalek::ristretto::CompressedRistretto;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::traits::Identity;
use rand_core::{CryptoRng, Rng};
#[cfg(feature = "serde")]
use serde::de::{Error, Visitor};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::Sha256;
#[cfg(feature = "serde")]
use std::fmt::Formatter;
use std::hash::Hash;
pub const G: GroupElement = GroupElement(curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT);
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct GroupElement(pub(crate) RistrettoPoint);
impl GroupElement {
#[must_use]
pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
Self(RistrettoPoint::random(rng))
}
#[must_use]
pub fn from_bytes(v: &[u8; 32]) -> Option<Self> {
CompressedRistretto(*v).decompress().map(Self)
}
#[must_use]
pub fn from_slice(v: &[u8]) -> Option<Self> {
CompressedRistretto::from_slice(v)
.ok()?
.decompress()
.map(Self)
}
#[must_use]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.compress().0
}
#[must_use]
pub fn from_hash(v: &[u8; 64]) -> Self {
Self(RistrettoPoint::from_uniform_bytes(v))
}
#[must_use]
pub fn from_lizard(v: &[u8; 16]) -> Self {
Self(RistrettoPoint::lizard_encode::<Sha256>(v))
}
#[must_use]
pub fn to_lizard(&self) -> Option<[u8; 16]> {
self.0.lizard_decode::<Sha256>()
}
#[must_use]
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let bytes = match hex::decode(s) {
Ok(v) => v,
Err(_) => return None,
};
#[allow(clippy::unwrap_used)]
CompressedRistretto::from_slice(&bytes)
.unwrap()
.decompress()
.map(Self)
}
#[must_use]
pub fn to_hex(&self) -> String {
hex::encode(self.to_bytes())
}
#[must_use]
pub fn identity() -> Self {
Self(RistrettoPoint::identity())
}
}
#[cfg(feature = "serde")]
impl Serialize for GroupElement {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_hex().as_str())
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for GroupElement {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct GroupElementVisitor;
impl Visitor<'_> for GroupElementVisitor {
type Value = GroupElement;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a hex encoded string representing a GroupElement")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
GroupElement::from_hex(v)
.ok_or(E::custom(format!("invalid hex encoded string: {v}")))
}
}
deserializer.deserialize_str(GroupElementVisitor)
}
}
impl Hash for GroupElement {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.to_bytes().hash(state);
}
}
use super::scalars::{ScalarCanBeZero, ScalarNonZero};
impl<'b> std::ops::Add<&'b GroupElement> for &GroupElement {
type Output = GroupElement;
fn add(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 + rhs.0)
}
}
impl<'b> std::ops::Add<&'b GroupElement> for GroupElement {
type Output = GroupElement;
fn add(mut self, rhs: &'b GroupElement) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl std::ops::Add<GroupElement> for &GroupElement {
type Output = GroupElement;
fn add(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 += self.0;
rhs
}
}
impl std::ops::Add<GroupElement> for GroupElement {
type Output = GroupElement;
fn add(mut self, rhs: Self) -> Self::Output {
self.0 += rhs.0;
self
}
}
impl<'b> std::ops::Sub<&'b GroupElement> for &GroupElement {
type Output = GroupElement;
fn sub(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 - rhs.0)
}
}
impl<'b> std::ops::Sub<&'b GroupElement> for GroupElement {
type Output = GroupElement;
fn sub(mut self, rhs: &'b GroupElement) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl std::ops::Sub<GroupElement> for &GroupElement {
type Output = GroupElement;
fn sub(self, rhs: GroupElement) -> Self::Output {
GroupElement(self.0 - rhs.0)
}
}
impl std::ops::Sub<GroupElement> for GroupElement {
type Output = GroupElement;
fn sub(mut self, rhs: Self) -> Self::Output {
self.0 -= rhs.0;
self
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for &ScalarNonZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for ScalarNonZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl std::ops::Mul<GroupElement> for &ScalarNonZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<GroupElement> for ScalarNonZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for &ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl<'b> std::ops::Mul<&'b GroupElement> for ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, rhs: &'b GroupElement) -> Self::Output {
GroupElement(self.0 * rhs.0)
}
}
impl std::ops::Mul<GroupElement> for &ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
impl std::ops::Mul<GroupElement> for ScalarCanBeZero {
type Output = GroupElement;
fn mul(self, mut rhs: GroupElement) -> Self::Output {
rhs.0 *= self.0;
rhs
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::arithmetic::scalars::ScalarNonZero;
#[test]
fn encode_decode() {
let mut rng = rand::rng();
let original = GroupElement::random(&mut rng);
let encoded = original.to_bytes();
let decoded = GroupElement::from_bytes(&encoded).expect("decoding should succeed");
assert_eq!(decoded, original);
}
#[test]
#[cfg(feature = "serde")]
fn serde_json_roundtrip() {
let mut rng = rand::rng();
let original = GroupElement::random(&mut rng);
let json = serde_json::to_string(&original).expect("serialization should succeed");
let deserialized: GroupElement =
serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(deserialized, original);
}
#[test]
fn decode_arbitrary_bytes() {
let bytes = b"test data dsfdsdfsd wefwefew dfd";
let element = GroupElement::from_bytes(bytes).expect("decoding should succeed");
let encoded = element.to_bytes();
assert_eq!(encoded, *bytes);
}
#[test]
fn addition_is_commutative() {
let mut rng = rand::rng();
let g = GroupElement::random(&mut rng);
let h = GroupElement::random(&mut rng);
assert_eq!(g + h, h + g);
}
#[test]
fn scalar_multiplication_distributes() {
let mut rng = rand::rng();
let g = GroupElement::random(&mut rng);
let h = GroupElement::random(&mut rng);
let s = ScalarNonZero::random(&mut rng);
let left = s * (g + h);
let right = s * g + s * h;
assert_eq!(left, right);
}
#[test]
fn scalar_identity() {
let mut rng = rand::rng();
let g = GroupElement::random(&mut rng);
let one = ScalarNonZero::one();
assert_eq!(one * g, g);
}
#[test]
fn lizard_edge_cases() {
let edge_cases = [
"00000000000000000000000000000000",
"00ffffffffffffffffffffffffffffff",
"f3ffffffffffffffffffffffffffff7f",
"ffffffffffffffffffffffffffffffff",
"01ffffffffffffffffffffffffffffff",
"edffffffffffffffffffffffffffff7f",
"01000000000000000000000000000000",
"ecffffffffffffffffffffffffffff7f",
];
for encoding in edge_cases {
let case = hex::decode(encoding).expect("hex decoding should succeed");
let bytes = <&[u8; 16]>::try_from(case.as_slice()).expect("should be 16 bytes");
let element = GroupElement::from_lizard(bytes);
let encoded = element.to_lizard().expect("lizard encoding should succeed");
assert_eq!(encoded, *bytes);
}
}
#[test]
fn lizard_random_roundtrip() {
let mut rng = rand::rng();
let mut random_bytes = [0u8; 16];
rng.fill_bytes(&mut random_bytes);
let element = GroupElement::from_lizard(&random_bytes);
let encoded = element.to_lizard().expect("lizard encoding should succeed");
assert_eq!(encoded, random_bytes);
}
#[test]
fn lizard_fails_after_scalar_multiplication() {
let mut rng = rand::rng();
let mut random_bytes = [0u8; 16];
rng.fill_bytes(&mut random_bytes);
let element = GroupElement::from_lizard(&random_bytes);
let s = ScalarNonZero::random(&mut rng);
let scaled = s * element;
assert!(scaled.to_lizard().is_none());
}
}