use std::cell::Cell;
use std::fmt;
use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use serde::{self, Deserializer, Serializer};
thread_local! {
static NUMERIC_NATIVE: Cell<bool> = const { Cell::new(false) };
}
pub fn set_numeric_native(native: bool) {
NUMERIC_NATIVE.with(|c| c.set(native));
}
pub fn is_numeric_native() -> bool {
NUMERIC_NATIVE.with(|c| c.get())
}
pub fn serialize<S: Serializer>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> {
if is_numeric_native() {
match value.to_f64() {
Some(f) => serializer.serialize_f64(f),
None => serializer.serialize_str(&value.to_string()),
}
} else {
rust_decimal::serde::str::serialize(value, serializer)
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Decimal, D::Error> {
deserializer.deserialize_any(DecimalVisitor)
}
pub mod option {
use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use serde::{Deserializer, Serializer};
use super::{is_numeric_native, OptionDecimalVisitor};
pub fn serialize<S: Serializer>(
value: &Option<Decimal>,
serializer: S,
) -> Result<S::Ok, S::Error> {
match value {
Some(d) => {
if is_numeric_native() {
match d.to_f64() {
Some(f) => serializer.serialize_f64(f),
None => serializer.serialize_str(&d.to_string()),
}
} else {
rust_decimal::serde::str_option::serialize(value, serializer)
}
}
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Option<Decimal>, D::Error> {
deserializer.deserialize_any(OptionDecimalVisitor)
}
}
struct DecimalVisitor;
impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
type Value = Decimal;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a decimal as a string or number")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Decimal, E> {
v.parse::<Decimal>().map_err(E::custom)
}
fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Decimal, E> {
Decimal::try_from(v).map_err(E::custom)
}
fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Decimal, E> {
Ok(Decimal::from(v))
}
fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Decimal, E> {
Ok(Decimal::from(v))
}
}
struct OptionDecimalVisitor;
impl<'de> serde::de::Visitor<'de> for OptionDecimalVisitor {
type Value = Option<Decimal>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a decimal as a string or number, or null")
}
fn visit_none<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
Ok(None)
}
fn visit_unit<E: serde::de::Error>(self) -> Result<Option<Decimal>, E> {
Ok(None)
}
fn visit_some<D: Deserializer<'de>>(
self,
deserializer: D,
) -> Result<Option<Decimal>, D::Error> {
deserializer.deserialize_any(DecimalVisitor).map(Some)
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Option<Decimal>, E> {
v.parse::<Decimal>().map(Some).map_err(E::custom)
}
fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Option<Decimal>, E> {
Decimal::try_from(v).map(Some).map_err(E::custom)
}
fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Option<Decimal>, E> {
Ok(Some(Decimal::from(v)))
}
fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Option<Decimal>, E> {
Ok(Some(Decimal::from(v)))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
struct TestStruct {
#[serde(with = "super")]
amount: Decimal,
#[serde(default, with = "super::option")]
tax: Option<Decimal>,
}
#[test]
fn test_string_mode() {
set_numeric_native(false);
let s = TestStruct {
amount: dec!(1729237.30),
tax: Some(dec!(99.95)),
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("\"1729237.30\""), "expected string: {json}");
assert!(json.contains("\"99.95\""), "expected string: {json}");
}
#[test]
fn test_native_mode() {
set_numeric_native(true);
let s = TestStruct {
amount: dec!(1729237.30),
tax: Some(dec!(99.95)),
};
let json = serde_json::to_string(&s).unwrap();
assert!(
json.contains(":1729237.3") || json.contains(":1729237.30"),
"expected number: {json}"
);
set_numeric_native(false);
}
#[test]
fn test_deserialize_from_string() {
let json = r#"{"amount":"1729237.30","tax":"99.95"}"#;
let s: TestStruct = serde_json::from_str(json).unwrap();
assert_eq!(s.amount, dec!(1729237.30));
assert_eq!(s.tax, Some(dec!(99.95)));
}
#[test]
fn test_deserialize_from_number() {
let json = r#"{"amount":1729237.30,"tax":99.95}"#;
let s: TestStruct = serde_json::from_str(json).unwrap();
assert_eq!(s.amount, dec!(1729237.3));
assert_eq!(s.tax, Some(dec!(99.95)));
}
#[test]
fn test_deserialize_null_option() {
let json = r#"{"amount":"100.00","tax":null}"#;
let s: TestStruct = serde_json::from_str(json).unwrap();
assert_eq!(s.tax, None);
}
#[test]
fn test_deserialize_missing_option() {
let json = r#"{"amount":"100.00"}"#;
let s: TestStruct = serde_json::from_str(json).unwrap();
assert_eq!(s.tax, None);
}
}