use rand::TryRngCore;
const MINIMUM_KEY_LENGTH: usize = 32;
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Hash)]
pub struct Key(Vec<u8>);
#[cfg(feature = "serde")]
mod deser {
use crate::Key;
use serde::Deserializer;
impl<'de> serde::Deserialize<'de> for Key {
fn deserialize<D>(deserializer: D) -> Result<Key, D::Error>
where
D: Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum KeyRepr {
Bytes(Vec<u8>),
String(String),
}
match KeyRepr::deserialize(deserializer).map_err(|_| {
serde::de::Error::custom("invalid key format: expected a byte array or a string")
})? {
KeyRepr::Bytes(b) => Key::try_from(b),
KeyRepr::String(s) => Key::try_from(s.as_bytes()),
}
.map_err(serde::de::Error::custom)
}
}
impl serde::Serialize for Key {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
}
impl PartialEq for Key {
fn eq(&self, other: &Self) -> bool {
use subtle::ConstantTimeEq;
self.0.ct_eq(&other.0).into()
}
}
impl Eq for Key {}
impl std::fmt::Debug for Key {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Key").field(&"***").finish()
}
}
impl Key {
#[inline]
pub fn from(key: Vec<u8>) -> Key {
crate::Key::try_from(key).expect("Invalid key material")
}
pub fn generate() -> Key {
Self::try_generate().expect("failed to generate `Key` from randomness")
}
pub fn try_generate() -> Option<Key> {
let mut rng = rand::rng();
let mut key: Vec<u8> = vec![0; MINIMUM_KEY_LENGTH * 2];
rng.try_fill_bytes(&mut key).ok()?;
Some(Key::from(key))
}
pub fn master(&self) -> &[u8] {
&self.0
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum KeyError {
TooShort(ShortKeyError),
}
impl std::error::Error for KeyError {}
impl std::fmt::Display for KeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KeyError::TooShort(e) => {
write!(f, "{e}")
}
}
}
}
#[derive(Debug)]
pub struct ShortKeyError {
length: usize,
}
impl std::fmt::Display for ShortKeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"key material is too short: expected >= {} bytes, got {} bytes",
MINIMUM_KEY_LENGTH, self.length
)
}
}
impl std::error::Error for ShortKeyError {}
impl TryFrom<&[u8]> for Key {
type Error = KeyError;
fn try_from(key: &[u8]) -> Result<Self, Self::Error> {
crate::Key::try_from(key.to_vec())
}
}
impl TryFrom<Vec<u8>> for Key {
type Error = KeyError;
fn try_from(key: Vec<u8>) -> Result<Self, Self::Error> {
if key.len() < MINIMUM_KEY_LENGTH {
Err(KeyError::TooShort(ShortKeyError { length: key.len() }))
} else {
Ok(Key(key))
}
}
}
#[cfg(test)]
mod test {
use super::Key;
use crate::crypto::encryption::EncryptionKey;
use crate::crypto::signing::SigningKey;
#[test]
fn try_from_works() {
use core::convert::TryInto;
let data = (0..32).collect::<Vec<_>>();
let key_res: Result<Key, _> = data[0..31].try_into();
assert!(key_res.is_err());
let key_res: Result<Key, _> = data.as_slice().try_into();
assert!(key_res.is_ok());
}
#[test]
fn non_deterministic_generate() {
let key_a = Key::generate();
let key_b = Key::generate();
assert_ne!(SigningKey::derive(&key_a), SigningKey::derive(&key_b));
assert_ne!(EncryptionKey::derive(&key_a), EncryptionKey::derive(&key_b));
}
#[test]
fn debug_does_not_leak_key() {
let key = Key::generate();
assert_eq!(format!("{key:?}"), "Key(\"***\")");
}
#[cfg(feature = "serde")]
#[test]
fn deserialize_key_from_string() {
use serde_json::from_str;
let key_str = "\"01234567890123456789012345678901\"";
let key: Result<Key, _> = from_str(key_str);
assert!(key.is_ok());
assert_eq!(key.unwrap().master(), b"01234567890123456789012345678901");
}
#[cfg(feature = "serde")]
#[test]
fn deserialize_key_from_bytes() {
use serde_json::from_slice;
let key_bytes = b"[48,49,50,51,52,53,54,55,56,57,48,49,50,51,52,53,54,55,56,57,48,49,50,51,52,53,54,55,56,57,48,49]";
let key: Result<Key, _> = from_slice(key_bytes);
assert!(key.is_ok());
assert_eq!(key.unwrap().master(), b"01234567890123456789012345678901");
}
#[cfg(feature = "serde")]
#[test]
fn deserialize_key_invalid_format() {
use serde_json::from_str;
let invalid_key = "{\"invalid\": \"format\"}";
let key: Result<Key, _> = from_str(invalid_key);
assert_eq!(
key.unwrap_err().to_string(),
"invalid key format: expected a byte array or a string"
);
}
}