ftl_jiter/
py_lossless_float.rs1use pyo3::exceptions::{PyTypeError, PyValueError};
2use pyo3::prelude::*;
3use pyo3::sync::GILOnceCell;
4use pyo3::types::PyType;
5
6use crate::Jiter;
7
8#[derive(Debug, Clone, Copy)]
9pub enum FloatMode {
10 Float,
11 Decimal,
12 LosslessFloat,
13}
14
15impl Default for FloatMode {
16 fn default() -> Self {
17 Self::Float
18 }
19}
20
21const FLOAT_ERROR: &str = "Invalid float mode, should be `'float'`, `'decimal'` or `'lossless-float'`";
22
23impl<'py> FromPyObject<'py> for FloatMode {
24 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
25 if let Ok(str_mode) = ob.extract::<&str>() {
26 match str_mode {
27 "float" => Ok(Self::Float),
28 "decimal" => Ok(Self::Decimal),
29 "lossless-float" => Ok(Self::LosslessFloat),
30 _ => Err(PyValueError::new_err(FLOAT_ERROR)),
31 }
32 } else {
33 Err(PyTypeError::new_err(FLOAT_ERROR))
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40#[pyclass(module = "jiter")]
41pub struct LosslessFloat(Vec<u8>);
42
43impl LosslessFloat {
44 pub fn new_unchecked(raw: Vec<u8>) -> Self {
45 Self(raw)
46 }
47}
48
49#[pymethods]
50impl LosslessFloat {
51 #[new]
52 fn new(raw: Vec<u8>) -> PyResult<Self> {
53 let s = Self(raw);
54 s.__float__()?;
56 Ok(s)
57 }
58
59 fn as_decimal<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
60 let decimal = get_decimal_type(py)?;
61 let float_str = self.__str__()?;
62 decimal.call1((float_str,))
63 }
64
65 fn __float__(&self) -> PyResult<f64> {
66 let bytes = &self.0;
67 let mut jiter = Jiter::new(bytes).with_allow_inf_nan();
68 let f = jiter
69 .next_float()
70 .map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
71 jiter
72 .finish()
73 .map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
74 Ok(f)
75 }
76
77 fn __bytes__(&self) -> &[u8] {
78 &self.0
79 }
80
81 fn __str__(&self) -> PyResult<&str> {
82 std::str::from_utf8(&self.0).map_err(|_| PyValueError::new_err("Invalid UTF-8"))
83 }
84
85 fn __repr__(&self) -> PyResult<String> {
86 self.__str__().map(|s| format!("LosslessFloat({s})"))
87 }
88}
89
90static DECIMAL_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
91
92pub fn get_decimal_type(py: Python) -> PyResult<&Bound<'_, PyType>> {
93 DECIMAL_TYPE
94 .get_or_try_init(py, || py.import_bound("decimal")?.getattr("Decimal")?.extract())
95 .map(|t| t.bind(py))
96}