use super::{Error, Signature};
use commonware_formatting::Hex;
use core::convert::{TryFrom, TryInto};
use curve25519_dalek::{
edwards::{CompressedEdwardsY, EdwardsPoint},
scalar::Scalar,
traits::IsIdentity,
};
use sha2::{digest::Update, Sha512};
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VerificationKeyBytes(pub(super) [u8; 32]);
impl VerificationKeyBytes {
pub const fn to_bytes(self) -> [u8; 32] {
self.0
}
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl core::fmt::Debug for VerificationKeyBytes {
fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt.debug_tuple("VerificationKeyBytes")
.field(&Hex(&self.0))
.finish()
}
}
impl AsRef<[u8]> for VerificationKeyBytes {
fn as_ref(&self) -> &[u8] {
&self.0[..]
}
}
impl TryFrom<&[u8]> for VerificationKeyBytes {
type Error = Error;
fn try_from(slice: &[u8]) -> Result<Self, Error> {
if slice.len() == 32 {
let mut bytes = [0u8; 32];
bytes[..].copy_from_slice(slice);
Ok(bytes.into())
} else {
Err(Error::InvalidSliceLength)
}
}
}
impl From<[u8; 32]> for VerificationKeyBytes {
fn from(bytes: [u8; 32]) -> Self {
Self(bytes)
}
}
impl From<VerificationKeyBytes> for [u8; 32] {
fn from(refined: VerificationKeyBytes) -> [u8; 32] {
refined.0
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
#[allow(non_snake_case)]
pub struct VerificationKey {
pub(super) A_bytes: VerificationKeyBytes,
pub(super) minus_A: EdwardsPoint,
}
impl PartialOrd for VerificationKey {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for VerificationKey {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.A_bytes.cmp(&other.A_bytes)
}
}
impl core::hash::Hash for VerificationKey {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.A_bytes.hash(state);
}
}
impl core::fmt::Debug for VerificationKey {
fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
fmt.debug_tuple("VerificationKey")
.field(&Hex(&self.A_bytes.0))
.finish()
}
}
impl From<VerificationKey> for VerificationKeyBytes {
fn from(vk: VerificationKey) -> Self {
vk.A_bytes
}
}
impl AsRef<[u8]> for VerificationKey {
fn as_ref(&self) -> &[u8] {
&self.A_bytes.0[..]
}
}
impl From<VerificationKey> for [u8; 32] {
fn from(vk: VerificationKey) -> [u8; 32] {
vk.A_bytes.0
}
}
impl TryFrom<VerificationKeyBytes> for VerificationKey {
type Error = Error;
#[allow(non_snake_case)]
fn try_from(bytes: VerificationKeyBytes) -> Result<Self, Self::Error> {
let A = CompressedEdwardsY(bytes.0)
.decompress()
.ok_or(Error::MalformedPublicKey)?;
Ok(Self {
A_bytes: bytes,
minus_A: -A,
})
}
}
impl TryFrom<&[u8]> for VerificationKey {
type Error = Error;
fn try_from(slice: &[u8]) -> Result<Self, Error> {
VerificationKeyBytes::try_from(slice).and_then(|vkb| vkb.try_into())
}
}
impl TryFrom<[u8; 32]> for VerificationKey {
type Error = Error;
fn try_from(bytes: [u8; 32]) -> Result<Self, Self::Error> {
VerificationKeyBytes::from(bytes).try_into()
}
}
impl VerificationKey {
pub const fn to_bytes(self) -> [u8; 32] {
self.A_bytes.0
}
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.A_bytes.0
}
pub fn verify(&self, signature: &Signature, msg: &[u8]) -> Result<(), Error> {
let k = Scalar::from_hash(
Sha512::default()
.chain(&signature.R_bytes[..])
.chain(&self.A_bytes.0[..])
.chain(msg),
);
self.verify_prehashed(signature, k)
}
#[allow(non_snake_case)]
pub(super) fn verify_prehashed(&self, signature: &Signature, k: Scalar) -> Result<(), Error> {
let s = Scalar::from_canonical_bytes(signature.s_bytes)
.into_option()
.ok_or(Error::InvalidSignature)?;
let R = CompressedEdwardsY(signature.R_bytes)
.decompress()
.ok_or(Error::InvalidSignature)?;
let R_prime = EdwardsPoint::vartime_double_scalar_mul_basepoint(&k, &self.minus_A, &s);
if (R - R_prime).mul_by_cofactor().is_identity() {
Ok(())
} else {
Err(Error::InvalidSignature)
}
}
}