#![allow(unsafe_code)]
use crate::error::{Error, Result};
use crate::random::OsRng;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::mem::ManuallyDrop;
use subtle::{Choice, ConstantTimeEq};
use uuid::Uuid;
use zeroize::ZeroizeOnDrop;
#[derive(Clone, ZeroizeOnDrop)]
pub struct SecretKey<const N: usize> {
bytes: [u8; N],
}
impl<const N: usize> SecretKey<N> {
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}
pub fn generate() -> Self {
let mut bytes = [0u8; N];
OsRng.fill_bytes(&mut bytes);
Self { bytes }
}
pub fn from_slice(slice: &[u8]) -> Result<Self> {
if slice.len() != N {
return Err(Error::InvalidKeyLength {
expected: N,
actual: slice.len(),
});
}
let mut bytes = [0u8; N];
bytes.copy_from_slice(slice);
Ok(Self { bytes })
}
pub fn as_bytes(&self) -> &[u8; N] {
&self.bytes
}
pub fn as_slice(&self) -> &[u8] {
&self.bytes
}
pub const fn len() -> usize {
N
}
pub const fn bit_len() -> usize {
N * 8
}
pub fn ct_eq(&self, other: &Self) -> bool {
self.bytes.ct_eq(&other.bytes).into()
}
}
impl<const N: usize> ConstantTimeEq for SecretKey<N> {
fn ct_eq(&self, other: &Self) -> Choice {
self.bytes.ct_eq(&other.bytes)
}
}
impl<const N: usize> fmt::Debug for SecretKey<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SecretKey<{}>[REDACTED]", N)
}
}
impl<const N: usize> AsRef<[u8]> for SecretKey<N> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PublicKey<const N: usize> {
bytes: [u8; N],
}
impl<const N: usize> PublicKey<N> {
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}
pub fn from_slice(slice: &[u8]) -> Result<Self> {
if slice.len() != N {
return Err(Error::InvalidKeyLength {
expected: N,
actual: slice.len(),
});
}
let mut bytes = [0u8; N];
bytes.copy_from_slice(slice);
Ok(Self { bytes })
}
pub fn as_bytes(&self) -> &[u8; N] {
&self.bytes
}
pub fn as_slice(&self) -> &[u8] {
&self.bytes
}
pub const fn len() -> usize {
N
}
pub fn to_hex(&self) -> String {
hex::encode(self.bytes)
}
pub fn from_hex(s: &str) -> Result<Self> {
let bytes = hex::decode(s).map_err(|e| Error::ParseError(e.to_string()))?;
Self::from_slice(&bytes)
}
}
impl<const N: usize> fmt::Debug for PublicKey<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PublicKey<{}>({})", N, self.to_hex())
}
}
impl<const N: usize> fmt::Display for PublicKey<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_hex())
}
}
impl<const N: usize> AsRef<[u8]> for PublicKey<N> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
#[cfg(feature = "serde")]
impl<const N: usize> Serialize for PublicKey<N> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&self.to_hex())
} else {
serializer.serialize_bytes(&self.bytes)
}
}
}
#[cfg(feature = "serde")]
impl<'de, const N: usize> Deserialize<'de> for PublicKey<N> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
let s = String::deserialize(deserializer)?;
Self::from_hex(&s).map_err(serde::de::Error::custom)
} else {
let bytes = <Vec<u8>>::deserialize(deserializer)?;
Self::from_slice(&bytes).map_err(serde::de::Error::custom)
}
}
}
#[derive(Clone)]
pub struct KeyPair<const SK: usize, const PK: usize> {
secret_key: ManuallyDrop<SecretKey<SK>>,
public_key: ManuallyDrop<PublicKey<PK>>,
}
impl<const SK: usize, const PK: usize> KeyPair<SK, PK> {
pub fn new(secret_key: SecretKey<SK>, public_key: PublicKey<PK>) -> Self {
Self {
secret_key: ManuallyDrop::new(secret_key),
public_key: ManuallyDrop::new(public_key),
}
}
pub fn secret_key(&self) -> &SecretKey<SK> {
&self.secret_key
}
pub fn public_key(&self) -> &PublicKey<PK> {
&self.public_key
}
pub fn into_secret_key(mut self) -> SecretKey<SK> {
let secret = unsafe { ManuallyDrop::take(&mut self.secret_key) };
std::mem::forget(self);
secret
}
pub fn into_parts(mut self) -> (SecretKey<SK>, PublicKey<PK>) {
let secret = unsafe { ManuallyDrop::take(&mut self.secret_key) };
let public = unsafe { ManuallyDrop::take(&mut self.public_key) };
std::mem::forget(self);
(secret, public)
}
}
impl<const SK: usize, const PK: usize> Drop for KeyPair<SK, PK> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.secret_key);
ManuallyDrop::drop(&mut self.public_key);
}
}
}
impl<const SK: usize, const PK: usize> fmt::Debug for KeyPair<SK, PK> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KeyPair")
.field("secret_key", &"[REDACTED]")
.field("public_key", &self.public_key)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct KeyId(Uuid);
impl KeyId {
pub fn generate() -> Self {
Self(Uuid::new_v4())
}
pub fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
pub fn as_uuid(&self) -> &Uuid {
&self.0
}
pub fn parse(s: &str) -> Result<Self> {
let uuid = Uuid::parse_str(s).map_err(|e| Error::ParseError(e.to_string()))?;
Ok(Self(uuid))
}
}
impl fmt::Display for KeyId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum KeyUsage {
Encrypt,
Sign,
KeyExchange,
WrapKey,
DeriveKey,
Authenticate,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[allow(missing_docs)] pub enum KeyAlgorithm {
Aes128,
Aes192,
Aes256,
ChaCha20,
ChaCha20Poly1305,
XChaCha20Poly1305,
Rsa2048,
Rsa3072,
Rsa4096,
Ed25519,
Ed448,
EcdsaP256,
EcdsaP384,
EcdsaP521,
EcdsaSecp256k1,
SchnorrSecp256k1,
X25519,
X448,
EcdhP256,
EcdhP384,
MlKem512,
MlKem768,
MlKem1024,
MlDsa44,
MlDsa65,
MlDsa87,
SlhDsaSha2_128f,
SlhDsaSha2_128s,
SlhDsaSha2_192f,
SlhDsaSha2_192s,
SlhDsaSha2_256f,
SlhDsaSha2_256s,
X25519MlKem768,
Ed25519MlDsa65,
Custom(String),
}
impl fmt::Display for KeyAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KeyAlgorithm::Aes128 => write!(f, "AES-128"),
KeyAlgorithm::Aes192 => write!(f, "AES-192"),
KeyAlgorithm::Aes256 => write!(f, "AES-256"),
KeyAlgorithm::ChaCha20 => write!(f, "ChaCha20"),
KeyAlgorithm::ChaCha20Poly1305 => write!(f, "ChaCha20-Poly1305"),
KeyAlgorithm::XChaCha20Poly1305 => write!(f, "XChaCha20-Poly1305"),
KeyAlgorithm::Rsa2048 => write!(f, "RSA-2048"),
KeyAlgorithm::Rsa3072 => write!(f, "RSA-3072"),
KeyAlgorithm::Rsa4096 => write!(f, "RSA-4096"),
KeyAlgorithm::Ed25519 => write!(f, "Ed25519"),
KeyAlgorithm::Ed448 => write!(f, "Ed448"),
KeyAlgorithm::EcdsaP256 => write!(f, "ECDSA-P256"),
KeyAlgorithm::EcdsaP384 => write!(f, "ECDSA-P384"),
KeyAlgorithm::EcdsaP521 => write!(f, "ECDSA-P521"),
KeyAlgorithm::EcdsaSecp256k1 => write!(f, "ECDSA-secp256k1"),
KeyAlgorithm::SchnorrSecp256k1 => write!(f, "Schnorr-secp256k1"),
KeyAlgorithm::X25519 => write!(f, "X25519"),
KeyAlgorithm::X448 => write!(f, "X448"),
KeyAlgorithm::EcdhP256 => write!(f, "ECDH-P256"),
KeyAlgorithm::EcdhP384 => write!(f, "ECDH-P384"),
KeyAlgorithm::MlKem512 => write!(f, "ML-KEM-512"),
KeyAlgorithm::MlKem768 => write!(f, "ML-KEM-768"),
KeyAlgorithm::MlKem1024 => write!(f, "ML-KEM-1024"),
KeyAlgorithm::MlDsa44 => write!(f, "ML-DSA-44"),
KeyAlgorithm::MlDsa65 => write!(f, "ML-DSA-65"),
KeyAlgorithm::MlDsa87 => write!(f, "ML-DSA-87"),
KeyAlgorithm::SlhDsaSha2_128f => write!(f, "SLH-DSA-SHA2-128f"),
KeyAlgorithm::SlhDsaSha2_128s => write!(f, "SLH-DSA-SHA2-128s"),
KeyAlgorithm::SlhDsaSha2_192f => write!(f, "SLH-DSA-SHA2-192f"),
KeyAlgorithm::SlhDsaSha2_192s => write!(f, "SLH-DSA-SHA2-192s"),
KeyAlgorithm::SlhDsaSha2_256f => write!(f, "SLH-DSA-SHA2-256f"),
KeyAlgorithm::SlhDsaSha2_256s => write!(f, "SLH-DSA-SHA2-256s"),
KeyAlgorithm::X25519MlKem768 => write!(f, "X25519-ML-KEM-768"),
KeyAlgorithm::Ed25519MlDsa65 => write!(f, "Ed25519-ML-DSA-65"),
KeyAlgorithm::Custom(name) => write!(f, "{}", name),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyMetadata {
pub id: KeyId,
pub algorithm: KeyAlgorithm,
pub usages: Vec<KeyUsage>,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub not_before: Option<DateTime<Utc>>,
pub label: Option<String>,
pub extractable: bool,
pub version: u32,
pub attributes: std::collections::HashMap<String, String>,
}
impl KeyMetadata {
pub fn new(algorithm: KeyAlgorithm, usages: Vec<KeyUsage>) -> Self {
Self {
id: KeyId::generate(),
algorithm,
usages,
created_at: Utc::now(),
expires_at: None,
not_before: None,
label: None,
extractable: false,
version: 1,
attributes: std::collections::HashMap::new(),
}
}
pub fn is_valid(&self) -> bool {
let now = Utc::now();
if self.not_before.is_some_and(|not_before| now < not_before) {
return false;
}
if self.expires_at.is_some_and(|expires_at| now > expires_at) {
return false;
}
true
}
pub fn can_use_for(&self, usage: KeyUsage) -> bool {
self.usages.contains(&usage) && self.is_valid()
}
pub fn with_expiration(mut self, expires_at: DateTime<Utc>) -> Self {
self.expires_at = Some(expires_at);
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn with_extractable(mut self, extractable: bool) -> Self {
self.extractable = extractable;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secret_key_generation() {
let key1 = SecretKey::<32>::generate();
let key2 = SecretKey::<32>::generate();
assert!(!key1.ct_eq(&key2));
}
#[test]
fn test_secret_key_from_slice() {
let bytes = [0u8; 32];
let key = SecretKey::<32>::from_slice(&bytes).unwrap();
assert_eq!(key.as_bytes(), &bytes);
let short = [0u8; 16];
assert!(SecretKey::<32>::from_slice(&short).is_err());
}
#[test]
fn test_public_key_hex() {
let bytes = [0xab; 32];
let key = PublicKey::<32>::new(bytes);
let hex = key.to_hex();
let decoded = PublicKey::<32>::from_hex(&hex).unwrap();
assert_eq!(key, decoded);
}
#[test]
fn test_key_metadata_validity() {
let meta = KeyMetadata::new(KeyAlgorithm::Aes256, vec![KeyUsage::Encrypt]);
assert!(meta.is_valid());
assert!(meta.can_use_for(KeyUsage::Encrypt));
assert!(!meta.can_use_for(KeyUsage::Sign));
}
#[test]
fn test_key_id() {
let id = KeyId::generate();
let s = id.to_string();
let parsed = KeyId::parse(&s).unwrap();
assert_eq!(id, parsed);
}
}