datasynth_core/
serde_decimal.rs1use std::cell::Cell;
17use std::fmt;
18
19use rust_decimal::prelude::ToPrimitive;
20use rust_decimal::Decimal;
21use serde::{self, Deserializer, Serializer};
22
23thread_local! {
24 static NUMERIC_NATIVE: Cell<bool> = const { Cell::new(false) };
25}
26
27pub fn set_numeric_native(native: bool) {
32 NUMERIC_NATIVE.with(|c| c.set(native));
33}
34
35pub fn is_numeric_native() -> bool {
37 NUMERIC_NATIVE.with(|c| c.get())
38}
39
40pub fn serialize<S: Serializer>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> {
42 if is_numeric_native() {
43 match value.to_f64() {
44 Some(f) => serializer.serialize_f64(f),
45 None => serializer.serialize_str(&value.to_string()),
46 }
47 } else {
48 rust_decimal::serde::str::serialize(value, serializer)
49 }
50}
51
52pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Decimal, D::Error> {
54 deserializer.deserialize_any(DecimalVisitor)
55}
56
57pub mod option {
59 use rust_decimal::prelude::ToPrimitive;
60 use rust_decimal::Decimal;
61 use serde::{Deserializer, Serializer};
62
63 use super::{is_numeric_native, OptionDecimalVisitor};
64
65 pub fn serialize<S: Serializer>(
67 value: &Option<Decimal>,
68 serializer: S,
69 ) -> Result<S::Ok, S::Error> {
70 match value {
71 Some(d) => {
72 if is_numeric_native() {
73 match d.to_f64() {
74 Some(f) => serializer.serialize_f64(f),
75 None => serializer.serialize_str(&d.to_string()),
76 }
77 } else {
78 rust_decimal::serde::str_option::serialize(value, serializer)
79 }
80 }
81 None => serializer.serialize_none(),
82 }
83 }
84
85 pub fn deserialize<'de, D: Deserializer<'de>>(
87 deserializer: D,
88 ) -> Result<Option<Decimal>, D::Error> {
89 deserializer.deserialize_any(OptionDecimalVisitor)
90 }
91}
92
93struct DecimalVisitor;
96
97impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
98 type Value = Decimal;
99
100 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
101 write!(f, "a decimal as a string or number")
102 }
103
104 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Decimal, E> {
105 v.parse::<Decimal>().map_err(E::custom)
106 }
107
108 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Decimal, E> {
109 Decimal::try_from(v).map_err(E::custom)
110 }
111
112 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Decimal, E> {
113 Ok(Decimal::from(v))
114 }
115
116 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Decimal, E> {
117 Ok(Decimal::from(v))
118 }
119}
120
121struct OptionDecimalVisitor;
122
123impl<'de> serde::de::Visitor<'de> for OptionDecimalVisitor {
124 type Value = Option<Decimal>;
125
126 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
127 write!(f, "a decimal as a string or number, or null")
128 }
129
130 fn visit_none<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
131 Ok(None)
132 }
133
134 fn visit_unit<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
135 Ok(None)
136 }
137
138 fn visit_some<D: Deserializer<'de>>(
139 self,
140 deserializer: D,
141 ) -> Result<Option<Decimal>, D::Error> {
142 deserializer.deserialize_any(DecimalVisitor).map(Some)
143 }
144
145 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Option<Decimal>, E> {
146 v.parse::<Decimal>().map(Some).map_err(E::custom)
147 }
148
149 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Option<Decimal>, E> {
150 Decimal::try_from(v).map(Some).map_err(E::custom)
151 }
152
153 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Option<Decimal>, E> {
154 Ok(Some(Decimal::from(v)))
155 }
156
157 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Option<Decimal>, E> {
158 Ok(Some(Decimal::from(v)))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use rust_decimal_macros::dec;
166
167 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
168 struct TestStruct {
169 #[serde(with = "super")]
170 amount: Decimal,
171 #[serde(default, with = "super::option")]
172 tax: Option<Decimal>,
173 }
174
175 #[test]
176 fn test_string_mode() {
177 set_numeric_native(false);
178 let s = TestStruct {
179 amount: dec!(1729237.30),
180 tax: Some(dec!(99.95)),
181 };
182 let json = serde_json::to_string(&s).unwrap();
183 assert!(json.contains("\"1729237.30\""), "expected string: {json}");
184 assert!(json.contains("\"99.95\""), "expected string: {json}");
185 }
186
187 #[test]
188 fn test_native_mode() {
189 set_numeric_native(true);
190 let s = TestStruct {
191 amount: dec!(1729237.30),
192 tax: Some(dec!(99.95)),
193 };
194 let json = serde_json::to_string(&s).unwrap();
195 assert!(
197 json.contains(":1729237.3") || json.contains(":1729237.30"),
198 "expected number: {json}"
199 );
200 set_numeric_native(false);
202 }
203
204 #[test]
205 fn test_deserialize_from_string() {
206 let json = r#"{"amount":"1729237.30","tax":"99.95"}"#;
207 let s: TestStruct = serde_json::from_str(json).unwrap();
208 assert_eq!(s.amount, dec!(1729237.30));
209 assert_eq!(s.tax, Some(dec!(99.95)));
210 }
211
212 #[test]
213 fn test_deserialize_from_number() {
214 let json = r#"{"amount":1729237.30,"tax":99.95}"#;
215 let s: TestStruct = serde_json::from_str(json).unwrap();
216 assert_eq!(s.amount, dec!(1729237.3));
218 assert_eq!(s.tax, Some(dec!(99.95)));
219 }
220
221 #[test]
222 fn test_deserialize_null_option() {
223 let json = r#"{"amount":"100.00","tax":null}"#;
224 let s: TestStruct = serde_json::from_str(json).unwrap();
225 assert_eq!(s.tax, None);
226 }
227
228 #[test]
229 fn test_deserialize_missing_option() {
230 let json = r#"{"amount":"100.00"}"#;
231 let s: TestStruct = serde_json::from_str(json).unwrap();
232 assert_eq!(s.tax, None);
233 }
234}