use alloc::string::String;
use core::{ops::Deref, slice};
use crate::utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, HexParseError, Serializable,
bytes_to_hex_string, hex_to_bytes,
};
pub const DIGEST192_BYTES: usize = 24;
pub const DIGEST256_BYTES: usize = 32;
pub const DIGEST512_BYTES: usize = 64;
pub type Digest192 = Digest<DIGEST192_BYTES>;
pub type Digest256 = Digest<DIGEST256_BYTES>;
pub type Digest512 = Digest<DIGEST512_BYTES>;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
#[repr(transparent)]
pub struct Digest<const N: usize = DIGEST256_BYTES>([u8; N]);
impl<const N: usize> Digest<N> {
#[inline]
pub const fn new(bytes: [u8; N]) -> Self {
Self(bytes)
}
#[inline]
pub fn as_bytes(&self) -> &[u8; N] {
&self.0
}
pub fn digests_as_bytes(digests: &[Digest<N>]) -> &[u8] {
let p = digests.as_ptr();
let len = digests.len() * N;
unsafe { slice::from_raw_parts(p as *const u8, len) }
}
}
impl<const N: usize> Default for Digest<N> {
fn default() -> Self {
Self([0; N])
}
}
impl<const N: usize> Deref for Digest<N> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<const N: usize> From<Digest<N>> for [u8; N] {
fn from(value: Digest<N>) -> Self {
value.0
}
}
impl<const N: usize> From<[u8; N]> for Digest<N> {
fn from(value: [u8; N]) -> Self {
Self(value)
}
}
impl<const N: usize> From<Digest<N>> for String {
fn from(value: Digest<N>) -> Self {
bytes_to_hex_string(value.0)
}
}
impl<const N: usize> TryFrom<&str> for Digest<N> {
type Error = HexParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
if N == DIGEST192_BYTES {
let short_len = (DIGEST192_BYTES * 2) + 2;
let long_len = (DIGEST256_BYTES * 2) + 2;
if value.len() == short_len {
let bytes = hex_to_bytes::<N>(value)?;
return Ok(Self(bytes));
}
if value.len() == long_len {
let bytes = hex_to_bytes::<DIGEST256_BYTES>(value)?;
let padding = &bytes[DIGEST192_BYTES..];
if padding.iter().all(|byte| *byte == 0) {
let mut trimmed = [0u8; N];
trimmed.copy_from_slice(&bytes[..N]);
return Ok(Self(trimmed));
}
}
return Err(HexParseError::InvalidLength { expected: short_len, actual: value.len() });
}
hex_to_bytes(value).map(Self)
}
}
impl<const N: usize> Serializable for Digest<N> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.0);
}
}
impl<const N: usize> Deserializable for Digest<N> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
source.read_array().map(Self)
}
}
#[cfg(test)]
mod tests {
use core::mem::{align_of, size_of};
use super::*;
#[test]
fn test_memory_layout_assumptions() {
assert_eq!(size_of::<Digest<32>>(), size_of::<[u8; 32]>());
assert_eq!(align_of::<Digest<32>>(), align_of::<[u8; 32]>());
assert_eq!(size_of::<Digest<64>>(), size_of::<[u8; 64]>());
assert_eq!(align_of::<Digest<64>>(), align_of::<[u8; 64]>());
assert_eq!(size_of::<Digest<24>>(), size_of::<[u8; 24]>());
assert_eq!(align_of::<Digest<24>>(), align_of::<[u8; 24]>());
assert_eq!(size_of::<Digest192>(), 24);
assert_eq!(size_of::<Digest256>(), 32);
assert_eq!(size_of::<Digest512>(), 64);
}
#[test]
fn test_digest_default_32() {
let digest = Digest::<32>::default();
assert_eq!(digest.as_bytes(), &[0u8; 32]);
}
#[test]
fn test_digest_default_64() {
let digest = Digest::<64>::default();
assert_eq!(digest.as_bytes(), &[0u8; 64]);
}
#[test]
fn test_digest256_alias() {
let digest = Digest256::default();
assert_eq!(digest.as_bytes(), &[0u8; 32]);
}
#[test]
fn test_digest512_alias() {
let digest = Digest512::default();
assert_eq!(digest.as_bytes(), &[0u8; 64]);
}
#[test]
fn test_digest_from_bytes_32() {
let bytes = [1u8; 32];
let digest = Digest::<32>::from(bytes);
assert_eq!(digest.as_bytes(), &bytes);
}
#[test]
fn test_digest_from_bytes_64() {
let bytes = [1u8; 64];
let digest = Digest::<64>::from(bytes);
assert_eq!(digest.as_bytes(), &bytes);
}
#[test]
fn test_digest192_accepts_zero_padded_hex() {
let mut bytes = [0u8; DIGEST192_BYTES];
bytes[0] = 1;
bytes[23] = 255;
let mut padded = [0u8; DIGEST256_BYTES];
padded[..DIGEST192_BYTES].copy_from_slice(&bytes);
let hex = bytes_to_hex_string(padded);
let parsed = Digest192::try_from(hex.as_str()).expect("digest192 should accept padding");
assert_eq!(parsed.as_bytes(), &bytes);
}
#[test]
fn test_digest_hex_roundtrip_32() {
let bytes = [0xab; 32];
let digest = Digest::<32>::from(bytes);
let hex: String = digest.into();
let recovered = Digest::<32>::try_from(hex.as_str()).unwrap();
assert_eq!(recovered.as_bytes(), &bytes);
}
#[test]
fn test_digest_hex_roundtrip_64() {
let bytes = [0xcd; 64];
let digest = Digest::<64>::from(bytes);
let hex: String = digest.into();
let recovered = Digest::<64>::try_from(hex.as_str()).unwrap();
assert_eq!(recovered.as_bytes(), &bytes);
}
#[test]
fn test_digest_hex_roundtrip_24() {
let bytes = [0xef; 24];
let digest = Digest::<24>::from(bytes);
let hex: String = digest.into();
let recovered = Digest::<24>::try_from(hex.as_str()).unwrap();
assert_eq!(recovered.as_bytes(), &bytes);
}
#[test]
fn test_digest_digests_as_bytes_32() {
let d1 = Digest::<32>::from([1u8; 32]);
let d2 = Digest::<32>::from([2u8; 32]);
let digests = [d1, d2];
let bytes = Digest::<32>::digests_as_bytes(&digests);
assert_eq!(bytes.len(), 64);
assert_eq!(&bytes[0..32], &[1u8; 32]);
assert_eq!(&bytes[32..64], &[2u8; 32]);
}
#[test]
fn test_digest_digests_as_bytes_64() {
let d1 = Digest::<64>::from([1u8; 64]);
let d2 = Digest::<64>::from([2u8; 64]);
let digests = [d1, d2];
let bytes = Digest::<64>::digests_as_bytes(&digests);
assert_eq!(bytes.len(), 128);
assert_eq!(&bytes[0..64], &[1u8; 64]);
assert_eq!(&bytes[64..128], &[2u8; 64]);
}
}