Skip to main content

libtw2_common/
digest.rs

1use std::error::Error;
2use std::fmt;
3use std::iter;
4use std::str::FromStr;
5
6#[derive(Debug)]
7pub struct InvalidSliceLength;
8
9#[derive(Clone, Copy, Hash, Eq, Ord, PartialEq, PartialOrd)]
10pub struct Sha256(pub [u8; 32]);
11
12impl Sha256 {
13    pub fn from_slice(bytes: &[u8]) -> Result<Sha256, InvalidSliceLength> {
14        let mut result = [0; 32];
15        if bytes.len() != result.len() {
16            return Err(InvalidSliceLength);
17        }
18        result.copy_from_slice(bytes);
19        Ok(Sha256(result))
20    }
21}
22
23impl fmt::Debug for Sha256 {
24    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25        for &b in &self.0 {
26            write!(f, "{:02x}", b)?;
27        }
28        Ok(())
29    }
30}
31
32impl fmt::Display for Sha256 {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        fmt::Debug::fmt(self, f)
35    }
36}
37
38#[derive(Debug)]
39pub enum Sha256FromStrError {
40    InvalidLength(usize),
41    NonHexChar,
42}
43
44impl fmt::Display for Sha256FromStrError {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        use self::Sha256FromStrError::*;
47        match self {
48            InvalidLength(len) => write!(f, "invalid length {}, must be 64", len),
49            NonHexChar => "non-hex character".fmt(f),
50        }
51    }
52}
53
54impl Error for Sha256FromStrError {}
55
56impl FromStr for Sha256 {
57    type Err = Sha256FromStrError;
58    fn from_str(v: &str) -> Result<Sha256, Sha256FromStrError> {
59        let len = v.chars().count();
60        if len != 64 {
61            return Err(Sha256FromStrError::InvalidLength(len));
62        }
63        let mut result = [0; 32];
64        // I just want to get string slices with two characters each. :(
65        // Sorry for this monstrosity.
66        let starts = v
67            .char_indices()
68            .map(|(i, _)| i)
69            .chain(iter::once(v.len()))
70            .step_by(2);
71        let ends = {
72            let mut e = starts.clone();
73            e.next();
74            e
75        };
76        for (i, (s, e)) in starts.zip(ends).enumerate() {
77            result[i] =
78                u8::from_str_radix(&v[s..e], 16).map_err(|_| Sha256FromStrError::NonHexChar)?;
79        }
80        Ok(Sha256(result))
81    }
82}
83
84#[cfg(feature = "serde")]
85mod serialize {
86    use std::fmt;
87
88    use super::Sha256;
89
90    impl serde::Serialize for Sha256 {
91        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92        where
93            S: serde::Serializer,
94        {
95            serializer.serialize_str(&format!("{}", self))
96        }
97    }
98
99    struct HexSha256Visitor;
100
101    impl<'de> serde::de::Visitor<'de> for HexSha256Visitor {
102        type Value = Sha256;
103
104        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
105            f.write_str("64 character hex value")
106        }
107        fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Sha256, E> {
108            use super::Sha256FromStrError::*;
109            v.parse().map_err(|e| match e {
110                InvalidLength(len) => E::invalid_length(len, &self),
111                NonHexChar => E::invalid_value(serde::de::Unexpected::Str(v), &self),
112            })
113        }
114    }
115
116    impl<'de> serde::Deserialize<'de> for Sha256 {
117        fn deserialize<D>(deserializer: D) -> Result<Sha256, D::Error>
118        where
119            D: serde::de::Deserializer<'de>,
120        {
121            deserializer.deserialize_str(HexSha256Visitor)
122        }
123    }
124}