use std::io::Write;
use std::path::Path;
use ed25519_dalek::{Signer, SigningKey, VerifyingKey};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::error::{CryptoError, CryptoResult};
use crate::signature::Signature;
#[derive(ZeroizeOnDrop)]
pub struct KeyPair {
#[zeroize(skip)] verifying_key: VerifyingKey,
signing_key: SigningKey,
}
impl KeyPair {
#[must_use]
pub fn generate() -> Self {
let signing_key = SigningKey::generate(&mut OsRng);
let verifying_key = signing_key.verifying_key();
Self {
verifying_key,
signing_key,
}
}
pub fn from_secret_key(bytes: &[u8]) -> CryptoResult<Self> {
if bytes.len() != 32 {
return Err(CryptoError::InvalidKeyLength {
expected: 32,
actual: bytes.len(),
});
}
let mut secret = [0u8; 32];
secret.copy_from_slice(bytes);
let signing_key = SigningKey::from_bytes(&secret);
let verifying_key = signing_key.verifying_key();
secret.zeroize();
Ok(Self {
verifying_key,
signing_key,
})
}
#[must_use]
pub fn public_key_bytes(&self) -> &[u8; 32] {
self.verifying_key.as_bytes()
}
#[must_use]
pub fn key_id(&self) -> [u8; 8] {
let mut id = [0u8; 8];
id.copy_from_slice(&self.public_key_bytes()[..8]);
id
}
#[must_use]
pub fn key_id_hex(&self) -> String {
hex::encode(self.key_id())
}
#[must_use]
pub fn sign(&self, message: &[u8]) -> Signature {
let sig = self.signing_key.sign(message);
Signature::from(sig)
}
pub fn verify(&self, message: &[u8], signature: &Signature) -> CryptoResult<()> {
signature.verify(message, self.public_key_bytes())
}
#[must_use]
pub fn export_public_key(&self) -> PublicKey {
PublicKey::from_bytes(*self.public_key_bytes())
}
#[must_use]
pub fn secret_key_bytes(&self) -> [u8; 32] {
self.signing_key.to_bytes()
}
pub fn load_or_generate(path: impl AsRef<Path>) -> CryptoResult<Self> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| CryptoError::IoError(e.to_string()))?;
}
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
match std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(path)
{
Ok(mut file) => {
let kp = Self::generate();
file.write_all(&kp.secret_key_bytes())
.map_err(|e| CryptoError::IoError(e.to_string()))?;
return Ok(kp);
},
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
},
Err(e) => return Err(CryptoError::IoError(e.to_string())),
}
}
#[cfg(not(unix))]
if !path.exists() {
let kp = Self::generate();
std::fs::write(path, kp.secret_key_bytes())
.map_err(|e| CryptoError::IoError(e.to_string()))?;
return Ok(kp);
}
let meta =
std::fs::symlink_metadata(path).map_err(|e| CryptoError::IoError(e.to_string()))?;
if meta.file_type().is_symlink() {
return Err(CryptoError::IoError(
"refusing to read key file: path is a symlink".into(),
));
}
let bytes =
Zeroizing::new(std::fs::read(path).map_err(|e| CryptoError::IoError(e.to_string()))?);
Self::from_secret_key(&bytes)
}
pub fn load_or_generate_pair(path: impl AsRef<Path>) -> CryptoResult<(Self, Self)> {
let first = Self::load_or_generate(path.as_ref())?;
let meta = std::fs::symlink_metadata(path.as_ref())
.map_err(|e| CryptoError::IoError(e.to_string()))?;
if meta.file_type().is_symlink() {
return Err(CryptoError::IoError(
"refusing to read key file: path is a symlink".into(),
));
}
let bytes = Zeroizing::new(
std::fs::read(path.as_ref()).map_err(|e| CryptoError::IoError(e.to_string()))?,
);
let second = Self::from_secret_key(&bytes)?;
Ok((first, second))
}
}
impl std::fmt::Debug for KeyPair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KeyPair")
.field("key_id", &self.key_id_hex())
.finish_non_exhaustive()
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct PublicKey([u8; 32]);
impl PublicKey {
#[must_use]
pub const fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn try_from_slice(slice: &[u8]) -> CryptoResult<Self> {
if slice.len() != 32 {
return Err(CryptoError::InvalidKeyLength {
expected: 32,
actual: slice.len(),
});
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(slice);
Ok(Self(bytes))
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
#[must_use]
pub fn key_id(&self) -> [u8; 8] {
let mut id = [0u8; 8];
id.copy_from_slice(&self.0[..8]);
id
}
#[must_use]
pub fn key_id_hex(&self) -> String {
hex::encode(self.key_id())
}
#[must_use]
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
pub fn from_hex(s: &str) -> CryptoResult<Self> {
let bytes = hex::decode(s).map_err(|_| CryptoError::InvalidHexEncoding)?;
Self::try_from_slice(&bytes)
}
#[must_use]
pub fn to_base64(&self) -> String {
use base64::Engine;
base64::engine::general_purpose::STANDARD.encode(self.0)
}
pub fn from_base64(s: &str) -> CryptoResult<Self> {
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD
.decode(s)
.map_err(|_| CryptoError::InvalidBase64Encoding)?;
Self::try_from_slice(&bytes)
}
pub fn verify(&self, message: &[u8], signature: &Signature) -> CryptoResult<()> {
signature.verify(message, &self.0)
}
}
impl std::fmt::Debug for PublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PublicKey({})", self.key_id_hex())
}
}
impl std::fmt::Display for PublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
impl Serialize for PublicKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_base64())
}
}
impl<'de> Deserialize<'de> for PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_base64(&s).map_err(serde::de::Error::custom)
}
}
impl From<[u8; 32]> for PublicKey {
fn from(bytes: [u8; 32]) -> Self {
Self(bytes)
}
}
impl From<PublicKey> for [u8; 32] {
fn from(pk: PublicKey) -> Self {
pk.0
}
}
impl AsRef<[u8]> for PublicKey {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generation() {
let kp1 = KeyPair::generate();
let kp2 = KeyPair::generate();
assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
}
#[test]
fn test_keypair_from_secret() {
let original = KeyPair::generate();
let secret = original.secret_key_bytes();
let restored = KeyPair::from_secret_key(&secret).unwrap();
assert_eq!(original.public_key_bytes(), restored.public_key_bytes());
}
#[test]
fn test_sign_verify() {
let keypair = KeyPair::generate();
let message = b"hello world";
let signature = keypair.sign(message);
assert!(keypair.verify(message, &signature).is_ok());
assert!(keypair.verify(b"wrong", &signature).is_err());
}
#[test]
fn test_key_id() {
let keypair = KeyPair::generate();
let key_id = keypair.key_id();
assert_eq!(&key_id[..], &keypair.public_key_bytes()[..8]);
let hex_id = keypair.key_id_hex();
assert_eq!(hex_id.len(), 16); }
#[test]
fn test_public_key_encoding() {
let keypair = KeyPair::generate();
let pk = keypair.export_public_key();
let hex = pk.to_hex();
let decoded = PublicKey::from_hex(&hex).unwrap();
assert_eq!(pk, decoded);
let b64 = pk.to_base64();
let decoded = PublicKey::from_base64(&b64).unwrap();
assert_eq!(pk, decoded);
}
#[test]
fn test_public_key_verify() {
let keypair = KeyPair::generate();
let pk = keypair.export_public_key();
let message = b"test";
let sig = keypair.sign(message);
assert!(pk.verify(message, &sig).is_ok());
}
#[test]
fn test_invalid_key_length() {
let result = KeyPair::from_secret_key(&[0u8; 31]);
assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
}
#[test]
fn test_load_or_generate_creates_new() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("keys").join("test.key");
let kp1 = KeyPair::load_or_generate(&path).unwrap();
assert!(path.exists());
let kp2 = KeyPair::load_or_generate(&path).unwrap();
assert_eq!(kp1.public_key_bytes(), kp2.public_key_bytes());
}
#[test]
fn test_load_or_generate_rejects_corrupt() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.key");
std::fs::write(&path, [0u8; 16]).unwrap();
let result = KeyPair::load_or_generate(&path);
assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
}
#[cfg(unix)]
#[test]
fn test_load_or_generate_sets_permissions() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("secure.key");
KeyPair::load_or_generate(&path).unwrap();
let perms = std::fs::metadata(&path).unwrap().permissions();
assert_eq!(perms.mode() & 0o777, 0o600);
}
#[cfg(unix)]
#[test]
fn test_load_or_generate_rejects_symlink() {
let dir = tempfile::tempdir().unwrap();
let real_path = dir.path().join("real.key");
let link_path = dir.path().join("link.key");
KeyPair::load_or_generate(&real_path).unwrap();
std::os::unix::fs::symlink(&real_path, &link_path).unwrap();
let result = KeyPair::load_or_generate(&link_path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("symlink"),
"expected symlink error, got: {err}"
);
}
#[test]
fn test_load_or_generate_pair() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("keys").join("pair.key");
let (kp1, kp2) = KeyPair::load_or_generate_pair(&path).unwrap();
assert_eq!(kp1.public_key_bytes(), kp2.public_key_bytes());
let msg = b"test message";
let sig1 = kp1.sign(msg);
let sig2 = kp2.sign(msg);
assert!(kp1.verify(msg, &sig1).is_ok());
assert!(kp2.verify(msg, &sig2).is_ok());
}
}