#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use falcon_rust::falcon512::{self as fr, SecretKey as FrSecretKey};
use rand::RngCore;
use sha3::{Digest, Sha3_256};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
error::Error,
verify::verify_raw,
DOMAIN_TAG, PUBLIC_KEY_BYTES,
};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PublicKey(
#[cfg_attr(feature = "serde", serde(with = "hex_bytes"))]
Vec<u8>,
);
impl PublicKey {
pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, Error> {
if bytes.len() != PUBLIC_KEY_BYTES {
return Err(Error::InvalidPublicKeyLength {
expected: PUBLIC_KEY_BYTES,
actual: bytes.len(),
});
}
Ok(Self(bytes))
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
pub fn to_address(&self) -> crate::address::SingleKeyAddress {
crate::address::SingleKeyAddress::from_public_key(self)
}
}
pub struct KeyPair {
public_key: PublicKey,
secret_key: SecretKeyBytes,
}
#[derive(Zeroize, ZeroizeOnDrop)]
struct SecretKeyBytes(Vec<u8>);
impl KeyPair {
pub fn generate() -> Self {
let mut seed = [0u8; 32];
rand::thread_rng().fill_bytes(&mut seed);
let (sk, pk) = fr::keygen(seed);
Self {
public_key: PublicKey(pk.to_bytes().to_vec()),
secret_key: SecretKeyBytes(sk.to_bytes().to_vec()),
}
}
pub fn from_bytes(sk_bytes: &[u8], pk_bytes: &[u8]) -> Result<Self, Error> {
if pk_bytes.len() != PUBLIC_KEY_BYTES {
return Err(Error::InvalidPublicKeyLength {
expected: PUBLIC_KEY_BYTES,
actual: pk_bytes.len(),
});
}
Ok(Self {
public_key: PublicKey(pk_bytes.to_vec()),
secret_key: SecretKeyBytes(sk_bytes.to_vec()),
})
}
pub fn public_key(&self) -> &PublicKey {
&self.public_key
}
pub fn secret_key_bytes(&self) -> &[u8] {
&self.secret_key.0
}
pub fn sign(&self, message: &[u8]) -> Vec<u8> {
let digest = domain_hash(message);
let sk = FrSecretKey::from_bytes(&self.secret_key.0)
.expect("KeyPair::sign: stored secret key is malformed — this is a bug");
fr::sign(&digest, &sk).to_bytes().to_vec()
}
pub fn address(&self) -> crate::address::SingleKeyAddress {
self.public_key.to_address()
}
pub fn verify_own_signature(&self, message: &[u8], signature: &[u8]) -> Result<bool, Error> {
verify_raw(message, signature, self.public_key.as_bytes(), 0)
}
}
impl core::fmt::Debug for KeyPair {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("KeyPair")
.field("public_key", &self.public_key)
.field("secret_key", &"[REDACTED]")
.finish()
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for KeyPair {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("KeyPair", 1)?;
state.serialize_field("public_key", &self.public_key)?;
state.end()
}
}
pub(crate) fn domain_hash(message: &[u8]) -> [u8; 32] {
let mut hasher = Sha3_256::new();
hasher.update(DOMAIN_TAG);
hasher.update(message);
let result = hasher.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&result);
out
}
#[cfg(feature = "serde")]
mod hex_bytes {
use serde::{Deserialize, Deserializer, Serializer};
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(&hex::encode(bytes))
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
let s = String::deserialize(d)?;
hex::decode(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{SIGNATURE_MAX_BYTES, SIGNATURE_MIN_BYTES};
#[test]
fn generate_and_sign_verify_roundtrip() {
let kp = KeyPair::generate();
let message = b"test payload for signing";
let sig = kp.sign(message);
assert!(sig.len() >= SIGNATURE_MIN_BYTES);
assert!(sig.len() <= SIGNATURE_MAX_BYTES);
let ok = kp.verify_own_signature(message, &sig).unwrap();
assert!(ok, "freshly generated signature must verify");
}
#[test]
fn wrong_message_fails_verification() {
let kp = KeyPair::generate();
let sig = kp.sign(b"correct message");
let ok = kp.verify_own_signature(b"wrong message", &sig).unwrap();
assert!(!ok);
}
#[test]
fn wrong_key_fails_verification() {
let kp1 = KeyPair::generate();
let kp2 = KeyPair::generate();
let message = b"some message";
let sig = kp1.sign(message);
let ok = verify_raw(message, &sig, kp2.public_key().as_bytes(), 0).unwrap();
assert!(!ok);
}
#[test]
fn public_key_length_is_correct() {
let kp = KeyPair::generate();
assert_eq!(kp.public_key().as_bytes().len(), PUBLIC_KEY_BYTES);
}
#[test]
fn from_bytes_rejects_bad_pk_length() {
let sk = KeyPair::generate();
let result = KeyPair::from_bytes(sk.secret_key_bytes(), &[0u8; 64]);
assert!(matches!(result, Err(Error::InvalidPublicKeyLength { .. })));
}
#[test]
fn debug_does_not_expose_secret_key() {
let kp = KeyPair::generate();
let debug_str = format!("{kp:?}");
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("secret_key_bytes"));
}
#[test]
fn domain_hash_includes_tag() {
let with_tag = domain_hash(b"hello");
let mut hasher = Sha3_256::new();
hasher.update(b"hello");
let raw: [u8; 32] = hasher.finalize().into();
assert_ne!(with_tag, raw, "domain hash must differ from untagged hash");
}
#[test]
fn domain_hash_is_deterministic() {
assert_eq!(domain_hash(b"data"), domain_hash(b"data"));
}
}