flow_value/
decimal.rs

1use rust_decimal::Decimal;
2
3pub(crate) const TOKEN: &str = "$$d";
4
5pub type Target = Decimal;
6
7pub mod opt {
8    pub fn serialize<S>(sig: &Option<super::Target>, s: S) -> Result<S::Ok, S::Error>
9    where
10        S: serde::Serializer,
11    {
12        match sig {
13            Some(sig) => super::serialize(sig, s),
14            None => s.serialize_none(),
15        }
16    }
17
18    pub fn deserialize<'de, D>(d: D) -> Result<Option<super::Target>, D::Error>
19    where
20        D: serde::Deserializer<'de>,
21    {
22        d.deserialize_option(crate::OptionVisitor(super::Visitor))
23    }
24}
25
26pub fn serialize<S>(d: &Decimal, s: S) -> Result<S::Ok, S::Error>
27where
28    S: serde::Serializer,
29{
30    s.serialize_newtype_struct(TOKEN, &crate::Bytes(&d.serialize()))
31}
32
33struct Visitor;
34
35impl<'de> serde::de::Visitor<'de> for Visitor {
36    type Value = Decimal;
37
38    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
39        formatter.write_str("decimal")
40    }
41
42    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
43    where
44        E: serde::de::Error,
45    {
46        if v.len() != 16 {
47            return Err(serde::de::Error::invalid_length(v.len(), &"16"));
48        }
49
50        let buf: [u8; 16] = v.try_into().unwrap();
51        Ok(Decimal::deserialize(buf))
52    }
53
54    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
55    where
56        E: serde::de::Error,
57    {
58        Ok(Decimal::from(v))
59    }
60
61    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
62    where
63        E: serde::de::Error,
64    {
65        Ok(Decimal::from(v))
66    }
67
68    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
69    where
70        E: serde::de::Error,
71    {
72        // TODO: this is lossy
73        Decimal::try_from(v).map_err(serde::de::Error::custom)
74    }
75
76    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
77    where
78        E: serde::de::Error,
79    {
80        let v = v.trim();
81        if v.bytes().any(|c| c == b'e' || c == b'E') {
82            Decimal::from_scientific(v).map_err(serde::de::Error::custom)
83        } else {
84            v.parse().map_err(serde::de::Error::custom)
85        }
86    }
87
88    fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
89    where
90        D: serde::Deserializer<'de>,
91    {
92        d.deserialize_any(self)
93    }
94}
95
96pub fn deserialize<'de, D>(d: D) -> Result<Decimal, D::Error>
97where
98    D: serde::Deserializer<'de>,
99{
100    d.deserialize_newtype_struct(TOKEN, Visitor)
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::Value;
107    use rust_decimal_macros::dec;
108
109    fn de<'de, D: serde::Deserializer<'de>>(d: D) -> Decimal {
110        deserialize(d).unwrap()
111    }
112
113    #[test]
114    fn test_deserialize_value() {
115        assert_eq!(de(Value::U64(100)), dec!(100));
116        assert_eq!(de(Value::I64(-1)), dec!(-1));
117        assert_eq!(de(Value::Decimal(Decimal::MAX)), Decimal::MAX);
118        assert_eq!(de(Value::F64(1231.2221)), dec!(1231.2221));
119        assert_eq!(de(Value::String("1234.0".to_owned())), dec!(1234));
120        assert_eq!(de(Value::String("  1234.0".to_owned())), dec!(1234));
121        assert_eq!(de(Value::String("1e5".to_owned())), dec!(100000));
122        assert_eq!(de(Value::String("  1e5".to_owned())), dec!(100000));
123    }
124
125    #[test]
126    fn test_serialize() {
127        assert_eq!(
128            serialize(&Decimal::MAX, crate::ser::Serializer).unwrap(),
129            Value::Decimal(Decimal::MAX)
130        );
131    }
132}