1use core::fmt;
2
3use schemars::JsonSchema;
4use serde::{de, ser, Deserialize, Deserializer, Serialize};
5use sha2::{Digest, Sha256};
6use thiserror::Error;
7
8use crate::errors::ErrorKind;
9use crate::prelude::*;
10use crate::{StdError, StdResult};
11
12#[derive(JsonSchema, Debug, Copy, Clone, PartialEq, Eq, Hash, cw_schema::Schemaifier)]
18#[schemaifier(type = cw_schema::NodeType::Checksum)]
19pub struct Checksum(#[schemars(with = "String")] [u8; 32]);
20
21impl Checksum {
22    pub fn generate(wasm: &[u8]) -> Self {
23        Checksum(Sha256::digest(wasm).into())
24    }
25
26    pub fn from_hex(input: &str) -> StdResult<Self> {
29        let mut binary = [0u8; 32];
30        hex::decode_to_slice(input, &mut binary)
31            .map_err(|err| StdError::from(err).with_kind(ErrorKind::Parsing))?;
32
33        Ok(Self(binary))
34    }
35
36    pub fn to_hex(self) -> String {
40        self.to_string()
41    }
42
43    pub fn as_slice(&self) -> &[u8] {
46        &self.0
47    }
48}
49
50impl fmt::Display for Checksum {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        for byte in self.0.iter() {
53            write!(f, "{byte:02x}")?;
54        }
55        Ok(())
56    }
57}
58
59impl From<[u8; 32]> for Checksum {
60    fn from(data: [u8; 32]) -> Self {
61        Checksum(data)
62    }
63}
64
65impl AsRef<[u8; 32]> for Checksum {
66    fn as_ref(&self) -> &[u8; 32] {
67        &self.0
68    }
69}
70
71impl Serialize for Checksum {
73    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: ser::Serializer,
76    {
77        if serializer.is_human_readable() {
78            serializer.serialize_str(&self.to_hex())
79        } else {
80            serializer.serialize_bytes(&self.0)
81        }
82    }
83}
84
85impl<'de> Deserialize<'de> for Checksum {
87    fn deserialize<D>(deserializer: D) -> Result<Checksum, D::Error>
88    where
89        D: Deserializer<'de>,
90    {
91        if deserializer.is_human_readable() {
92            deserializer.deserialize_str(ChecksumVisitor)
93        } else {
94            deserializer.deserialize_bytes(ChecksumBytesVisitor)
95        }
96    }
97}
98
99struct ChecksumVisitor;
100
101impl de::Visitor<'_> for ChecksumVisitor {
102    type Value = Checksum;
103
104    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
105        formatter.write_str("valid hex encoded 32 byte checksum")
106    }
107
108    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
109    where
110        E: de::Error,
111    {
112        match Checksum::from_hex(v) {
113            Ok(data) => Ok(data),
114            Err(_) => Err(E::custom(format!("invalid checksum: {v}"))),
115        }
116    }
117}
118
119struct ChecksumBytesVisitor;
120
121impl de::Visitor<'_> for ChecksumBytesVisitor {
122    type Value = Checksum;
123
124    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
125        formatter.write_str("32 byte checksum")
126    }
127
128    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
129    where
130        E: de::Error,
131    {
132        Checksum::try_from(v).map_err(|ChecksumError| E::invalid_length(v.len(), &"32 bytes"))
133    }
134}
135
136#[derive(Error, Debug)]
137#[error("Checksum not of length 32")]
138pub struct ChecksumError;
139
140impl TryFrom<&[u8]> for Checksum {
141    type Error = ChecksumError;
142
143    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
144        if value.len() != 32 {
145            return Err(ChecksumError);
146        }
147        let mut data = [0u8; 32];
148        data.copy_from_slice(value);
149        Ok(Checksum(data))
150    }
151}
152
153impl From<Checksum> for Vec<u8> {
154    fn from(original: Checksum) -> Vec<u8> {
155        original.0.into()
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    use crate::to_json_string;
164
165    #[test]
166    fn generate_works() {
167        let wasm = vec![0x68, 0x69, 0x6a];
168        let checksum = Checksum::generate(&wasm);
169
170        let expected = [
172            0x72, 0x2c, 0x8c, 0x99, 0x3f, 0xd7, 0x5a, 0x76, 0x27, 0xd6, 0x9e, 0xd9, 0x41, 0x34,
173            0x4f, 0xe2, 0xa1, 0x42, 0x3a, 0x3e, 0x75, 0xef, 0xd3, 0xe6, 0x77, 0x8a, 0x14, 0x28,
174            0x84, 0x22, 0x71, 0x04,
175        ];
176        assert_eq!(checksum.0, expected);
177    }
178
179    #[test]
180    fn implemented_display() {
181        let wasm = vec![0x68, 0x69, 0x6a];
182        let checksum = Checksum::generate(&wasm);
183        let embedded = format!("Check: {checksum}");
185        assert_eq!(
186            embedded,
187            "Check: 722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
188        );
189        assert_eq!(
190            checksum.to_string(),
191            "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
192        );
193    }
194
195    #[test]
196    fn from_hex_works() {
197        let checksum = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104";
199        let parsed = Checksum::from_hex(checksum).unwrap();
200        assert_eq!(parsed, Checksum::generate(b"hij"));
201        assert_eq!(parsed.to_hex(), checksum);
203
204        let too_short = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a1428842271";
206        assert!(Checksum::from_hex(too_short).is_err());
207        let invalid_char = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a1428842271g4";
208        assert!(Checksum::from_hex(invalid_char).is_err());
209        let too_long = "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a14288422710400";
210        assert!(Checksum::from_hex(too_long).is_err());
211    }
212
213    #[test]
214    fn to_hex_works() {
215        let wasm = vec![0x68, 0x69, 0x6a];
216        let checksum = Checksum::generate(&wasm);
217        assert_eq!(
219            checksum.to_hex(),
220            "722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104"
221        );
222    }
223
224    #[test]
225    fn into_vec_works() {
226        let checksum = Checksum::generate(&[12u8; 17]);
227        let as_vec: Vec<u8> = checksum.into();
228        assert_eq!(as_vec, checksum.0);
229    }
230
231    #[test]
232    fn ref_conversions_work() {
233        let checksum = Checksum::generate(&[12u8; 17]);
234        let _: &[u8; 32] = checksum.as_ref();
236        let _: &[u8] = checksum.as_ref();
237        let _: &[u8; 32] = checksum.as_ref();
239        let _: &[u8] = checksum.as_ref();
240    }
241
242    #[test]
243    fn serde_works() {
244        let checksum =
246            Checksum::from_hex("722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104")
247                .unwrap();
248
249        let serialized = to_json_string(&checksum).unwrap();
250        assert_eq!(
251            serialized,
252            "\"722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104\""
253        );
254
255        let deserialized: Checksum = serde_json::from_str(&serialized).unwrap();
256        assert_eq!(deserialized, checksum);
257    }
258
259    #[test]
260    fn msgpack_works() {
261        let checksum =
263            Checksum::from_hex("722c8c993fd75a7627d69ed941344fe2a1423a3e75efd3e6778a142884227104")
264                .unwrap();
265
266        let serialized = rmp_serde::to_vec(&checksum).unwrap();
267        let expected = vec![
269            0xc4, 0x20, 0x72, 0x2c, 0x8c, 0x99, 0x3f, 0xd7, 0x5a, 0x76, 0x27, 0xd6, 0x9e, 0xd9,
270            0x41, 0x34, 0x4f, 0xe2, 0xa1, 0x42, 0x3a, 0x3e, 0x75, 0xef, 0xd3, 0xe6, 0x77, 0x8a,
271            0x14, 0x28, 0x84, 0x22, 0x71, 0x04,
272        ];
273        assert_eq!(serialized, expected);
274
275        let deserialized: Checksum = rmp_serde::from_slice(&serialized).unwrap();
276        assert_eq!(deserialized, checksum);
277    }
278}