use crate::keys::{PublicKey, SecretKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH};
use crate::SchnorrError;
use mohan::ser;
use rand::{CryptoRng, RngCore};
use zeroize::Zeroize;
pub const KEYPAIR_LENGTH: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH;
#[derive(Debug, Default, Clone)] pub struct Keypair {
pub secret: SecretKey,
pub public: PublicKey,
}
impl From<SecretKey> for Keypair {
fn from(secret: SecretKey) -> Keypair {
let public = PublicKey::from_secret(&secret);
Keypair { secret, public }
}
}
impl ::zeroize::Zeroize for Keypair {
fn zeroize(&mut self) {
self.secret.zeroize();
}
}
impl Drop for Keypair {
fn drop(&mut self) {
self.zeroize();
}
}
impl Keypair {
pub fn to_bytes(&self) -> [u8; KEYPAIR_LENGTH] {
let mut bytes: [u8; KEYPAIR_LENGTH] = [0u8; KEYPAIR_LENGTH];
bytes[..SECRET_KEY_LENGTH].copy_from_slice(self.secret.as_bytes());
bytes[SECRET_KEY_LENGTH..].copy_from_slice(self.public.as_bytes());
bytes
}
pub fn from_bytes<'a>(bytes: &'a [u8]) -> Result<Keypair, SchnorrError> {
if bytes.len() != KEYPAIR_LENGTH {
return Err(SchnorrError::SerError);
}
let secret = SecretKey::from_bytes(&bytes[..SECRET_KEY_LENGTH])?;
let public = PublicKey::from_bytes(&bytes[SECRET_KEY_LENGTH..])?;
Ok(Keypair{ secret: secret, public: public })
}
pub fn generate<R>(csprng: &mut R) -> Keypair
where
R: CryptoRng + RngCore,
{
let sk: SecretKey = SecretKey::generate(csprng);
let pk: PublicKey = PublicKey::from_secret(&sk);
Keypair {
public: pk,
secret: sk,
}
}
pub fn from_secret(s: &SecretKey) -> Keypair {
Keypair {
secret: s.clone(),
public: PublicKey::from_secret(s),
}
}
}
impl ser::Writeable for Keypair {
fn write<W: ser::Writer>(&self, writer: &mut W) -> Result<(), ser::Error> {
self.secret.write(writer)?;
self.public.write(writer)?;
Ok(())
}
}
impl ser::Readable for Keypair {
fn read(reader: &mut dyn ser::Reader) -> Result<Keypair, ser::Error> {
let s = SecretKey::read(reader)?;
let p = PublicKey::read(reader)?;
Ok(Keypair {
secret: s,
public: p,
})
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn keypair_clear_on_drop() {
let mut keypair: Keypair = Keypair::generate(&mut rand::prelude::thread_rng());
keypair.zeroize();
fn as_bytes<T>(x: &T) -> &[u8] {
use core::mem;
use core::slice;
unsafe { slice::from_raw_parts(x as *const T as *const u8, mem::size_of_val(x)) }
}
assert!(!as_bytes(&keypair).iter().all(|x| *x == 0u8));
}
}