Skip to main content

nautilus_core/
serialization.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Common serialization traits and functions.
17//!
18//! This module provides custom serde deserializers and serializers for common
19//! patterns encountered when parsing exchange API responses, particularly:
20//!
21//! - Empty strings that should be interpreted as `None` or zero.
22//! - Type conversions from strings to primitives.
23//! - Decimal values represented as strings.
24
25use std::str::FromStr;
26
27use bytes::Bytes;
28use rust_decimal::Decimal;
29use serde::{
30    Deserialize, Deserializer, Serialize, Serializer,
31    de::{Error, Unexpected, Visitor},
32    ser::SerializeSeq,
33};
34use ustr::Ustr;
35
36/// Sorted serialization for `AHashSet<T>` where element order must be deterministic.
37///
38/// Use with `#[serde(with = "nautilus_core::serialization::sorted_hashset")]`.
39pub mod sorted_hashset {
40    use ahash::AHashSet;
41    use serde::{Deserialize, Deserializer, Serialize, Serializer};
42
43    /// Serializes an `AHashSet<T>` as a sorted array for deterministic output.
44    ///
45    /// # Errors
46    ///
47    /// Returns any error produced by the underlying [`Serializer`] when writing
48    /// the sorted vector.
49    pub fn serialize<T, S>(set: &AHashSet<T>, s: S) -> Result<S::Ok, S::Error>
50    where
51        T: Serialize + Ord,
52        S: Serializer,
53    {
54        let mut sorted: Vec<&T> = set.iter().collect();
55        sorted.sort_unstable();
56        sorted.serialize(s)
57    }
58
59    /// Deserializes an array into an `AHashSet<T>`.
60    ///
61    /// # Errors
62    ///
63    /// Returns any error produced by the underlying [`Deserializer`] when reading
64    /// the source array.
65    pub fn deserialize<'de, T, D>(d: D) -> Result<AHashSet<T>, D::Error>
66    where
67        T: Deserialize<'de> + Eq + std::hash::Hash,
68        D: Deserializer<'de>,
69    {
70        let vec = Vec::<T>::deserialize(d)?;
71        Ok(vec.into_iter().collect())
72    }
73}
74
75struct BoolVisitor;
76
77/// Zero-allocation decimal visitor for maximum deserialization performance.
78///
79/// Directly visits JSON tokens without intermediate `serde_json::Value` allocation.
80/// Handles all JSON numeric representations: strings, integers, floats, and null.
81struct DecimalVisitor;
82
83impl Visitor<'_> for DecimalVisitor {
84    type Value = Decimal;
85
86    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
87        formatter.write_str("a decimal number as string, integer, or float")
88    }
89
90    // Fast path: borrowed string (zero-copy)
91    fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
92        if v.is_empty() {
93            return Ok(Decimal::ZERO);
94        }
95        // Check for scientific notation
96        if v.contains('e') || v.contains('E') {
97            Decimal::from_scientific(v).map_err(E::custom)
98        } else {
99            Decimal::from_str(v).map_err(E::custom)
100        }
101    }
102
103    // Owned string (rare case, delegates to visit_str)
104    fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> {
105        self.visit_str(&v)
106    }
107
108    // Direct integer handling - no string conversion needed
109    fn visit_i64<E: Error>(self, v: i64) -> Result<Self::Value, E> {
110        Ok(Decimal::from(v))
111    }
112
113    fn visit_u64<E: Error>(self, v: u64) -> Result<Self::Value, E> {
114        Ok(Decimal::from(v))
115    }
116
117    fn visit_i128<E: Error>(self, v: i128) -> Result<Self::Value, E> {
118        Ok(Decimal::from(v))
119    }
120
121    fn visit_u128<E: Error>(self, v: u128) -> Result<Self::Value, E> {
122        Ok(Decimal::from(v))
123    }
124
125    // Float handling - direct conversion
126    fn visit_f64<E: Error>(self, v: f64) -> Result<Self::Value, E> {
127        if v.is_nan() {
128            return Err(E::invalid_value(Unexpected::Float(v), &self));
129        }
130
131        if v.is_infinite() {
132            return Err(E::invalid_value(Unexpected::Float(v), &self));
133        }
134        Decimal::try_from(v).map_err(E::custom)
135    }
136
137    // Null → zero (matches existing behavior)
138    fn visit_unit<E: Error>(self) -> Result<Self::Value, E> {
139        Ok(Decimal::ZERO)
140    }
141
142    fn visit_none<E: Error>(self) -> Result<Self::Value, E> {
143        Ok(Decimal::ZERO)
144    }
145}
146
147/// Zero-allocation optional decimal visitor for maximum deserialization performance.
148///
149/// Handles null values as `None` and empty strings as `None`.
150/// Uses `deserialize_any` approach to handle all JSON value types uniformly.
151struct OptionalDecimalVisitor;
152
153impl Visitor<'_> for OptionalDecimalVisitor {
154    type Value = Option<Decimal>;
155
156    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
157        formatter.write_str("null or a decimal number as string, integer, or float")
158    }
159
160    // Fast path: borrowed string (zero-copy)
161    // Empty string → None (different from DecimalVisitor which returns ZERO)
162    fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
163        if v.is_empty() {
164            return Ok(None);
165        }
166        DecimalVisitor.visit_str(v).map(Some)
167    }
168
169    fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> {
170        self.visit_str(&v)
171    }
172
173    fn visit_i64<E: Error>(self, v: i64) -> Result<Self::Value, E> {
174        DecimalVisitor.visit_i64(v).map(Some)
175    }
176
177    fn visit_u64<E: Error>(self, v: u64) -> Result<Self::Value, E> {
178        DecimalVisitor.visit_u64(v).map(Some)
179    }
180
181    fn visit_i128<E: Error>(self, v: i128) -> Result<Self::Value, E> {
182        DecimalVisitor.visit_i128(v).map(Some)
183    }
184
185    fn visit_u128<E: Error>(self, v: u128) -> Result<Self::Value, E> {
186        DecimalVisitor.visit_u128(v).map(Some)
187    }
188
189    fn visit_f64<E: Error>(self, v: f64) -> Result<Self::Value, E> {
190        DecimalVisitor.visit_f64(v).map(Some)
191    }
192
193    // Null → None
194    fn visit_unit<E: Error>(self) -> Result<Self::Value, E> {
195        Ok(None)
196    }
197
198    fn visit_none<E: Error>(self) -> Result<Self::Value, E> {
199        Ok(None)
200    }
201}
202
203/// Represents types which are serializable for JSON specifications.
204pub trait Serializable: Serialize + for<'de> Deserialize<'de> {
205    /// Deserialize an object from JSON encoded bytes.
206    ///
207    /// # Errors
208    ///
209    /// Returns serialization errors.
210    fn from_json_bytes(data: &[u8]) -> Result<Self, serde_json::Error> {
211        serde_json::from_slice(data)
212    }
213
214    /// Serialize an object to JSON encoded bytes.
215    ///
216    /// # Errors
217    ///
218    /// Returns serialization errors.
219    fn to_json_bytes(&self) -> Result<Bytes, serde_json::Error> {
220        serde_json::to_vec(self).map(Bytes::from)
221    }
222}
223
224pub use self::msgpack::{FromMsgPack, MsgPackSerializable, ToMsgPack};
225
226/// Provides `MsgPack` serialization support for types implementing [`Serializable`].
227///
228/// This module contains traits for `MsgPack` serialization and deserialization,
229/// separated from the core [`Serializable`] trait to allow independent opt-in.
230pub mod msgpack {
231    use bytes::Bytes;
232    use serde::{Deserialize, Serialize};
233
234    use super::Serializable;
235
236    /// Provides deserialization from `MsgPack` encoded bytes.
237    pub trait FromMsgPack: for<'de> Deserialize<'de> + Sized {
238        /// Deserialize an object from `MsgPack` encoded bytes.
239        ///
240        /// # Errors
241        ///
242        /// Returns serialization errors.
243        fn from_msgpack_bytes(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
244            rmp_serde::from_slice(data)
245        }
246    }
247
248    /// Provides serialization to `MsgPack` encoded bytes.
249    pub trait ToMsgPack: Serialize {
250        /// Serialize an object to `MsgPack` encoded bytes.
251        ///
252        /// # Errors
253        ///
254        /// Returns serialization errors.
255        fn to_msgpack_bytes(&self) -> Result<Bytes, rmp_serde::encode::Error> {
256            rmp_serde::to_vec_named(self).map(Bytes::from)
257        }
258    }
259
260    /// Marker trait combining [`Serializable`], [`FromMsgPack`], and [`ToMsgPack`].
261    ///
262    /// This trait is automatically implemented for all types that implement [`Serializable`].
263    pub trait MsgPackSerializable: Serializable + FromMsgPack + ToMsgPack {}
264
265    impl<T> FromMsgPack for T where T: Serializable {}
266
267    impl<T> ToMsgPack for T where T: Serializable {}
268
269    impl<T> MsgPackSerializable for T where T: Serializable {}
270}
271
272impl Visitor<'_> for BoolVisitor {
273    type Value = u8;
274
275    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        formatter.write_str("a boolean as u8")
277    }
278
279    fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
280    where
281        E: serde::de::Error,
282    {
283        Ok(u8::from(value))
284    }
285
286    #[expect(
287        clippy::cast_possible_truncation,
288        reason = "Intentional for parsing, value range validated"
289    )]
290    fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
291    where
292        E: serde::de::Error,
293    {
294        // Only 0 or 1 are considered valid representations when provided as an
295        // integer. We deliberately reject values outside this range to avoid
296        // silently truncating larger integers into impl-defined boolean
297        // semantics.
298        if value > 1 {
299            Err(E::invalid_value(Unexpected::Unsigned(value), &self))
300        } else {
301            Ok(value as u8)
302        }
303    }
304}
305
306/// Serde default value function that returns `true`.
307///
308/// Use with `#[serde(default = "default_true")]` on boolean fields.
309#[must_use]
310pub const fn default_true() -> bool {
311    true
312}
313
314/// Serde default value function that returns `false`.
315///
316/// Use with `#[serde(default = "default_false")]` on boolean fields.
317#[must_use]
318pub const fn default_false() -> bool {
319    false
320}
321
322/// Deserialize the boolean value as a `u8`.
323///
324/// # Errors
325///
326/// Returns serialization errors.
327pub fn from_bool_as_u8<'de, D>(deserializer: D) -> Result<u8, D::Error>
328where
329    D: Deserializer<'de>,
330{
331    deserializer.deserialize_any(BoolVisitor)
332}
333
334/// Deserializes a `Decimal` from either a JSON string or number.
335///
336/// High-performance implementation using a custom visitor that avoids intermediate
337/// `serde_json::Value` allocations. Handles all JSON numeric representations:
338///
339/// - JSON string: `"123.456"` → Decimal (zero-copy for borrowed strings)
340/// - JSON integer: `123` → Decimal (direct conversion, no string allocation)
341/// - JSON float: `123.456` → Decimal
342/// - JSON null: → `Decimal::ZERO`
343/// - Scientific notation: `"1.5e-8"` → Decimal
344///
345/// # Performance
346///
347/// This implementation is optimized for high-frequency trading scenarios:
348/// - Zero allocations for string values (uses borrowed `&str`)
349/// - Direct integer conversion without string intermediary
350/// - No intermediate `serde_json::Value` heap allocation
351///
352/// # Errors
353///
354/// Returns an error if the value cannot be parsed as a valid decimal.
355pub fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
356where
357    D: Deserializer<'de>,
358{
359    deserializer.deserialize_any(DecimalVisitor)
360}
361
362/// Deserializes an `Option<Decimal>` from a JSON string, number, or null.
363///
364/// High-performance implementation using a custom visitor that avoids intermediate
365/// `serde_json::Value` allocations. Handles all JSON numeric representations:
366///
367/// - JSON string: `"123.456"` → Some(Decimal) (zero-copy for borrowed strings)
368/// - JSON integer: `123` → Some(Decimal) (direct conversion)
369/// - JSON float: `123.456` → Some(Decimal)
370/// - JSON null: → `None`
371/// - Empty string: `""` → `None`
372/// - Scientific notation: `"1.5e-8"` → Some(Decimal)
373///
374/// # Performance
375///
376/// This implementation is optimized for high-frequency trading scenarios:
377/// - Zero allocations for string values (uses borrowed `&str`)
378/// - Direct integer conversion without string intermediary
379/// - No intermediate `serde_json::Value` heap allocation
380///
381/// # Errors
382///
383/// Returns an error if the value cannot be parsed as a valid decimal.
384pub fn deserialize_optional_decimal<'de, D>(deserializer: D) -> Result<Option<Decimal>, D::Error>
385where
386    D: Deserializer<'de>,
387{
388    // Use deserialize_any to handle all JSON value types uniformly
389    // (deserialize_option would route non-null through visit_some, losing empty string handling)
390    deserializer.deserialize_any(OptionalDecimalVisitor)
391}
392
393/// Serializes a `Decimal` as a JSON number (float).
394///
395/// Used for outgoing requests where exchange APIs expect JSON numbers.
396///
397/// # Errors
398///
399/// Returns an error if serialization fails.
400pub fn serialize_decimal<S: Serializer>(d: &Decimal, s: S) -> Result<S::Ok, S::Error> {
401    rust_decimal::serde::float::serialize(d, s)
402}
403
404/// Serializes an `Option<Decimal>` as a JSON number or null.
405///
406/// # Errors
407///
408/// Returns an error if serialization fails.
409pub fn serialize_optional_decimal<S: Serializer>(
410    d: &Option<Decimal>,
411    s: S,
412) -> Result<S::Ok, S::Error> {
413    match d {
414        Some(decimal) => rust_decimal::serde::float::serialize(decimal, s),
415        None => s.serialize_none(),
416    }
417}
418
419/// Deserializes a `Decimal` from a JSON string.
420///
421/// This is the strict form that requires the value to be a string, rejecting
422/// numeric JSON values to avoid precision loss.
423///
424/// # Errors
425///
426/// Returns an error if the string cannot be parsed as a valid decimal.
427pub fn deserialize_decimal_from_str<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
428where
429    D: Deserializer<'de>,
430{
431    let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
432    Decimal::from_str(s.as_ref()).map_err(D::Error::custom)
433}
434
435/// Deserializes a `Decimal` from a string field that might be empty.
436///
437/// Handles edge cases where empty string "" or "0" becomes `Decimal::ZERO`.
438///
439/// # Errors
440///
441/// Returns an error if the string cannot be parsed as a valid decimal.
442pub fn deserialize_decimal_or_zero<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
443where
444    D: Deserializer<'de>,
445{
446    let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
447    if s.is_empty() || s == "0" {
448        Ok(Decimal::ZERO)
449    } else {
450        Decimal::from_str(s.as_ref()).map_err(D::Error::custom)
451    }
452}
453
454/// Deserializes an optional `Decimal` from a string field.
455///
456/// Returns `None` if the string is empty or "0", otherwise parses to `Decimal`.
457/// This is a strict string-only deserializer; for flexible handling of strings,
458/// numbers, and null, use [`deserialize_optional_decimal`].
459///
460/// # Errors
461///
462/// Returns an error if the string cannot be parsed as a valid decimal.
463pub fn deserialize_optional_decimal_str<'de, D>(
464    deserializer: D,
465) -> Result<Option<Decimal>, D::Error>
466where
467    D: Deserializer<'de>,
468{
469    let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
470    if s.is_empty() || s == "0" {
471        Ok(None)
472    } else {
473        Decimal::from_str(s.as_ref())
474            .map(Some)
475            .map_err(D::Error::custom)
476    }
477}
478
479/// Deserializes an optional `Decimal` from a string-only field.
480///
481/// Returns `None` if the value is null or the string is empty, otherwise
482/// parses to `Decimal`.
483///
484/// # Errors
485///
486/// Returns an error if the string cannot be parsed as a valid decimal.
487pub fn deserialize_optional_decimal_from_str<'de, D>(
488    deserializer: D,
489) -> Result<Option<Decimal>, D::Error>
490where
491    D: Deserializer<'de>,
492{
493    let opt = Option::<String>::deserialize(deserializer)?;
494    match opt {
495        Some(s) if !s.is_empty() => Decimal::from_str(&s).map(Some).map_err(D::Error::custom),
496        _ => Ok(None),
497    }
498}
499
500/// Deserializes a `Decimal` from an optional string field, defaulting to zero.
501///
502/// Handles edge cases: `None`, empty string "", or "0" all become `Decimal::ZERO`.
503///
504/// # Errors
505///
506/// Returns an error if the string cannot be parsed as a valid decimal.
507pub fn deserialize_optional_decimal_or_zero<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
508where
509    D: Deserializer<'de>,
510{
511    let opt: Option<String> = Deserialize::deserialize(deserializer)?;
512    match opt {
513        None => Ok(Decimal::ZERO),
514        Some(s) if s.is_empty() || s == "0" => Ok(Decimal::ZERO),
515        Some(s) => Decimal::from_str(&s).map_err(D::Error::custom),
516    }
517}
518
519/// Deserializes a `Vec<Decimal>` from a JSON array of strings.
520///
521/// # Errors
522///
523/// Returns an error if any string cannot be parsed as a valid decimal.
524pub fn deserialize_vec_decimal_from_str<'de, D>(deserializer: D) -> Result<Vec<Decimal>, D::Error>
525where
526    D: Deserializer<'de>,
527{
528    let strings = Vec::<String>::deserialize(deserializer)?;
529    strings
530        .into_iter()
531        .map(|s| Decimal::from_str(&s).map_err(D::Error::custom))
532        .collect()
533}
534
535/// Serializes a `Decimal` as a string (lossless, no scientific notation).
536///
537/// # Errors
538///
539/// Returns an error if serialization fails.
540pub fn serialize_decimal_as_str<S>(decimal: &Decimal, serializer: S) -> Result<S::Ok, S::Error>
541where
542    S: Serializer,
543{
544    serializer.serialize_str(&decimal.to_string())
545}
546
547/// Serializes an optional `Decimal` as a string.
548///
549/// # Errors
550///
551/// Returns an error if serialization fails.
552pub fn serialize_optional_decimal_as_str<S>(
553    decimal: &Option<Decimal>,
554    serializer: S,
555) -> Result<S::Ok, S::Error>
556where
557    S: Serializer,
558{
559    match decimal {
560        Some(d) => serializer.serialize_str(&d.to_string()),
561        None => serializer.serialize_none(),
562    }
563}
564
565/// Serializes a `Vec<Decimal>` as an array of strings.
566///
567/// # Errors
568///
569/// Returns an error if serialization fails.
570pub fn serialize_vec_decimal_as_str<S>(
571    decimals: &Vec<Decimal>,
572    serializer: S,
573) -> Result<S::Ok, S::Error>
574where
575    S: Serializer,
576{
577    let mut seq = serializer.serialize_seq(Some(decimals.len()))?;
578    for decimal in decimals {
579        seq.serialize_element(&decimal.to_string())?;
580    }
581    seq.end()
582}
583
584/// Parses a string to `Decimal`, returning an error if parsing fails.
585///
586/// # Errors
587///
588/// Returns an error if the string cannot be parsed as a Decimal.
589pub fn parse_decimal(s: &str) -> anyhow::Result<Decimal> {
590    Decimal::from_str(s).map_err(|e| anyhow::anyhow!("Failed to parse decimal from '{s}': {e}"))
591}
592
593/// Parses an optional string to `Decimal`, returning `None` if the string is `None` or empty.
594///
595/// # Errors
596///
597/// Returns an error if the string cannot be parsed as a Decimal.
598pub fn parse_optional_decimal(s: &Option<String>) -> anyhow::Result<Option<Decimal>> {
599    match s {
600        None => Ok(None),
601        Some(s) if s.is_empty() => Ok(None),
602        Some(s) => parse_decimal(s).map(Some),
603    }
604}
605
606/// Deserializes an empty string into `None`.
607///
608/// Many exchange APIs represent null string fields as an empty string (`""`).
609/// When such a payload is mapped onto `Option<String>` the default behavior
610/// would yield `Some("")`, which is semantically different from the intended
611/// absence of a value. This helper ensures that empty strings are normalized
612/// to `None` during deserialization.
613///
614/// # Errors
615///
616/// Returns an error if the JSON value cannot be deserialized into a string.
617pub fn deserialize_empty_string_as_none<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
618where
619    D: Deserializer<'de>,
620{
621    let opt = Option::<String>::deserialize(deserializer)?;
622    Ok(opt.filter(|s| !s.is_empty()))
623}
624
625/// Deserializes an empty [`Ustr`] into `None`.
626///
627/// # Errors
628///
629/// Returns an error if the JSON value cannot be deserialized into a string.
630pub fn deserialize_empty_ustr_as_none<'de, D>(deserializer: D) -> Result<Option<Ustr>, D::Error>
631where
632    D: Deserializer<'de>,
633{
634    let opt = Option::<Ustr>::deserialize(deserializer)?;
635    Ok(opt.filter(|s| !s.is_empty()))
636}
637
638/// Deserializes a `u8` from a string field.
639///
640/// Returns 0 if the string is empty.
641///
642/// # Errors
643///
644/// Returns an error if the string cannot be parsed as a u8.
645pub fn deserialize_string_to_u8<'de, D>(deserializer: D) -> Result<u8, D::Error>
646where
647    D: Deserializer<'de>,
648{
649    let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
650    if s.is_empty() {
651        return Ok(0);
652    }
653    s.as_ref().parse::<u8>().map_err(D::Error::custom)
654}
655
656/// Deserializes a `u64` from a string field.
657///
658/// Returns 0 if the string is empty.
659///
660/// # Errors
661///
662/// Returns an error if the string cannot be parsed as a u64.
663pub fn deserialize_string_to_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
664where
665    D: Deserializer<'de>,
666{
667    let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
668    if s.is_empty() {
669        Ok(0)
670    } else {
671        s.as_ref().parse::<u64>().map_err(D::Error::custom)
672    }
673}
674
675/// Deserializes an optional `u64` from a string field.
676///
677/// Returns `None` if the value is null or the string is empty.
678///
679/// # Errors
680///
681/// Returns an error if the string cannot be parsed as a u64.
682pub fn deserialize_optional_string_to_u64<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
683where
684    D: Deserializer<'de>,
685{
686    let s: Option<String> = Option::deserialize(deserializer)?;
687    match s {
688        Some(s) if s.is_empty() => Ok(None),
689        Some(s) => s.parse().map(Some).map_err(D::Error::custom),
690        None => Ok(None),
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use rstest::*;
697    use rust_decimal::Decimal;
698    use rust_decimal_macros::dec;
699    use serde::{Deserialize, Serialize};
700    use ustr::Ustr;
701
702    use super::{
703        Serializable, deserialize_decimal, deserialize_decimal_from_str,
704        deserialize_decimal_or_zero, deserialize_empty_string_as_none,
705        deserialize_empty_ustr_as_none, deserialize_optional_decimal,
706        deserialize_optional_decimal_or_zero, deserialize_optional_decimal_str,
707        deserialize_optional_string_to_u64, deserialize_string_to_u8, deserialize_string_to_u64,
708        deserialize_vec_decimal_from_str, from_bool_as_u8,
709        msgpack::{FromMsgPack, ToMsgPack},
710        parse_decimal, parse_optional_decimal, serialize_decimal, serialize_decimal_as_str,
711        serialize_optional_decimal, serialize_optional_decimal_as_str,
712        serialize_vec_decimal_as_str,
713    };
714
715    #[derive(Deserialize)]
716    pub struct TestStruct {
717        #[serde(deserialize_with = "from_bool_as_u8")]
718        pub value: u8,
719    }
720
721    #[rstest]
722    #[case(r#"{"value": true}"#, 1)]
723    #[case(r#"{"value": false}"#, 0)]
724    fn test_deserialize_bool_as_u8_with_boolean(#[case] json_str: &str, #[case] expected: u8) {
725        let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
726        assert_eq!(test_struct.value, expected);
727    }
728
729    #[rstest]
730    #[case(r#"{"value": 1}"#, 1)]
731    #[case(r#"{"value": 0}"#, 0)]
732    fn test_deserialize_bool_as_u8_with_u64(#[case] json_str: &str, #[case] expected: u8) {
733        let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
734        assert_eq!(test_struct.value, expected);
735    }
736
737    #[rstest]
738    fn test_deserialize_bool_as_u8_with_invalid_integer() {
739        // Any integer other than 0/1 is invalid and should error
740        let json = r#"{"value": 2}"#;
741        let result: Result<TestStruct, _> = serde_json::from_str(json);
742        assert!(result.is_err());
743    }
744
745    #[derive(Serialize, Deserialize, PartialEq, Debug)]
746    struct SerializableTestStruct {
747        id: u32,
748        name: String,
749        value: f64,
750    }
751
752    impl Serializable for SerializableTestStruct {}
753
754    #[rstest]
755    fn test_serializable_json_roundtrip() {
756        let original = SerializableTestStruct {
757            id: 42,
758            name: "test".to_string(),
759            value: std::f64::consts::PI,
760        };
761
762        let json_bytes = original.to_json_bytes().unwrap();
763        let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
764
765        assert_eq!(original, deserialized);
766    }
767
768    #[rstest]
769    fn test_serializable_msgpack_roundtrip() {
770        let original = SerializableTestStruct {
771            id: 123,
772            name: "msgpack_test".to_string(),
773            value: std::f64::consts::E,
774        };
775
776        let msgpack_bytes = original.to_msgpack_bytes().unwrap();
777        let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
778
779        assert_eq!(original, deserialized);
780    }
781
782    #[rstest]
783    fn test_serializable_json_invalid_data() {
784        let invalid_json = b"invalid json data";
785        let result = SerializableTestStruct::from_json_bytes(invalid_json);
786        assert!(result.is_err());
787    }
788
789    #[rstest]
790    fn test_serializable_msgpack_invalid_data() {
791        let invalid_msgpack = b"invalid msgpack data";
792        let result = SerializableTestStruct::from_msgpack_bytes(invalid_msgpack);
793        assert!(result.is_err());
794    }
795
796    #[rstest]
797    fn test_serializable_json_empty_values() {
798        let test_struct = SerializableTestStruct {
799            id: 0,
800            name: String::new(),
801            value: 0.0,
802        };
803
804        let json_bytes = test_struct.to_json_bytes().unwrap();
805        let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
806
807        assert_eq!(test_struct, deserialized);
808    }
809
810    #[rstest]
811    fn test_serializable_msgpack_empty_values() {
812        let test_struct = SerializableTestStruct {
813            id: 0,
814            name: String::new(),
815            value: 0.0,
816        };
817
818        let msgpack_bytes = test_struct.to_msgpack_bytes().unwrap();
819        let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
820
821        assert_eq!(test_struct, deserialized);
822    }
823
824    #[derive(Deserialize)]
825    struct TestOptionalDecimalStr {
826        #[serde(deserialize_with = "deserialize_optional_decimal_str")]
827        value: Option<Decimal>,
828    }
829
830    #[derive(Deserialize)]
831    struct TestDecimalOrZero {
832        #[serde(deserialize_with = "deserialize_decimal_or_zero")]
833        value: Decimal,
834    }
835
836    #[derive(Deserialize)]
837    struct TestOptionalDecimalOrZero {
838        #[serde(deserialize_with = "deserialize_optional_decimal_or_zero")]
839        value: Decimal,
840    }
841
842    #[derive(Serialize, Deserialize, PartialEq, Debug)]
843    struct TestDecimalRoundtrip {
844        #[serde(
845            serialize_with = "serialize_decimal_as_str",
846            deserialize_with = "deserialize_decimal_from_str"
847        )]
848        value: Decimal,
849        #[serde(
850            serialize_with = "serialize_optional_decimal_as_str",
851            deserialize_with = "super::deserialize_optional_decimal_from_str"
852        )]
853        optional_value: Option<Decimal>,
854    }
855
856    #[rstest]
857    #[case(r#"{"value":"123.45"}"#, Some(dec!(123.45)))]
858    #[case(r#"{"value":"0"}"#, None)]
859    #[case(r#"{"value":""}"#, None)]
860    fn test_deserialize_optional_decimal_str(
861        #[case] json: &str,
862        #[case] expected: Option<Decimal>,
863    ) {
864        let result: TestOptionalDecimalStr = serde_json::from_str(json).unwrap();
865        assert_eq!(result.value, expected);
866    }
867
868    #[rstest]
869    #[case(r#"{"value":"123.45"}"#, dec!(123.45))]
870    #[case(r#"{"value":"0"}"#, Decimal::ZERO)]
871    #[case(r#"{"value":""}"#, Decimal::ZERO)]
872    fn test_deserialize_decimal_or_zero(#[case] json: &str, #[case] expected: Decimal) {
873        let result: TestDecimalOrZero = serde_json::from_str(json).unwrap();
874        assert_eq!(result.value, expected);
875    }
876
877    #[rstest]
878    #[case(r#"{"value":"123.45"}"#, dec!(123.45))]
879    #[case(r#"{"value":"0"}"#, Decimal::ZERO)]
880    #[case(r#"{"value":null}"#, Decimal::ZERO)]
881    fn test_deserialize_optional_decimal_or_zero(#[case] json: &str, #[case] expected: Decimal) {
882        let result: TestOptionalDecimalOrZero = serde_json::from_str(json).unwrap();
883        assert_eq!(result.value, expected);
884    }
885
886    #[rstest]
887    fn test_decimal_serialization_roundtrip() {
888        let original = TestDecimalRoundtrip {
889            value: dec!(123.456789012345678),
890            optional_value: Some(dec!(0.000000001)),
891        };
892
893        let json = serde_json::to_string(&original).unwrap();
894
895        // Check that it's serialized as strings
896        assert!(json.contains("\"123.456789012345678\""));
897        assert!(json.contains("\"0.000000001\""));
898
899        let deserialized: TestDecimalRoundtrip = serde_json::from_str(&json).unwrap();
900        assert_eq!(original.value, deserialized.value);
901        assert_eq!(original.optional_value, deserialized.optional_value);
902    }
903
904    #[rstest]
905    fn test_decimal_optional_none_handling() {
906        let test_struct = TestDecimalRoundtrip {
907            value: dec!(42.0),
908            optional_value: None,
909        };
910
911        let json = serde_json::to_string(&test_struct).unwrap();
912        assert!(json.contains("null"));
913
914        let parsed: TestDecimalRoundtrip = serde_json::from_str(&json).unwrap();
915        assert_eq!(test_struct.value, parsed.value);
916        assert_eq!(None, parsed.optional_value);
917    }
918
919    #[derive(Deserialize)]
920    struct TestEmptyStringAsNone {
921        #[serde(deserialize_with = "deserialize_empty_string_as_none")]
922        value: Option<String>,
923    }
924
925    #[rstest]
926    #[case(r#"{"value":"hello"}"#, Some("hello".to_string()))]
927    #[case(r#"{"value":""}"#, None)]
928    #[case(r#"{"value":null}"#, None)]
929    fn test_deserialize_empty_string_as_none(#[case] json: &str, #[case] expected: Option<String>) {
930        let result: TestEmptyStringAsNone = serde_json::from_str(json).unwrap();
931        assert_eq!(result.value, expected);
932    }
933
934    #[derive(Deserialize)]
935    struct TestEmptyUstrAsNone {
936        #[serde(deserialize_with = "deserialize_empty_ustr_as_none")]
937        value: Option<Ustr>,
938    }
939
940    #[rstest]
941    #[case(r#"{"value":"hello"}"#, Some(Ustr::from("hello")))]
942    #[case(r#"{"value":""}"#, None)]
943    #[case(r#"{"value":null}"#, None)]
944    fn test_deserialize_empty_ustr_as_none(#[case] json: &str, #[case] expected: Option<Ustr>) {
945        let result: TestEmptyUstrAsNone = serde_json::from_str(json).unwrap();
946        assert_eq!(result.value, expected);
947    }
948
949    #[derive(Serialize, Deserialize, PartialEq, Debug)]
950    struct TestVecDecimal {
951        #[serde(
952            serialize_with = "serialize_vec_decimal_as_str",
953            deserialize_with = "deserialize_vec_decimal_from_str"
954        )]
955        values: Vec<Decimal>,
956    }
957
958    #[rstest]
959    fn test_vec_decimal_roundtrip() {
960        let original = TestVecDecimal {
961            values: vec![dec!(1.5), dec!(2.25), dec!(100.001)],
962        };
963
964        let json = serde_json::to_string(&original).unwrap();
965        assert!(json.contains("[\"1.5\",\"2.25\",\"100.001\"]"));
966
967        let parsed: TestVecDecimal = serde_json::from_str(&json).unwrap();
968        assert_eq!(original.values, parsed.values);
969    }
970
971    #[rstest]
972    fn test_vec_decimal_empty() {
973        let original = TestVecDecimal { values: vec![] };
974
975        let json = serde_json::to_string(&original).unwrap();
976        let parsed: TestVecDecimal = serde_json::from_str(&json).unwrap();
977        assert_eq!(original.values, parsed.values);
978    }
979
980    #[derive(Deserialize)]
981    struct TestStringToU8 {
982        #[serde(deserialize_with = "deserialize_string_to_u8")]
983        value: u8,
984    }
985
986    #[rstest]
987    #[case(r#"{"value":"42"}"#, 42)]
988    #[case(r#"{"value":"0"}"#, 0)]
989    #[case(r#"{"value":"255"}"#, 255)]
990    #[case(r#"{"value":""}"#, 0)]
991    fn test_deserialize_string_to_u8(#[case] json: &str, #[case] expected: u8) {
992        let result: TestStringToU8 = serde_json::from_str(json).unwrap();
993        assert_eq!(result.value, expected);
994    }
995
996    #[rstest]
997    #[case(r#"{"value":"256"}"#)]
998    #[case(r#"{"value":"999"}"#)]
999    #[case(r#"{"value":"abc"}"#)]
1000    fn test_deserialize_string_to_u8_invalid(#[case] json: &str) {
1001        let result: Result<TestStringToU8, _> = serde_json::from_str(json);
1002        assert!(result.is_err());
1003    }
1004
1005    #[derive(Deserialize)]
1006    struct TestStringToU64 {
1007        #[serde(deserialize_with = "deserialize_string_to_u64")]
1008        value: u64,
1009    }
1010
1011    #[rstest]
1012    #[case(r#"{"value":"12345678901234"}"#, 12_345_678_901_234)]
1013    #[case(r#"{"value":"0"}"#, 0)]
1014    #[case(r#"{"value":"18446744073709551615"}"#, u64::MAX)]
1015    #[case(r#"{"value":""}"#, 0)]
1016    fn test_deserialize_string_to_u64(#[case] json: &str, #[case] expected: u64) {
1017        let result: TestStringToU64 = serde_json::from_str(json).unwrap();
1018        assert_eq!(result.value, expected);
1019    }
1020
1021    #[rstest]
1022    #[case(r#"{"value":"18446744073709551616"}"#)]
1023    #[case(r#"{"value":"abc"}"#)]
1024    #[case(r#"{"value":"-1"}"#)]
1025    fn test_deserialize_string_to_u64_invalid(#[case] json: &str) {
1026        let result: Result<TestStringToU64, _> = serde_json::from_str(json);
1027        assert!(result.is_err());
1028    }
1029
1030    #[derive(Deserialize)]
1031    struct TestOptionalStringToU64 {
1032        #[serde(deserialize_with = "deserialize_optional_string_to_u64")]
1033        value: Option<u64>,
1034    }
1035
1036    #[rstest]
1037    #[case(r#"{"value":"12345678901234"}"#, Some(12_345_678_901_234))]
1038    #[case(r#"{"value":"0"}"#, Some(0))]
1039    #[case(r#"{"value":""}"#, None)]
1040    #[case(r#"{"value":null}"#, None)]
1041    fn test_deserialize_optional_string_to_u64(#[case] json: &str, #[case] expected: Option<u64>) {
1042        let result: TestOptionalStringToU64 = serde_json::from_str(json).unwrap();
1043        assert_eq!(result.value, expected);
1044    }
1045
1046    #[rstest]
1047    #[case("123.45", dec!(123.45))]
1048    #[case("0", Decimal::ZERO)]
1049    #[case("0.0", Decimal::ZERO)]
1050    fn test_parse_decimal(#[case] input: &str, #[case] expected: Decimal) {
1051        let result = parse_decimal(input).unwrap();
1052        assert_eq!(result, expected);
1053    }
1054
1055    #[rstest]
1056    fn test_parse_decimal_invalid() {
1057        assert!(parse_decimal("invalid").is_err());
1058        assert!(parse_decimal("").is_err());
1059    }
1060
1061    #[rstest]
1062    #[case(&Some("123.45".to_string()), Some(dec!(123.45)))]
1063    #[case(&Some("0".to_string()), Some(Decimal::ZERO))]
1064    #[case(&Some(String::new()), None)]
1065    #[case(&None, None)]
1066    fn test_parse_optional_decimal(
1067        #[case] input: &Option<String>,
1068        #[case] expected: Option<Decimal>,
1069    ) {
1070        let result = parse_optional_decimal(input).unwrap();
1071        assert_eq!(result, expected);
1072    }
1073
1074    // Tests for flexible decimal deserializers (handles both string and number JSON values)
1075
1076    #[derive(Debug, Serialize, Deserialize, PartialEq)]
1077    struct TestFlexibleDecimal {
1078        #[serde(
1079            serialize_with = "serialize_decimal",
1080            deserialize_with = "deserialize_decimal"
1081        )]
1082        value: Decimal,
1083        #[serde(
1084            serialize_with = "serialize_optional_decimal",
1085            deserialize_with = "deserialize_optional_decimal"
1086        )]
1087        optional_value: Option<Decimal>,
1088    }
1089
1090    #[rstest]
1091    #[case(r#"{"value": 123.456, "optional_value": 789.012}"#, dec!(123.456), Some(dec!(789.012)))]
1092    #[case(r#"{"value": "123.456", "optional_value": "789.012"}"#, dec!(123.456), Some(dec!(789.012)))]
1093    #[case(r#"{"value": 100, "optional_value": null}"#, dec!(100), None)]
1094    #[case(r#"{"value": null, "optional_value": null}"#, Decimal::ZERO, None)]
1095    fn test_deserialize_flexible_decimal(
1096        #[case] json: &str,
1097        #[case] expected_value: Decimal,
1098        #[case] expected_optional: Option<Decimal>,
1099    ) {
1100        let result: TestFlexibleDecimal = serde_json::from_str(json).unwrap();
1101        assert_eq!(result.value, expected_value);
1102        assert_eq!(result.optional_value, expected_optional);
1103    }
1104
1105    #[rstest]
1106    fn test_flexible_decimal_roundtrip() {
1107        let original = TestFlexibleDecimal {
1108            value: dec!(123.456),
1109            optional_value: Some(dec!(789.012)),
1110        };
1111
1112        let json = serde_json::to_string(&original).unwrap();
1113        let deserialized: TestFlexibleDecimal = serde_json::from_str(&json).unwrap();
1114
1115        assert_eq!(original.value, deserialized.value);
1116        assert_eq!(original.optional_value, deserialized.optional_value);
1117    }
1118
1119    #[rstest]
1120    fn test_flexible_decimal_scientific_notation() {
1121        // Test that scientific notation from serde_json is handled correctly.
1122        // serde_json outputs very small numbers like 0.00000001 as "1e-8".
1123        // Note: JSON numbers are parsed as f64, so values are limited to ~15 significant digits.
1124        let json = r#"{"value": 0.00000001, "optional_value": 12345678.12345}"#;
1125        let parsed: TestFlexibleDecimal = serde_json::from_str(json).unwrap();
1126        assert_eq!(parsed.value, dec!(0.00000001));
1127        assert_eq!(parsed.optional_value, Some(dec!(12345678.12345)));
1128    }
1129
1130    #[rstest]
1131    fn test_flexible_decimal_empty_string_optional() {
1132        let json = r#"{"value": 100, "optional_value": ""}"#;
1133        let parsed: TestFlexibleDecimal = serde_json::from_str(json).unwrap();
1134        assert_eq!(parsed.value, dec!(100));
1135        assert_eq!(parsed.optional_value, None);
1136    }
1137
1138    // Additional tests for DecimalVisitor edge cases
1139
1140    #[derive(Debug, Deserialize)]
1141    struct TestDecimalOnly {
1142        #[serde(deserialize_with = "deserialize_decimal")]
1143        value: Decimal,
1144    }
1145
1146    #[rstest]
1147    #[case(r#"{"value": "1.5e-8"}"#, dec!(0.000000015))]
1148    #[case(r#"{"value": "1E10"}"#, dec!(10000000000))]
1149    #[case(r#"{"value": "-1.23e5"}"#, dec!(-123000))]
1150    fn test_deserialize_decimal_scientific_string(#[case] json: &str, #[case] expected: Decimal) {
1151        let result: TestDecimalOnly = serde_json::from_str(json).unwrap();
1152        assert_eq!(result.value, expected);
1153    }
1154
1155    #[rstest]
1156    #[case(r#"{"value": 9223372036854775807}"#, dec!(9223372036854775807))] // i64::MAX
1157    #[case(r#"{"value": -9223372036854775808}"#, dec!(-9223372036854775808))] // i64::MIN
1158    #[case(r#"{"value": 0}"#, Decimal::ZERO)]
1159    fn test_deserialize_decimal_large_integers(#[case] json: &str, #[case] expected: Decimal) {
1160        let result: TestDecimalOnly = serde_json::from_str(json).unwrap();
1161        assert_eq!(result.value, expected);
1162    }
1163
1164    #[rstest]
1165    #[case(r#"{"value": "-123.456789"}"#, dec!(-123.456789))]
1166    #[case(r#"{"value": -999.99}"#, dec!(-999.99))]
1167    fn test_deserialize_decimal_negative(#[case] json: &str, #[case] expected: Decimal) {
1168        let result: TestDecimalOnly = serde_json::from_str(json).unwrap();
1169        assert_eq!(result.value, expected);
1170    }
1171
1172    #[rstest]
1173    #[case(r#"{"value": "123456789.123456789012345678"}"#)] // High precision string
1174    fn test_deserialize_decimal_high_precision(#[case] json: &str) {
1175        let result: TestDecimalOnly = serde_json::from_str(json).unwrap();
1176        assert_eq!(result.value, dec!(123456789.123456789012345678));
1177    }
1178
1179    #[derive(Debug, Deserialize)]
1180    struct TestOptionalDecimalOnly {
1181        #[serde(deserialize_with = "deserialize_optional_decimal")]
1182        value: Option<Decimal>,
1183    }
1184
1185    #[rstest]
1186    #[case(r#"{"value": "1.5e-8"}"#, Some(dec!(0.000000015)))]
1187    #[case(r#"{"value": null}"#, None)]
1188    #[case(r#"{"value": ""}"#, None)]
1189    #[case(r#"{"value": 42}"#, Some(dec!(42)))]
1190    #[case(r#"{"value": -100.5}"#, Some(dec!(-100.5)))]
1191    fn test_deserialize_optional_decimal_various(
1192        #[case] json: &str,
1193        #[case] expected: Option<Decimal>,
1194    ) {
1195        let result: TestOptionalDecimalOnly = serde_json::from_str(json).unwrap();
1196        assert_eq!(result.value, expected);
1197    }
1198}