1use std::error::Error;
2use std::fmt;
3use std::iter;
4use std::str::FromStr;
5
6#[derive(Debug)]
7pub struct InvalidSliceLength;
8
9#[derive(Clone, Copy, Hash, Eq, Ord, PartialEq, PartialOrd)]
10pub struct Sha256(pub [u8; 32]);
11
12impl Sha256 {
13 pub fn from_slice(bytes: &[u8]) -> Result<Sha256, InvalidSliceLength> {
14 let mut result = [0; 32];
15 if bytes.len() != result.len() {
16 return Err(InvalidSliceLength);
17 }
18 result.copy_from_slice(bytes);
19 Ok(Sha256(result))
20 }
21}
22
23impl fmt::Debug for Sha256 {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 for &b in &self.0 {
26 write!(f, "{:02x}", b)?;
27 }
28 Ok(())
29 }
30}
31
32impl fmt::Display for Sha256 {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 fmt::Debug::fmt(self, f)
35 }
36}
37
38#[derive(Debug)]
39pub enum Sha256FromStrError {
40 InvalidLength(usize),
41 NonHexChar,
42}
43
44impl fmt::Display for Sha256FromStrError {
45 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46 use self::Sha256FromStrError::*;
47 match self {
48 InvalidLength(len) => write!(f, "invalid length {}, must be 64", len),
49 NonHexChar => "non-hex character".fmt(f),
50 }
51 }
52}
53
54impl Error for Sha256FromStrError {}
55
56impl FromStr for Sha256 {
57 type Err = Sha256FromStrError;
58 fn from_str(v: &str) -> Result<Sha256, Sha256FromStrError> {
59 let len = v.chars().count();
60 if len != 64 {
61 return Err(Sha256FromStrError::InvalidLength(len));
62 }
63 let mut result = [0; 32];
64 let starts = v
67 .char_indices()
68 .map(|(i, _)| i)
69 .chain(iter::once(v.len()))
70 .step_by(2);
71 let ends = {
72 let mut e = starts.clone();
73 e.next();
74 e
75 };
76 for (i, (s, e)) in starts.zip(ends).enumerate() {
77 result[i] =
78 u8::from_str_radix(&v[s..e], 16).map_err(|_| Sha256FromStrError::NonHexChar)?;
79 }
80 Ok(Sha256(result))
81 }
82}
83
84#[cfg(feature = "serde")]
85mod serialize {
86 use std::fmt;
87
88 use super::Sha256;
89
90 impl serde::Serialize for Sha256 {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 serializer.serialize_str(&format!("{}", self))
96 }
97 }
98
99 struct HexSha256Visitor;
100
101 impl<'de> serde::de::Visitor<'de> for HexSha256Visitor {
102 type Value = Sha256;
103
104 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 f.write_str("64 character hex value")
106 }
107 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Sha256, E> {
108 use super::Sha256FromStrError::*;
109 v.parse().map_err(|e| match e {
110 InvalidLength(len) => E::invalid_length(len, &self),
111 NonHexChar => E::invalid_value(serde::de::Unexpected::Str(v), &self),
112 })
113 }
114 }
115
116 impl<'de> serde::Deserialize<'de> for Sha256 {
117 fn deserialize<D>(deserializer: D) -> Result<Sha256, D::Error>
118 where
119 D: serde::de::Deserializer<'de>,
120 {
121 deserializer.deserialize_str(HexSha256Visitor)
122 }
123 }
124}