use core::convert::{TryFrom, TryInto};
use num_bigint::{BigInt, Sign};
use rand_core::{CryptoRng, RngCore};
use sha3::{
digest::{ExtendableOutput, Update},
Shake256,
};
use crate::{
init_sig, point::Point, shake256, Ed448Error, PreHash, PublicKey, KEY_LENGTH, SIG_LENGTH,
};
#[allow(clippy::redundant_pub_crate)]
pub(crate) type PrivateKeyRaw = [u8; KEY_LENGTH];
#[allow(clippy::redundant_pub_crate)]
pub(crate) type SeedRaw = [u8; KEY_LENGTH];
#[derive(Copy, Clone)]
pub struct PrivateKey(PrivateKeyRaw);
opaque_debug::implement!(PrivateKey);
impl PrivateKey {
pub fn new<T>(rnd: &mut T) -> Self
where
T: CryptoRng + RngCore,
{
let mut key = [0; KEY_LENGTH];
rnd.fill_bytes(&mut key);
Self::from(key)
}
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8; KEY_LENGTH] {
&self.0
}
pub(crate) fn expand(&self) -> (PrivateKeyRaw, SeedRaw) {
let h = Shake256::default()
.chain(self.as_bytes())
.finalize_boxed(114);
let mut s: [u8; KEY_LENGTH] = h[..KEY_LENGTH].try_into().unwrap();
s[0] &= 0b1111_1100;
s[56] = 0;
s[55] |= 0b1000_0000;
let seed: [u8; KEY_LENGTH] = h[KEY_LENGTH..].try_into().unwrap();
(s, seed)
}
#[inline]
pub fn sign(&self, msg: &[u8], ctx: Option<&[u8]>) -> crate::Result<[u8; SIG_LENGTH]> {
self.sign_real(msg, ctx, PreHash::False)
}
#[inline]
pub fn sign_ph(&self, msg: &[u8], ctx: Option<&[u8]>) -> crate::Result<[u8; SIG_LENGTH]> {
self.sign_real(msg, ctx, PreHash::True)
}
fn sign_real(
&self,
msg: &[u8],
ctx: Option<&[u8]>,
pre_hash: PreHash,
) -> crate::Result<[u8; SIG_LENGTH]> {
let (ctx, msg) = init_sig(ctx, pre_hash, msg)?;
let (a, seed) = &self.expand();
let a = BigInt::from_bytes_le(Sign::Plus, a);
let r = shake256(vec![seed, &msg], ctx.as_ref(), pre_hash);
let r = BigInt::from_bytes_le(Sign::Plus, r.as_ref()) % Point::l();
let R = (Point::default() * &r).encode();
let h = shake256(
vec![&R, &PublicKey::from(a.clone()).as_byte(), &msg],
ctx.as_ref(),
pre_hash,
);
let h = BigInt::from_bytes_le(Sign::Plus, h.as_ref()) % Point::l();
let S = (r + h * a) % Point::l();
let mut S = S.magnitude().to_bytes_le();
S.resize_with(KEY_LENGTH, Default::default);
let S: [u8; KEY_LENGTH] = S.try_into().unwrap();
Ok([R, S].concat().try_into().unwrap())
}
}
impl From<PrivateKeyRaw> for PrivateKey {
#[inline]
fn from(array: PrivateKeyRaw) -> Self {
Self(array)
}
}
impl TryFrom<&'_ [u8]> for PrivateKey {
type Error = Ed448Error;
fn try_from(bytes: &[u8]) -> crate::Result<Self> {
if bytes.len() != KEY_LENGTH {
return Err(Ed448Error::WrongKeyLength);
}
let bytes: &[u8; KEY_LENGTH] = bytes.try_into().unwrap();
Ok(Self::from(bytes))
}
}
impl From<&'_ PrivateKeyRaw> for PrivateKey {
#[inline]
fn from(bytes: &PrivateKeyRaw) -> Self {
Self::from(*bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::OsRng;
#[test]
fn create_new_pkey() {
let pkey = PrivateKey::new(&mut OsRng);
let a = pkey.as_bytes();
assert_eq!(a.len(), KEY_LENGTH);
}
#[test]
fn invalid_key_len() {
let invalid_pk = PrivateKey::try_from(&[0x01_u8][..]);
assert_eq!(invalid_pk.unwrap_err(), Ed448Error::WrongKeyLength);
}
#[test]
fn invalid_context_length() {
let pkey = PrivateKey::new(&mut OsRng);
let ctx = [0; 256];
let invalid_sig = pkey.sign(b"message", Some(&ctx));
assert_eq!(invalid_sig.unwrap_err(), Ed448Error::ContextTooLong);
}
}