bytesize/
serde.rs

1use alloc::string::{String, ToString as _};
2use core::fmt;
3
4use serde_core::{de, Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::ByteSize;
7
8impl<'de> Deserialize<'de> for ByteSize {
9    fn deserialize<D>(de: D) -> Result<Self, D::Error>
10    where
11        D: Deserializer<'de>,
12    {
13        struct ByteSizeVisitor;
14
15        impl de::Visitor<'_> for ByteSizeVisitor {
16            type Value = ByteSize;
17
18            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
19                formatter.write_str("an integer or string")
20            }
21
22            fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
23                if let Ok(val) = u64::try_from(value) {
24                    Ok(ByteSize(val))
25                } else {
26                    Err(E::invalid_value(
27                        de::Unexpected::Signed(value),
28                        &"integer overflow",
29                    ))
30                }
31            }
32
33            fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
34                Ok(ByteSize(value))
35            }
36
37            fn visit_str<E: de::Error>(self, value: &str) -> Result<Self::Value, E> {
38                if let Ok(val) = value.parse() {
39                    Ok(val)
40                } else {
41                    Err(E::invalid_value(
42                        de::Unexpected::Str(value),
43                        &"parsable string",
44                    ))
45                }
46            }
47        }
48
49        if de.is_human_readable() {
50            de.deserialize_any(ByteSizeVisitor)
51        } else {
52            de.deserialize_u64(ByteSizeVisitor)
53        }
54    }
55}
56
57impl Serialize for ByteSize {
58    fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
59    where
60        S: Serializer,
61    {
62        if ser.is_human_readable() {
63            <String>::serialize(&self.to_string(), ser)
64        } else {
65            self.0.serialize(ser)
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    use serde::{Deserialize, Serialize};
75
76    #[test]
77    fn test_serde() {
78        #[derive(Serialize, Deserialize)]
79        struct S {
80            x: ByteSize,
81        }
82
83        let s = serde_json::from_str::<S>(r#"{ "x": "5 B" }"#).unwrap();
84        assert_eq!(s.x, ByteSize(5));
85
86        let s = serde_json::from_str::<S>(r#"{ "x": 1048576 }"#).unwrap();
87        assert_eq!(s.x, "1 MiB".parse::<ByteSize>().unwrap());
88
89        let s = toml::from_str::<S>(r#"x = "2.5 MiB""#).unwrap();
90        assert_eq!(s.x, "2.5 MiB".parse::<ByteSize>().unwrap());
91
92        // i64 MAX
93        let s = toml::from_str::<S>(r#"x = "9223372036854775807""#).unwrap();
94        assert_eq!(s.x, "9223372036854775807".parse::<ByteSize>().unwrap());
95    }
96
97    #[test]
98    fn test_serde_json() {
99        let json = serde_json::to_string(&ByteSize::mib(1)).unwrap();
100        assert_eq!(json, "\"1.0 MiB\"");
101
102        let deserialized = serde_json::from_str::<ByteSize>(&json).unwrap();
103        assert_eq!(deserialized.0, 1048576);
104    }
105}