use serde::{Deserialize, Serialize};
use crate::error::CryptoError;
use std::fmt;
use super::{DSA, DSAlgorithm, Ed25519Signer, PublicKey, SignatureIdentifier};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default,
)]
pub enum KeyPairAlgorithm {
#[default]
Ed25519,
}
impl From<DSAlgorithm> for KeyPairAlgorithm {
fn from(algo: DSAlgorithm) -> Self {
match algo {
DSAlgorithm::Ed25519 => Self::Ed25519,
}
}
}
impl From<KeyPairAlgorithm> for DSAlgorithm {
fn from(kp_type: KeyPairAlgorithm) -> Self {
match kp_type {
KeyPairAlgorithm::Ed25519 => Self::Ed25519,
}
}
}
impl KeyPairAlgorithm {
pub fn generate_keypair(&self) -> Result<KeyPair, CryptoError> {
KeyPair::generate(*self)
}
}
impl fmt::Display for KeyPairAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Ed25519 => write!(f, "Ed25519"),
}
}
}
#[derive(Clone)]
pub enum KeyPair {
Ed25519(Ed25519Signer),
}
impl KeyPair {
pub fn generate(key_type: KeyPairAlgorithm) -> Result<Self, CryptoError> {
match key_type {
KeyPairAlgorithm::Ed25519 => {
Ed25519Signer::generate().map(KeyPair::Ed25519)
}
}
}
pub fn from_secret_der(der: &[u8]) -> Result<Self, CryptoError> {
use pkcs8::{ObjectIdentifier, PrivateKeyInfo};
let private_key_info = PrivateKeyInfo::try_from(der)
.map_err(|e| CryptoError::InvalidDerFormat(e.to_string()))?;
let oid = private_key_info.algorithm.oid;
const ED25519_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.3.101.112");
if oid == ED25519_OID {
let secret_key = private_key_info.private_key;
if secret_key.len() < 2 || secret_key[0] != 0x04 {
return Err(CryptoError::InvalidSecretKey(
"Invalid Ed25519 key encoding in DER".to_string(),
));
}
let key_length = secret_key[1] as usize;
if secret_key.len() < 2 + key_length {
return Err(CryptoError::InvalidSecretKey(
"Truncated Ed25519 key in DER".to_string(),
));
}
let actual_key = &secret_key[2..2 + key_length];
Ed25519Signer::from_secret_key(actual_key).map(KeyPair::Ed25519)
} else {
Err(CryptoError::UnsupportedAlgorithm(format!(
"Algorithm with OID {} is not supported",
oid
)))
}
}
pub fn from_seed(
key_type: KeyPairAlgorithm,
seed: &[u8; 32],
) -> Result<Self, CryptoError> {
match key_type {
KeyPairAlgorithm::Ed25519 => {
Ed25519Signer::from_seed(seed).map(KeyPair::Ed25519)
}
}
}
pub fn derive_from_data(
key_type: KeyPairAlgorithm,
data: &[u8],
) -> Result<Self, CryptoError> {
match key_type {
KeyPairAlgorithm::Ed25519 => {
Ed25519Signer::derive_from_data(data).map(KeyPair::Ed25519)
}
}
}
pub fn from_secret_key(secret_key: &[u8]) -> Result<Self, CryptoError> {
match secret_key.len() {
32 | 64 => {
Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
}
_ => Err(CryptoError::InvalidSecretKey(format!(
"Unsupported key length: {} bytes",
secret_key.len()
))),
}
}
pub fn from_secret_key_with_type(
key_type: KeyPairAlgorithm,
secret_key: &[u8],
) -> Result<Self, CryptoError> {
match key_type {
KeyPairAlgorithm::Ed25519 => {
Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
}
}
}
#[inline]
pub const fn key_type(&self) -> KeyPairAlgorithm {
match self {
Self::Ed25519(_) => KeyPairAlgorithm::Ed25519,
}
}
#[inline]
pub fn sign(
&self,
message: &[u8],
) -> Result<SignatureIdentifier, CryptoError> {
match self {
Self::Ed25519(signer) => signer.sign(message),
}
}
#[inline]
pub fn algorithm(&self) -> DSAlgorithm {
match self {
Self::Ed25519(signer) => signer.algorithm(),
}
}
#[inline]
pub fn algorithm_id(&self) -> u8 {
match self {
Self::Ed25519(signer) => signer.algorithm_id(),
}
}
#[inline]
pub fn public_key_bytes(&self) -> Vec<u8> {
match self {
Self::Ed25519(signer) => signer.public_key_bytes(),
}
}
#[inline]
pub fn public_key(&self) -> PublicKey {
PublicKey::new(self.algorithm(), self.public_key_bytes())
.expect("KeyPair should always have valid public key")
}
#[inline]
pub fn secret_key_bytes(&self) -> Result<Vec<u8>, CryptoError> {
match self {
Self::Ed25519(signer) => signer.secret_key_bytes(),
}
}
pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
let secret = self.secret_key_bytes()?;
let mut result = Vec::with_capacity(1 + secret.len());
result.push(self.algorithm_id());
result.extend_from_slice(&secret);
Ok(result)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
if bytes.is_empty() {
return Err(CryptoError::InvalidSecretKey(
"Data too short to contain algorithm identifier".to_string(),
));
}
let id = bytes[0];
let algorithm = DSAlgorithm::from_identifier(id)?;
let key_type = KeyPairAlgorithm::from(algorithm);
let secret_key = &bytes[1..];
Self::from_secret_key_with_type(key_type, secret_key)
}
pub fn to_secret_der(&self) -> Result<Vec<u8>, CryptoError> {
use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
const ED25519_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.3.101.112");
let secret_key_bytes = self.secret_key_bytes()?;
let mut wrapped_key = Vec::with_capacity(2 + secret_key_bytes.len());
wrapped_key.push(0x04); wrapped_key.push(secret_key_bytes.len() as u8); wrapped_key.extend_from_slice(&secret_key_bytes);
let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
oid: ED25519_OID,
parameters: None,
};
let private_key_info = PrivateKeyInfo {
algorithm: algorithm_identifier,
private_key: &wrapped_key,
public_key: None,
};
private_key_info.to_der().map_err(|e| {
CryptoError::InvalidSecretKey(format!("DER encoding failed: {}", e))
})
}
}
impl Default for KeyPair {
fn default() -> Self {
Self::Ed25519(Ed25519Signer::default())
}
}
impl fmt::Debug for KeyPair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use crate::common::base64_encoding;
f.debug_struct("KeyPair")
.field("type", &self.key_type())
.field("algorithm", &self.algorithm())
.field(
"public_key",
&base64_encoding::encode(&self.public_key_bytes()),
)
.finish_non_exhaustive()
}
}
impl fmt::Display for KeyPair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?} KeyPair", self.key_type())
}
}
impl DSA for KeyPair {
#[inline]
fn algorithm_id(&self) -> u8 {
Self::algorithm_id(self)
}
#[inline]
fn signature_length(&self) -> usize {
match self {
Self::Ed25519(signer) => signer.signature_length(),
}
}
#[inline]
fn sign(&self, message: &[u8]) -> Result<SignatureIdentifier, CryptoError> {
Self::sign(self, message)
}
#[inline]
fn algorithm(&self) -> DSAlgorithm {
Self::algorithm(self)
}
#[inline]
fn public_key_bytes(&self) -> Vec<u8> {
Self::public_key_bytes(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generate() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
assert_eq!(keypair.public_key_bytes().len(), 32);
}
#[test]
fn test_keypair_sign_verify() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let message = b"Test message";
let signature = keypair.sign(message).unwrap();
let public_key = keypair.public_key();
assert!(public_key.verify(message, &signature).is_ok());
assert!(public_key.verify(b"Wrong message", &signature).is_err());
}
#[test]
fn test_keypair_from_seed() {
let seed = [42u8; 32];
let keypair1 =
KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
let keypair2 =
KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
}
#[test]
fn test_keypair_derive_from_data() {
let data = b"my passphrase";
let keypair1 =
KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
let keypair2 =
KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
let keypair3 =
KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, b"different")
.unwrap();
assert_ne!(keypair1.public_key_bytes(), keypair3.public_key_bytes());
}
#[test]
fn test_keypair_serialization() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let message = b"Test message";
let bytes = keypair.to_bytes().unwrap();
assert_eq!(bytes[0], b'E');
let keypair2 = KeyPair::from_bytes(&bytes).unwrap();
let sig1 = keypair.sign(message).unwrap();
let sig2 = keypair2.sign(message).unwrap();
let public_key = keypair.public_key();
assert!(public_key.verify(message, &sig1).is_ok());
assert!(public_key.verify(message, &sig2).is_ok());
}
#[test]
fn test_keypair_dsa_trait() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let message = b"Test message";
let signature = DSA::sign(&keypair, message).unwrap();
assert_eq!(DSA::algorithm(&keypair), DSAlgorithm::Ed25519);
assert_eq!(DSA::algorithm_id(&keypair), b'E');
let public_key = keypair.public_key();
assert!(public_key.verify(message, &signature).is_ok());
}
#[test]
fn test_keypair_public_key_wrapper() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let public_key = keypair.public_key();
assert_eq!(public_key.algorithm(), keypair.algorithm());
assert_eq!(public_key.as_bytes(), &keypair.public_key_bytes()[..]);
}
#[test]
fn test_keypair_from_secret_key_autodetect() {
let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let secret_bytes = keypair1.secret_key_bytes().unwrap();
let keypair2 = KeyPair::from_secret_key(&secret_bytes).unwrap();
assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
}
#[test]
fn test_keypair_type_conversion() {
let kp_type = KeyPairAlgorithm::Ed25519;
let algo: DSAlgorithm = kp_type.into();
assert_eq!(algo, DSAlgorithm::Ed25519);
let kp_type2: KeyPairAlgorithm = algo.into();
assert_eq!(kp_type, kp_type2);
}
#[test]
fn test_keypair_algorithm_generate() {
let algorithm = KeyPairAlgorithm::Ed25519;
let keypair = algorithm.generate_keypair().unwrap();
assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
let message = b"test";
let signature = keypair.sign(message).unwrap();
let public_key = keypair.public_key();
assert!(public_key.verify(message, &signature).is_ok());
}
#[test]
fn test_keypair_algorithm_display() {
let algorithm = KeyPairAlgorithm::Ed25519;
assert_eq!(algorithm.to_string(), "Ed25519");
}
#[test]
fn test_default_keypair() {
let keypair = KeyPair::default();
assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
}
#[test]
fn test_keypair_clone() {
let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let keypair_clone = keypair.clone();
assert_eq!(
keypair.public_key_bytes(),
keypair_clone.public_key_bytes()
);
let message = b"test message";
let sig1 = keypair.sign(message).unwrap();
let sig2 = keypair_clone.sign(message).unwrap();
assert_eq!(sig1, sig2);
let public_key = keypair.public_key();
assert!(public_key.verify(message, &sig1).is_ok());
assert!(public_key.verify(message, &sig2).is_ok());
}
#[test]
fn test_keypair_der_roundtrip() {
let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
let message = b"Test message for DER roundtrip";
let der_bytes = keypair1.to_secret_der().unwrap();
assert_eq!(der_bytes[0], 0x30);
let keypair2 = KeyPair::from_secret_der(&der_bytes).unwrap();
assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
let sig1 = keypair1.sign(message).unwrap();
let sig2 = keypair2.sign(message).unwrap();
let public_key = keypair1.public_key();
assert!(public_key.verify(message, &sig1).is_ok());
assert!(public_key.verify(message, &sig2).is_ok());
}
#[test]
fn test_keypair_from_der_invalid() {
let invalid_der = vec![0x00, 0x01, 0x02];
let result = KeyPair::from_secret_der(&invalid_der);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
CryptoError::InvalidDerFormat(_)
));
}
#[test]
fn test_keypair_from_der_unsupported_algorithm() {
use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
let unsupported_oid = ObjectIdentifier::new_unwrap("1.3.132.0.10");
let fake_key = vec![0x04, 0x20]; let fake_key = [&fake_key[..], &[0u8; 32]].concat();
let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
oid: unsupported_oid,
parameters: None,
};
let private_key_info = PrivateKeyInfo {
algorithm: algorithm_identifier,
private_key: &fake_key,
public_key: None,
};
let der_bytes = private_key_info.to_der().unwrap();
let result = KeyPair::from_secret_der(&der_bytes);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
CryptoError::UnsupportedAlgorithm(_)
));
}
}