1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#![doc = include_str!("../README.md")]

use std::fmt;

use bytesize::ByteSize;
use serde::{de, Serialize, Serializer};

pub fn serialize<S>(size: &ByteSize, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    if serializer.is_human_readable() {
        <str>::serialize(size.to_string().as_str(), serializer)
    } else {
        size.0.serialize(serializer)
    }
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<ByteSize, D::Error>
where
    D: de::Deserializer<'de>,
{
    struct Helper;
    impl<'de> de::Visitor<'de> for Helper {
        type Value = ByteSize;

        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
            formatter.write_str("an integer or string")
        }

        fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
            Ok(ByteSize(value))
        }

        fn visit_str<E: de::Error>(self, value: &str) -> Result<Self::Value, E> {
            if let Ok(val) = value.parse() {
                Ok(val)
            } else {
                Err(E::invalid_value(
                    de::Unexpected::Str(value),
                    &"parsable string",
                ))
            }
        }
    }

    deserializer.deserialize_any(Helper)
}

#[cfg(test)]
mod tests {
    use super::*;

    use quickcheck_macros::quickcheck;
    use serde::Deserialize;

    #[derive(Serialize, Deserialize)]
    #[serde(transparent)]
    struct W(#[serde(with = "self")] ByteSize);

    #[quickcheck]
    fn deserializes_any(x: u64) {
        let _: W = serde_json::from_str(&x.to_string()).unwrap();
    }

    #[quickcheck]
    fn serializes_any(x: u64) {
        serde_json::to_string(&ByteSize(x).to_string()).unwrap();
    }

    #[test]
    fn deserialize_sizes() {
        #[track_caller]
        fn check_str(s: &str) {
            assert_eq!(
                serde_json::from_str::<W>(&format!("{:?}", s)).unwrap().0,
                s.parse().unwrap()
            );
        }

        #[track_caller]
        fn check(s: &str) {
            assert_eq!(serde_json::from_str::<W>(s).unwrap().0, s.parse().unwrap());
        }

        check_str("5 MB");
        check_str("12.34 KB");
        check("123");
        check("0");
    }
}