use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::py_format;
use pyo3::sync::PyOnceLock;
use pyo3::types::{PyString, PyType};
use crate::Jiter;
#[derive(Debug, Clone, Copy, Default)]
pub enum FloatMode {
#[default]
Float,
Decimal,
LosslessFloat,
}
const FLOAT_ERROR: &str = "Invalid float mode, should be `'float'`, `'decimal'` or `'lossless-float'`";
impl<'py> FromPyObject<'_, 'py> for FloatMode {
type Error = PyErr;
fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
if let Ok(str_mode) = ob.extract::<&str>() {
match str_mode {
"float" => Ok(Self::Float),
"decimal" => Ok(Self::Decimal),
"lossless-float" => Ok(Self::LosslessFloat),
_ => Err(PyValueError::new_err(FLOAT_ERROR)),
}
} else {
Err(PyTypeError::new_err(FLOAT_ERROR))
}
}
}
#[derive(Debug, Clone)]
#[pyclass(module = "jiter", skip_from_py_object)]
pub struct LosslessFloat(Vec<u8>);
impl LosslessFloat {
pub fn new_unchecked(raw: Vec<u8>) -> Self {
Self(raw)
}
}
#[pymethods]
impl LosslessFloat {
#[new]
fn new(raw: Vec<u8>) -> PyResult<Self> {
let s = Self(raw);
s.__float__()?;
Ok(s)
}
fn as_decimal<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let decimal = get_decimal_type(py)?;
let float_str = self.__str__()?;
decimal.call1((float_str,))
}
fn __float__(&self) -> PyResult<f64> {
let bytes = &self.0;
let mut jiter = Jiter::new(bytes).with_allow_inf_nan();
let f = jiter
.next_float()
.map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
jiter
.finish()
.map_err(|e| PyValueError::new_err(e.description(&jiter)))?;
Ok(f)
}
fn __bytes__(&self) -> &[u8] {
&self.0
}
fn __str__(&self) -> PyResult<&str> {
std::str::from_utf8(&self.0).map_err(|_| PyValueError::new_err("Invalid UTF-8"))
}
fn __repr__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyString>> {
let s = self.__str__()?;
py_format!(py, "LosslessFloat({s})")
}
}
static DECIMAL_TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
pub fn get_decimal_type(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
DECIMAL_TYPE.import(py, "decimal", "Decimal")
}