1use core::fmt;
2
3use schemars::JsonSchema;
4use serde::{de, ser, Deserialize, Deserializer, Serialize};
5use sha2::{Digest, Sha256};
6use thiserror::Error;
7
8use crate::prelude::*;
9use crate::{StdError, StdResult};
10
11#[derive(JsonSchema, Debug, Copy, Clone, PartialEq, Eq, Hash)]
17pub struct Checksum(#[schemars(with = "String")] [u8; 32]);
18
19impl Checksum {
20 pub fn generate(wasm: &[u8]) -> Self {
21 Checksum(Sha256::digest(wasm).into())
22 }
23
24 pub fn from_hex(input: &str) -> StdResult<Self> {
27 let mut binary = [0u8; 32];
28 hex::decode_to_slice(input, &mut binary).map_err(StdError::invalid_hex)?;
29
30 Ok(Self(binary))
31 }
32
33 pub fn to_hex(self) -> String {
37 self.to_string()
38 }
39
40 pub fn as_slice(&self) -> &[u8] {
43 &self.0
44 }
45}
46
47impl fmt::Display for Checksum {
48 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49 for byte in self.0.iter() {
50 write!(f, "{byte:02x}")?;
51 }
52 Ok(())
53 }
54}
55
56impl From<[u8; 32]> for Checksum {
57 fn from(data: [u8; 32]) -> Self {
58 Checksum(data)
59 }
60}
61
62impl AsRef<[u8; 32]> for Checksum {
63 fn as_ref(&self) -> &[u8; 32] {
64 &self.0
65 }
66}
67
68impl Serialize for Checksum {
70 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71 where
72 S: ser::Serializer,
73 {
74 if serializer.is_human_readable() {
75 serializer.serialize_str(&self.to_hex())
76 } else {
77 serializer.serialize_bytes(&self.0)
78 }
79 }
80}
81
82impl<'de> Deserialize<'de> for Checksum {
84 fn deserialize<D>(deserializer: D) -> Result<Checksum, D::Error>
85 where
86 D: Deserializer<'de>,
87 {
88 if deserializer.is_human_readable() {
89 deserializer.deserialize_str(ChecksumVisitor)
90 } else {
91 deserializer.deserialize_bytes(ChecksumBytesVisitor)
92 }
93 }
94}
95
96struct ChecksumVisitor;
97
98impl de::Visitor<'_> for ChecksumVisitor {
99 type Value = Checksum;
100
101 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
102 formatter.write_str("valid hex encoded 32 byte checksum")
103 }
104
105 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
106 where
107 E: de::Error,
108 {
109 match Checksum::from_hex(v) {
110 Ok(data) => Ok(data),
111 Err(_) => Err(E::custom(format!("invalid checksum: {v}"))),
112 }
113 }
114}
115
116struct ChecksumBytesVisitor;
117
118impl de::Visitor<'_> for ChecksumBytesVisitor {
119 type Value = Checksum;
120
121 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
122 formatter.write_str("32 byte checksum")
123 }
124
125 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
126 where
127 E: de::Error,
128 {
129 Checksum::try_from(v).map_err(|ChecksumError| E::invalid_length(v.len(), &"32 bytes"))
130 }
131}
132
133#[derive(Error, Debug)]
134#[error("Checksum not of length 32")]
135pub struct ChecksumError;
136
137impl TryFrom<&[u8]> for Checksum {
138 type Error = ChecksumError;
139
140 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
141 if value.len() != 32 {
142 return Err(ChecksumError);
143 }
144 let mut data = [0u8; 32];
145 data.copy_from_slice(value);
146 Ok(Checksum(data))
147 }
148}
149
150impl From<Checksum> for Vec<u8> {
151 fn from(original: Checksum) -> Vec<u8> {
152 original.0.into()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 use crate::to_json_string;
161
162 #[test]
163 fn generate_works() {
164 let wasm = vec![0x68, 0x69, 0x6a];
165 let checksum = Checksum::generate(&wasm);
166
167 let expected = [
169 0x72, 0x2c, 0x8c, 0x99, 0x3f, 0xd7, 0x5a, 0x76, 0x27, 0xd6, 0x9e, 0xd9, 0x41, 0x34,
170 0x4f, 0xe2, 0xa1, 0x42, 0x3a, 0x3e, 0x75, 0xef, 0xd3, 0xe6, 0x77, 0x8a, 0x14, 0x28,
171 0x84, 0x22, 0x71, 0x04,
172 ];
173 assert_eq!(checksum.0, expected);
174 }
175
176 #[test]
177 fn implemented_display() {
178 let wasm = vec![0x68, 0x69, 0x6a];
179 let checksum = Checksum::generate(&wasm);
180 let embedded = format!("Check: {checksum}");
182 assert_eq!(
183 embedded,
184 "Check: 722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
185 );
186 assert_eq!(
187 checksum.to_string(),
188 "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
189 );
190 }
191
192 #[test]
193 fn from_hex_works() {
194 let checksum = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104";
196 let parsed = Checksum::from_hex(checksum).unwrap();
197 assert_eq!(parsed, Checksum::generate(b"hij"));
198 assert_eq!(parsed.to_hex(), checksum);
200
201 let too_short = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a1428842271";
203 assert!(Checksum::from_hex(too_short).is_err());
204 let invalid_char = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a1428842271g4";
205 assert!(Checksum::from_hex(invalid_char).is_err());
206 let too_long = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a14288422710400";
207 assert!(Checksum::from_hex(too_long).is_err());
208 }
209
210 #[test]
211 fn to_hex_works() {
212 let wasm = vec![0x68, 0x69, 0x6a];
213 let checksum = Checksum::generate(&wasm);
214 assert_eq!(
216 checksum.to_hex(),
217 "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
218 );
219 }
220
221 #[test]
222 fn into_vec_works() {
223 let checksum = Checksum::generate(&[12u8; 17]);
224 let as_vec: Vec<u8> = checksum.into();
225 assert_eq!(as_vec, checksum.0);
226 }
227
228 #[test]
229 fn ref_conversions_work() {
230 let checksum = Checksum::generate(&[12u8; 17]);
231 let _: &[u8; 32] = checksum.as_ref();
233 let _: &[u8] = checksum.as_ref();
234 let _: &[u8; 32] = checksum.as_ref();
236 let _: &[u8] = checksum.as_ref();
237 }
238
239 #[test]
240 fn serde_works() {
241 let checksum =
243 Checksum::from_hex("722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104")
244 .unwrap();
245
246 let serialized = to_json_string(&checksum).unwrap();
247 assert_eq!(
248 serialized,
249 "\"722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104\""
250 );
251
252 let deserialized: Checksum = serde_json::from_str(&serialized).unwrap();
253 assert_eq!(deserialized, checksum);
254 }
255
256 #[test]
257 fn msgpack_works() {
258 let checksum =
260 Checksum::from_hex("722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104")
261 .unwrap();
262
263 let serialized = rmp_serde::to_vec(&checksum).unwrap();
264 let expected = vec![
266 0xc4, 0x20, 0x72, 0x2c, 0x8c, 0x99, 0x3f, 0xd7, 0x5a, 0x76, 0x27, 0xd6, 0x9e, 0xd9,
267 0x41, 0x34, 0x4f, 0xe2, 0xa1, 0x42, 0x3a, 0x3e, 0x75, 0xef, 0xd3, 0xe6, 0x77, 0x8a,
268 0x14, 0x28, 0x84, 0x22, 0x71, 0x04,
269 ];
270 assert_eq!(serialized, expected);
271
272 let deserialized: Checksum = rmp_serde::from_slice(&serialized).unwrap();
273 assert_eq!(deserialized, checksum);
274 }
275}