datasynth_core/
serde_decimal.rs1use std::cell::Cell;
17use std::fmt;
18
19use rust_decimal::prelude::ToPrimitive;
20use rust_decimal::Decimal;
21use serde::{self, Deserializer, Serializer};
22
23thread_local! {
24 static NUMERIC_NATIVE: Cell<bool> = const { Cell::new(false) };
25}
26
27pub fn set_numeric_native(native: bool) {
32 NUMERIC_NATIVE.with(|c| c.set(native));
33}
34
35pub fn is_numeric_native() -> bool {
37 NUMERIC_NATIVE.with(|c| c.get())
38}
39
40pub fn serialize<S: Serializer>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> {
42 if is_numeric_native() {
43 match value.to_f64() {
44 Some(f) => serializer.serialize_f64(f),
45 None => serializer.serialize_str(&value.to_string()),
46 }
47 } else {
48 rust_decimal::serde::str::serialize(value, serializer)
49 }
50}
51
52pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Decimal, D::Error> {
54 deserializer.deserialize_any(DecimalVisitor)
55}
56
57pub mod option {
59 use rust_decimal::prelude::ToPrimitive;
60 use rust_decimal::Decimal;
61 use serde::{Deserializer, Serializer};
62
63 use super::{is_numeric_native, OptionDecimalVisitor};
64
65 pub fn serialize<S: Serializer>(
67 value: &Option<Decimal>,
68 serializer: S,
69 ) -> Result<S::Ok, S::Error> {
70 match value {
71 Some(d) => {
72 if is_numeric_native() {
73 match d.to_f64() {
74 Some(f) => serializer.serialize_f64(f),
75 None => serializer.serialize_str(&d.to_string()),
76 }
77 } else {
78 rust_decimal::serde::str_option::serialize(value, serializer)
79 }
80 }
81 None => serializer.serialize_none(),
82 }
83 }
84
85 pub fn deserialize<'de, D: Deserializer<'de>>(
87 deserializer: D,
88 ) -> Result<Option<Decimal>, D::Error> {
89 deserializer.deserialize_any(OptionDecimalVisitor)
90 }
91}
92
93struct DecimalVisitor;
96
97impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
98 type Value = Decimal;
99
100 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
101 write!(f, "a decimal as a string or number")
102 }
103
104 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Decimal, E> {
105 v.parse::<Decimal>().map_err(E::custom)
106 }
107
108 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Decimal, E> {
109 Decimal::try_from(v).map_err(E::custom)
110 }
111
112 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Decimal, E> {
113 Ok(Decimal::from(v))
114 }
115
116 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Decimal, E> {
117 Ok(Decimal::from(v))
118 }
119}
120
121struct OptionDecimalVisitor;
122
123impl<'de> serde::de::Visitor<'de> for OptionDecimalVisitor {
124 type Value = Option<Decimal>;
125
126 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
127 write!(f, "a decimal as a string or number, or null")
128 }
129
130 fn visit_none<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
131 Ok(None)
132 }
133
134 fn visit_unit<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
135 Ok(None)
136 }
137
138 fn visit_some<D: Deserializer<'de>>(
139 self,
140 deserializer: D,
141 ) -> Result<Option<Decimal>, D::Error> {
142 deserializer.deserialize_any(DecimalVisitor).map(Some)
143 }
144
145 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Option<Decimal>, E> {
146 v.parse::<Decimal>().map(Some).map_err(E::custom)
147 }
148
149 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Option<Decimal>, E> {
150 Decimal::try_from(v).map(Some).map_err(E::custom)
151 }
152
153 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Option<Decimal>, E> {
154 Ok(Some(Decimal::from(v)))
155 }
156
157 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Option<Decimal>, E> {
158 Ok(Some(Decimal::from(v)))
159 }
160}
161
162#[cfg(test)]
163#[allow(clippy::unwrap_used)]
164mod tests {
165 use super::*;
166 use rust_decimal_macros::dec;
167
168 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
169 struct TestStruct {
170 #[serde(with = "super")]
171 amount: Decimal,
172 #[serde(default, with = "super::option")]
173 tax: Option<Decimal>,
174 }
175
176 #[test]
177 fn test_string_mode() {
178 set_numeric_native(false);
179 let s = TestStruct {
180 amount: dec!(1729237.30),
181 tax: Some(dec!(99.95)),
182 };
183 let json = serde_json::to_string(&s).unwrap();
184 assert!(json.contains("\"1729237.30\""), "expected string: {json}");
185 assert!(json.contains("\"99.95\""), "expected string: {json}");
186 }
187
188 #[test]
189 fn test_native_mode() {
190 set_numeric_native(true);
191 let s = TestStruct {
192 amount: dec!(1729237.30),
193 tax: Some(dec!(99.95)),
194 };
195 let json = serde_json::to_string(&s).unwrap();
196 assert!(
198 json.contains(":1729237.3") || json.contains(":1729237.30"),
199 "expected number: {json}"
200 );
201 set_numeric_native(false);
203 }
204
205 #[test]
206 fn test_deserialize_from_string() {
207 let json = r#"{"amount":"1729237.30","tax":"99.95"}"#;
208 let s: TestStruct = serde_json::from_str(json).unwrap();
209 assert_eq!(s.amount, dec!(1729237.30));
210 assert_eq!(s.tax, Some(dec!(99.95)));
211 }
212
213 #[test]
214 fn test_deserialize_from_number() {
215 let json = r#"{"amount":1729237.30,"tax":99.95}"#;
216 let s: TestStruct = serde_json::from_str(json).unwrap();
217 assert_eq!(s.amount, dec!(1729237.3));
219 assert_eq!(s.tax, Some(dec!(99.95)));
220 }
221
222 #[test]
223 fn test_deserialize_null_option() {
224 let json = r#"{"amount":"100.00","tax":null}"#;
225 let s: TestStruct = serde_json::from_str(json).unwrap();
226 assert_eq!(s.tax, None);
227 }
228
229 #[test]
230 fn test_deserialize_missing_option() {
231 let json = r#"{"amount":"100.00"}"#;
232 let s: TestStruct = serde_json::from_str(json).unwrap();
233 assert_eq!(s.tax, None);
234 }
235}