cashu/util/
hex.rs

1// Copyright (c) 2022-2023 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Hex
5
6use core::fmt;
7
8use crate::ensure_cdk;
9
10/// Hex error
11#[derive(Debug, PartialEq, Eq)]
12pub enum Error {
13    /// An invalid character was found
14    InvalidHexCharacter {
15        /// Char
16        c: char,
17        /// Char index
18        index: usize,
19    },
20    /// A hex string's length needs to be even, as two digits correspond to
21    /// one byte.
22    OddLength,
23}
24
25impl std::error::Error for Error {}
26
27impl fmt::Display for Error {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            Self::InvalidHexCharacter { c, index } => {
31                write!(f, "Invalid character {c} at position {index}")
32            }
33            Self::OddLength => write!(f, "Odd number of digits"),
34        }
35    }
36}
37
38#[inline]
39fn from_digit(num: u8) -> char {
40    if num < 10 {
41        (b'0' + num) as char
42    } else {
43        (b'a' + num - 10) as char
44    }
45}
46
47/// Hex encode
48pub fn encode<T>(data: T) -> String
49where
50    T: AsRef<[u8]>,
51{
52    let bytes: &[u8] = data.as_ref();
53    let mut hex: String = String::with_capacity(2 * bytes.len());
54    for byte in bytes.iter() {
55        hex.push(from_digit(byte >> 4));
56        hex.push(from_digit(byte & 0xF));
57    }
58    hex
59}
60
61const fn val(c: u8, idx: usize) -> Result<u8, Error> {
62    match c {
63        b'A'..=b'F' => Ok(c - b'A' + 10),
64        b'a'..=b'f' => Ok(c - b'a' + 10),
65        b'0'..=b'9' => Ok(c - b'0'),
66        _ => Err(Error::InvalidHexCharacter {
67            c: c as char,
68            index: idx,
69        }),
70    }
71}
72
73/// Hex decode
74pub fn decode<T>(hex: T) -> Result<Vec<u8>, Error>
75where
76    T: AsRef<[u8]>,
77{
78    let hex = hex.as_ref();
79    let len = hex.len();
80
81    ensure_cdk!(len % 2 == 0, Error::OddLength);
82
83    let mut bytes: Vec<u8> = Vec::with_capacity(len / 2);
84
85    for i in (0..len).step_by(2) {
86        let high = val(hex[i], i)?;
87        let low = val(hex[i + 1], i + 1)?;
88        bytes.push((high << 4) | low);
89    }
90
91    Ok(bytes)
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_encode() {
100        assert_eq!(encode("foobar"), "666f6f626172");
101    }
102
103    #[test]
104    fn test_decode() {
105        assert_eq!(
106            decode("666f6f626172"),
107            Ok(String::from("foobar").into_bytes())
108        );
109    }
110
111    #[test]
112    pub fn test_invalid_length() {
113        assert_eq!(decode("1").unwrap_err(), Error::OddLength);
114        assert_eq!(decode("666f6f6261721").unwrap_err(), Error::OddLength);
115    }
116
117    #[test]
118    pub fn test_invalid_char() {
119        assert_eq!(
120            decode("66ag").unwrap_err(),
121            Error::InvalidHexCharacter { c: 'g', index: 3 }
122        );
123    }
124}