use crate::DataValue;
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{PyBytes, PyBytesMethods, PyDict, PyList, PyNone},
};
#[derive(FromPyObject, Debug)]
pub enum PyDataValue<'py> {
#[pyo3(transparent, annotation = "str")]
String(String),
#[pyo3(transparent, annotation = "bytes")]
Bytes(Bound<'py, PyBytes>),
#[pyo3(transparent, annotation = "bool")]
Bool(bool),
#[pyo3(transparent, annotation = "i32")]
Int(i32),
#[pyo3(transparent, annotation = "u32")]
UInt(u32),
#[pyo3(transparent, annotation = "i64")]
BigInt(i64),
#[pyo3(transparent, annotation = "u64")]
UBigInt(u64),
#[pyo3(transparent, annotation = "u128")]
UBigBigInt(u128),
#[pyo3(transparent, annotation = "i128")]
BigBigInt(i128),
#[pyo3(transparent, annotation = "f64")]
Double(f64),
#[pyo3(transparent, annotation = "f64")]
Float(f32),
#[pyo3(transparent, annotation = "map")]
Map(std::collections::HashMap<String, PyDataValue<'py>>),
#[pyo3(transparent, annotation = "vec")]
Vec(Vec<PyDataValue<'py>>),
#[pyo3(transparent, annotation = "none")]
Null(Bound<'py, PyNone>),
}
impl<'a> IntoPyObject<'a> for DataValue {
type Target = PyAny;
type Output = Bound<'a, PyAny>;
type Error = PyErr;
fn into_pyobject(self, py: Python<'a>) -> Result<Self::Output, Self::Error> {
match self {
DataValue::U8(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::U32(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::Bool(b) => Ok(b.into_pyobject(py)?.as_any().clone()),
DataValue::I32(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::U64(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::I64(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::F64(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::F32(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::String(b) => Ok(b.as_str().into_pyobject(py)?.into_any()),
DataValue::Bytes(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::Vec(b) => {
let list = PyList::empty(py);
for v in b {
list.append(v.into_pyobject(py)?.clone())?;
}
Ok(list.into_any())
}
DataValue::Map(b) => {
let dict = PyDict::new(py);
for (k, v) in b {
dict.set_item(k.as_str(), v.into_pyobject(py)?.clone())?;
}
Ok(dict.into_any())
}
DataValue::Null => Ok(py.None().into_pyobject(py)?.into_any()),
DataValue::U128(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::I128(b) => Ok(b.into_pyobject(py)?.into_any()),
DataValue::EnumNumber(b) => Ok(b.into_pyobject(py)?.into_any()),
}
}
}
impl<'py> From<PyDataValue<'py>> for DataValue {
fn from(value: PyDataValue) -> Self {
match value {
PyDataValue::String(s) => Self::from(s),
PyDataValue::BigInt(i) => Self::from(i),
PyDataValue::UBigBigInt(f) => Self::from(f),
PyDataValue::BigBigInt(f) => Self::from(f),
PyDataValue::UBigInt(f) => Self::from(f),
PyDataValue::UInt(f) => Self::from(f),
PyDataValue::Float(f) => Self::from(f),
PyDataValue::Int(i) => Self::from(i),
PyDataValue::Double(f) => Self::from(f),
PyDataValue::Bool(b) => Self::from(b),
PyDataValue::Bytes(b) => Self::Bytes(b.as_bytes().to_owned()),
PyDataValue::Null(_) => Self::Null,
PyDataValue::Map(v) => Self::Map(
v.into_iter()
.map(|(k, v)| (k.into(), Self::from(v)))
.collect(),
),
PyDataValue::Vec(v) => DataValue::Vec(v.into_iter().map(Self::from).collect()),
}
}
}
impl FromPyObject<'_> for DataValue {
fn extract_bound(object: &Bound<'_, PyAny>) -> Result<Self, pyo3::PyErr> {
PyDataValue::extract_bound(object)
.map_err(|error| {
let message = format!(
"Cannot create DataValue from class {:?}",
object.get_type().name()
);
let type_error = PyTypeError::new_err(message);
Python::with_gil(|python| {
type_error.set_cause(python, Some(error));
});
type_error
})
.map(|value| value.into())
}
}
#[cfg(test)]
mod test {
use super::*;
use pyo3::ffi::c_str;
use rstest::*;
fn check_py_rust_conversion(value: DataValue, expected: DataValue) {
let extracted = Python::with_gil(|py| -> DataValue {
let fun: Py<PyAny> = PyModule::from_code(
py,
c_str!(
"
def example(df):
return df
"
),
c_str!(""),
c_str!(""),
)
.expect("BUG: Cannot compile testing function")
.getattr("example")
.expect("BUG: Cannot get testing function")
.into();
let result = fun
.call1(py, (value,))
.expect("BUG: Cannot call test function");
let extracted = result
.extract::<PyDataValue<'_>>(py)
.expect("BUG: Cannot extract PyDataValue from python object");
DataValue::from(extracted)
});
assert_eq!(extracted, expected);
}
#[rstest]
#[case(DataValue::Bool(true))]
#[case(DataValue::F64(f64::MAX))]
#[case(DataValue::U32(u32::MAX))]
#[case(DataValue::I32(i32::MIN))]
#[case(DataValue::U64(u64::MAX))]
#[case(DataValue::I64(i64::MIN))]
#[case(DataValue::U128(u128::MAX))]
#[case(DataValue::I128(i128::MIN))]
#[case(DataValue::Bytes(b"10".to_vec()))]
#[case(DataValue::String("10".into()))]
#[case(DataValue::Vec(vec![DataValue::I32(10), DataValue::I32(i32::MIN)]))]
#[case(DataValue::Map(crate::stdhashmap! {
"a" =>10,
"b" => i32::MIN,
}))]
#[case(DataValue::Null)]
fn test_data_value_conversion(#[case] value: DataValue) {
check_py_rust_conversion(value.clone(), value)
}
#[rstest]
#[case(DataValue::U8(1), DataValue::I32(1))]
#[case(DataValue::F32(f32::MAX), DataValue::F64(f32::MAX as f64))]
fn test_special_cases(#[case] value: DataValue, #[case] expected: DataValue) {
check_py_rust_conversion(value, expected)
}
}