use derive_more::{Display, Error};
use rand::{rngs::OsRng, RngCore};
use std::{fmt, str::FromStr};
#[derive(Debug, Display, Error)]
pub struct SecretKeyError;
impl From<SecretKeyError> for std::io::Error {
fn from(_: SecretKeyError) -> Self {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"not valid secret key format",
)
}
}
pub type SecretKey16 = SecretKey<16>;
pub type SecretKey24 = SecretKey<24>;
pub type SecretKey32 = SecretKey<32>;
#[derive(Clone, PartialEq, Eq)]
pub struct SecretKey<const N: usize>([u8; N]);
impl<const N: usize> fmt::Debug for SecretKey<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SecretKey")
.field(&"**OMITTED**".to_string())
.finish()
}
}
impl<const N: usize> Default for SecretKey<N> {
fn default() -> Self {
Self::generate().unwrap()
}
}
impl<const N: usize> SecretKey<N> {
pub fn unprotected_as_bytes(&self) -> &[u8] {
&self.0
}
pub fn unprotected_as_byte_array(&self) -> &[u8; N] {
&self.0
}
pub fn unprotected_into_byte_array(self) -> [u8; N] {
self.0
}
pub fn into_heap_secret_key(self) -> HeapSecretKey {
HeapSecretKey(self.0.to_vec())
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
N
}
pub fn generate() -> Result<Self, SecretKeyError> {
if N < 1 || N > (isize::MAX as usize) {
return Err(SecretKeyError);
}
let mut key = [0; N];
OsRng.fill_bytes(&mut key);
Ok(Self(key))
}
pub fn from_slice(slice: &[u8]) -> Result<Self, SecretKeyError> {
if slice.len() != N {
return Err(SecretKeyError);
}
let mut value = [0u8; N];
value[..N].copy_from_slice(slice);
Ok(Self(value))
}
}
impl<const N: usize> From<[u8; N]> for SecretKey<N> {
fn from(arr: [u8; N]) -> Self {
Self(arr)
}
}
impl<const N: usize> FromStr for SecretKey<N> {
type Err = SecretKeyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = hex::decode(s).map_err(|_| SecretKeyError)?;
Self::from_slice(&bytes)
}
}
impl<const N: usize> fmt::Display for SecretKey<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct HeapSecretKey(Vec<u8>);
impl fmt::Debug for HeapSecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("HeapSecretKey")
.field(&"**OMITTED**".to_string())
.finish()
}
}
impl HeapSecretKey {
pub fn unprotected_as_bytes(&self) -> &[u8] {
&self.0
}
pub fn unprotected_into_bytes(self) -> Vec<u8> {
self.0.to_vec()
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.0.len()
}
pub fn generate(n: usize) -> Result<Self, SecretKeyError> {
if n < 1 || n > (isize::MAX as usize) {
return Err(SecretKeyError);
}
let mut key = Vec::new();
let mut buf = [0; 32];
while key.len() < n {
OsRng.fill_bytes(&mut buf);
key.extend_from_slice(&buf[..std::cmp::min(n - key.len(), 32)]);
}
Ok(Self(key))
}
}
impl From<Vec<u8>> for HeapSecretKey {
fn from(bytes: Vec<u8>) -> Self {
Self(bytes)
}
}
impl<const N: usize> From<[u8; N]> for HeapSecretKey {
fn from(arr: [u8; N]) -> Self {
Self::from(arr.to_vec())
}
}
impl<const N: usize> From<SecretKey<N>> for HeapSecretKey {
fn from(key: SecretKey<N>) -> Self {
key.into_heap_secret_key()
}
}
impl FromStr for HeapSecretKey {
type Err = SecretKeyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(hex::decode(s).map_err(|_| SecretKeyError)?))
}
}
impl fmt::Display for HeapSecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
}
}
impl<const N: usize> PartialEq<[u8; N]> for HeapSecretKey {
fn eq(&self, other: &[u8; N]) -> bool {
self.0.eq(other)
}
}
impl<const N: usize> PartialEq<HeapSecretKey> for [u8; N] {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(self)
}
}
impl<const N: usize> PartialEq<HeapSecretKey> for &[u8; N] {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(*self)
}
}
impl PartialEq<[u8]> for HeapSecretKey {
fn eq(&self, other: &[u8]) -> bool {
self.0.eq(other)
}
}
impl PartialEq<HeapSecretKey> for [u8] {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(self)
}
}
impl PartialEq<HeapSecretKey> for &[u8] {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(*self)
}
}
impl PartialEq<String> for HeapSecretKey {
fn eq(&self, other: &String) -> bool {
self.0.eq(other.as_bytes())
}
}
impl PartialEq<HeapSecretKey> for String {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(self)
}
}
impl PartialEq<HeapSecretKey> for &String {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(*self)
}
}
impl PartialEq<str> for HeapSecretKey {
fn eq(&self, other: &str) -> bool {
self.0.eq(other.as_bytes())
}
}
impl PartialEq<HeapSecretKey> for str {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(self)
}
}
impl PartialEq<HeapSecretKey> for &str {
fn eq(&self, other: &HeapSecretKey) -> bool {
other.eq(*self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
#[test]
fn secret_key_should_be_able_to_be_generated() {
SecretKey::<0>::generate().unwrap_err();
let key = SecretKey::<1>::generate().unwrap();
assert_eq!(key.len(), 1);
let key = SecretKey::<100>::generate().unwrap();
assert_eq!(key.len(), 100);
}
#[test]
fn heap_secret_key_should_be_able_to_be_generated() {
HeapSecretKey::generate(0).unwrap_err();
let key = HeapSecretKey::generate(1).unwrap();
assert_eq!(key.len(), 1);
let key = HeapSecretKey::generate(100).unwrap();
assert_eq!(key.len(), 100);
}
}