decimal_rs/
serde.rs

1// Copyright 2021 CoD Technologies Corp.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! serde implementation.
16
17use crate::decimal::{Buf, Decimal};
18
19#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
20impl serde::Serialize for Decimal {
21    #[inline]
22    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
23    where
24        S: serde::ser::Serializer,
25    {
26        use std::io::Write;
27
28        let mut buf = Buf::new();
29        if serializer.is_human_readable() {
30            write!(&mut buf, "{}", self).map_err(serde::ser::Error::custom)?;
31            let str = unsafe { std::str::from_utf8_unchecked(buf.as_slice()) };
32            str.serialize(serializer)
33        } else {
34            self.encode(&mut buf).map_err(serde::ser::Error::custom)?;
35            buf.as_slice().serialize(serializer)
36        }
37    }
38}
39
40#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
41impl<'de> serde::Deserialize<'de> for Decimal {
42    #[inline]
43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44    where
45        D: serde::de::Deserializer<'de>,
46    {
47        struct DecimalVisitor;
48
49        impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
50            type Value = Decimal;
51
52            #[inline]
53            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
54                write!(formatter, "a decimal")
55            }
56
57            #[inline]
58            fn visit_str<E>(self, v: &str) -> Result<Decimal, E>
59            where
60                E: serde::de::Error,
61            {
62                v.parse().map_err(serde::de::Error::custom)
63            }
64
65            #[inline]
66            fn visit_bytes<E>(self, v: &[u8]) -> Result<Decimal, E>
67            where
68                E: serde::de::Error,
69            {
70                let n = Decimal::decode(v);
71                Ok(n)
72            }
73        }
74
75        if deserializer.is_human_readable() {
76            deserializer.deserialize_str(DecimalVisitor)
77        } else {
78            deserializer.deserialize_bytes(DecimalVisitor)
79        }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn test_serde() {
89        let dec = "123.456".parse::<Decimal>().unwrap();
90
91        let json = serde_json::to_string(&dec).unwrap();
92        assert_eq!(json, r#""123.456""#);
93        let json_dec: Decimal = serde_json::from_str(&json).unwrap();
94        assert_eq!(json_dec, dec);
95
96        let bin = bincode::serialize(&dec).unwrap();
97        let bin_dec: Decimal = bincode::deserialize(&bin).unwrap();
98        assert_eq!(bin_dec, dec);
99    }
100}