ftl_jiter/
py_lossless_float.rs

1use pyo3::exceptions::{PyTypeError, PyValueError};
2use pyo3::prelude::*;
3use pyo3::sync::GILOnceCell;
4use pyo3::types::PyType;
5
6use crate::Jiter;
7
8#[derive(Debug, Clone, Copy)]
9pub enum FloatMode {
10    Float,
11    Decimal,
12    LosslessFloat,
13}
14
15impl Default for FloatMode {
16    fn default() -> Self {
17        Self::Float
18    }
19}
20
21const FLOAT_ERROR: &str = "Invalid float mode, should be `'float'`, `'decimal'` or `'lossless-float'`";
22
23impl<'py> FromPyObject<'py> for FloatMode {
24    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
25        if let Ok(str_mode) = ob.extract::<&str>() {
26            match str_mode {
27                "float" => Ok(Self::Float),
28                "decimal" => Ok(Self::Decimal),
29                "lossless-float" => Ok(Self::LosslessFloat),
30                _ => Err(PyValueError::new_err(FLOAT_ERROR)),
31            }
32        } else {
33            Err(PyTypeError::new_err(FLOAT_ERROR))
34        }
35    }
36}
37
38/// Represents a float from JSON, by holding the underlying bytes representing a float from JSON.
39#[derive(Debug, Clone)]
40#[pyclass(module = "jiter")]
41pub struct LosslessFloat(Vec<u8>);
42
43impl LosslessFloat {
44    pub fn new_unchecked(raw: Vec<u8>) -> Self {
45        Self(raw)
46    }
47}
48
49#[pymethods]
50impl LosslessFloat {
51    #[new]
52    fn new(raw: Vec<u8>) -> PyResult<Self> {
53        let s = Self(raw);
54        // check the string is valid by calling `as_float`
55        s.__float__()?;
56        Ok(s)
57    }
58
59    fn as_decimal<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
60        let decimal = get_decimal_type(py)?;
61        let float_str = self.__str__()?;
62        decimal.call1((float_str,))
63    }
64
65    fn __float__(&self) -> PyResult<f64> {
66        let bytes = &self.0;
67        let mut jiter = Jiter::new(bytes).with_allow_inf_nan();
68        let f = jiter
69            .next_float()
70            .map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
71        jiter
72            .finish()
73            .map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
74        Ok(f)
75    }
76
77    fn __bytes__(&self) -> &[u8] {
78        &self.0
79    }
80
81    fn __str__(&self) -> PyResult<&str> {
82        std::str::from_utf8(&self.0).map_err(|_| PyValueError::new_err("Invalid UTF-8"))
83    }
84
85    fn __repr__(&self) -> PyResult<String> {
86        self.__str__().map(|s| format!("LosslessFloat({s})"))
87    }
88}
89
90static DECIMAL_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
91
92pub fn get_decimal_type(py: Python) -> PyResult<&Bound<'_, PyType>> {
93    DECIMAL_TYPE
94        .get_or_try_init(py, || py.import_bound("decimal")?.getattr("Decimal")?.extract())
95        .map(|t| t.bind(py))
96}