use miette::Diagnostic;
use serde::de::Error as DeError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::fmt::{Debug, Display, Formatter, Write};
use std::num::ParseIntError;
use thiserror::Error;
const HEXADECIMAL_RADIX: u32 = 16;
#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct Hash<const N: usize> {
bytes: [u8; N],
}
impl<const N: usize> Hash<N> {
#[must_use]
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}
pub fn from_string(hex: &str) -> Result<Self, HashError> {
let bytes = to_bytes(hex)?;
Ok(Hash::new(bytes))
}
#[must_use]
pub fn to_hex(&self) -> String {
self.bytes.iter().fold(String::new(), |mut acc, &b| {
let _ = write!(acc, "{b:02x}");
acc
})
}
#[must_use]
pub fn as_bytes(&self) -> &[u8; N] {
&self.bytes
}
#[must_use]
pub fn truncate<const M: usize>(&self) -> Option<Hash<M>> {
let bytes: [u8; M] = self.bytes[..M].try_into().ok()?;
Some(Hash::new(bytes))
}
}
impl<const N: usize> Debug for Hash<N> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
write!(formatter, "{}", self.to_hex())
}
}
impl<const N: usize> Default for Hash<N> {
fn default() -> Self {
Self::new([0; N])
}
}
impl<const N: usize> Display for Hash<N> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
write!(formatter, "{}", self.to_hex())
}
}
impl<const N: usize> Serialize for Hash<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_hex())
}
}
impl<'de, const N: usize> Deserialize<'de> for Hash<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hex_str = String::deserialize(deserializer)?;
Hash::from_string(&hex_str).map_err(DeError::custom)
}
}
fn to_bytes<const N: usize>(hex: &str) -> Result<[u8; N], HashError> {
let length = hex.len();
if length != N * 2 {
return Err(HashError::InvalidLength {
expected: N * 2,
actual: length,
});
}
let mut bytes = [0_u8; N];
for (i, byte) in bytes.iter_mut().enumerate() {
let start = i * 2;
let byte_str = &hex[start..start + 2];
*byte = to_byte(byte_str).map_err(|_| HashError::InvalidCharacter { position: start })?;
}
Ok(bytes)
}
fn to_byte(hex: &str) -> Result<u8, ParseIntError> {
u8::from_str_radix(hex, HEXADECIMAL_RADIX)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Error, Diagnostic)]
pub enum HashError {
#[error("Invalid hex length\nExpected: {expected}\nActual: {actual}")]
InvalidLength { expected: usize, actual: usize },
#[error("Invalid hex character at position {position}")]
InvalidCharacter { position: usize },
}