cashu 0.16.0

Cashu shared types and crypto utilities, used as the foundation for the CDK and their crates
Documentation
// Copyright (c) 2022-2023 Yuki Kishimoto
// Distributed under the MIT software license

//! Hex

use core::fmt;

use crate::ensure_cdk;

/// Hex error
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
    /// An invalid character was found
    InvalidHexCharacter {
        /// Char
        c: char,
        /// Char index
        index: usize,
    },
    /// A hex string's length needs to be even, as two digits correspond to
    /// one byte.
    OddLength,
}

impl std::error::Error for Error {}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::InvalidHexCharacter { c, index } => {
                write!(f, "Invalid character {c} at position {index}")
            }
            Self::OddLength => write!(f, "Odd number of digits"),
        }
    }
}

#[inline]
fn from_digit(num: u8) -> char {
    if num < 10 {
        (b'0' + num) as char
    } else {
        (b'a' + num - 10) as char
    }
}

/// Hex encode
pub fn encode<T>(data: T) -> String
where
    T: AsRef<[u8]>,
{
    let bytes: &[u8] = data.as_ref();
    let mut hex: String = String::with_capacity(2 * bytes.len());
    for byte in bytes.iter() {
        hex.push(from_digit(byte >> 4));
        hex.push(from_digit(byte & 0xF));
    }
    hex
}

const fn val(c: u8, idx: usize) -> Result<u8, Error> {
    match c {
        b'A'..=b'F' => Ok(c - b'A' + 10),
        b'a'..=b'f' => Ok(c - b'a' + 10),
        b'0'..=b'9' => Ok(c - b'0'),
        _ => Err(Error::InvalidHexCharacter {
            c: c as char,
            index: idx,
        }),
    }
}

/// Hex decode
pub fn decode<T>(hex: T) -> Result<Vec<u8>, Error>
where
    T: AsRef<[u8]>,
{
    let hex = hex.as_ref();
    let len = hex.len();

    ensure_cdk!(len % 2 == 0, Error::OddLength);

    let mut bytes: Vec<u8> = Vec::with_capacity(len / 2);

    for i in (0..len).step_by(2) {
        let high = val(hex[i], i)?;
        let low = val(hex[i + 1], i + 1)?;
        bytes.push((high << 4) | low);
    }

    Ok(bytes)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encode() {
        assert_eq!(encode("foobar"), "666f6f626172");
    }

    #[test]
    fn test_decode() {
        assert_eq!(
            decode("666f6f626172"),
            Ok(String::from("foobar").into_bytes())
        );
    }

    #[test]
    fn test_invalid_length() {
        assert_eq!(decode("1").unwrap_err(), Error::OddLength);
        assert_eq!(decode("666f6f6261721").unwrap_err(), Error::OddLength);
    }

    #[test]
    fn test_invalid_char() {
        assert_eq!(
            decode("66ag").unwrap_err(),
            Error::InvalidHexCharacter { c: 'g', index: 3 }
        );
    }
}