Skip to main content

graphrecords_python/graphrecord/
value.rs

1use super::{Lut, traits::DeepFrom};
2use crate::{conversion_lut::ConversionLut, graphrecord::errors::PyGraphRecordError};
3use chrono::{NaiveDateTime, TimeDelta};
4use graphrecords_core::{errors::GraphRecordError, graphrecord::GraphRecordValue};
5use pyo3::{
6    Borrowed, Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python,
7    types::{PyAnyMethods, PyBool, PyDateTime, PyDelta, PyFloat, PyInt, PyString},
8};
9use std::ops::Deref;
10
11#[repr(transparent)]
12#[derive(Clone, Debug)]
13pub struct PyGraphRecordValue(GraphRecordValue);
14
15impl From<GraphRecordValue> for PyGraphRecordValue {
16    fn from(value: GraphRecordValue) -> Self {
17        Self(value)
18    }
19}
20
21impl From<PyGraphRecordValue> for GraphRecordValue {
22    fn from(value: PyGraphRecordValue) -> Self {
23        value.0
24    }
25}
26
27impl DeepFrom<PyGraphRecordValue> for GraphRecordValue {
28    fn deep_from(value: PyGraphRecordValue) -> Self {
29        value.into()
30    }
31}
32
33impl DeepFrom<GraphRecordValue> for PyGraphRecordValue {
34    fn deep_from(value: GraphRecordValue) -> Self {
35        value.into()
36    }
37}
38
39impl Deref for PyGraphRecordValue {
40    type Target = GraphRecordValue;
41
42    fn deref(&self) -> &Self::Target {
43        &self.0
44    }
45}
46
47static GRAPHRECORDVALUE_CONVERSION_LUT: Lut<GraphRecordValue> = ConversionLut::new();
48
49#[allow(clippy::unnecessary_wraps)]
50pub(crate) fn convert_pyobject_to_graphrecordvalue(
51    ob: &Bound<'_, PyAny>,
52) -> PyResult<GraphRecordValue> {
53    fn convert_string(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
54        Ok(GraphRecordValue::String(
55            ob.extract::<String>().expect("Extraction must succeed"),
56        ))
57    }
58
59    fn convert_int(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
60        Ok(GraphRecordValue::Int(
61            ob.extract::<i64>().expect("Extraction must succeed"),
62        ))
63    }
64
65    fn convert_float(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
66        Ok(GraphRecordValue::Float(
67            ob.extract::<f64>().expect("Extraction must succeed"),
68        ))
69    }
70
71    fn convert_bool(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
72        Ok(GraphRecordValue::Bool(
73            ob.extract::<bool>().expect("Extraction must succeed"),
74        ))
75    }
76
77    fn convert_datetime(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
78        Ok(GraphRecordValue::DateTime(
79            ob.extract::<NaiveDateTime>()
80                .expect("Extraction must succeed"),
81        ))
82    }
83
84    fn convert_duration(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
85        Ok(GraphRecordValue::Duration(
86            ob.extract::<TimeDelta>().expect("Extraction must succeed"),
87        ))
88    }
89
90    const fn convert_null(_ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
91        Ok(GraphRecordValue::Null)
92    }
93
94    fn throw_error(ob: &Bound<'_, PyAny>) -> PyResult<GraphRecordValue> {
95        Err(
96            PyGraphRecordError::from(GraphRecordError::ConversionError(format!(
97                "Failed to convert {ob} into GraphRecordValue",
98            )))
99            .into(),
100        )
101    }
102
103    let type_pointer = ob.get_type_ptr() as usize;
104
105    let conversion_function = GRAPHRECORDVALUE_CONVERSION_LUT.get_or_insert(type_pointer, || {
106        if ob.is_instance_of::<PyString>() {
107            convert_string
108        } else if ob.is_instance_of::<PyBool>() {
109            convert_bool
110        } else if ob.is_instance_of::<PyInt>() {
111            convert_int
112        } else if ob.is_instance_of::<PyFloat>() {
113            convert_float
114        } else if ob.is_instance_of::<PyDateTime>() {
115            convert_datetime
116        } else if ob.is_instance_of::<PyDelta>() {
117            convert_duration
118        } else if ob.is_none() {
119            convert_null
120        } else {
121            throw_error
122        }
123    });
124
125    conversion_function(ob)
126}
127
128impl FromPyObject<'_, '_> for PyGraphRecordValue {
129    type Error = PyErr;
130
131    fn extract(ob: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
132        convert_pyobject_to_graphrecordvalue(&ob).map(Self::from)
133    }
134}
135
136impl<'py> IntoPyObject<'py> for PyGraphRecordValue {
137    type Target = PyAny;
138    type Output = Bound<'py, Self::Target>;
139    type Error = PyErr;
140
141    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
142        match self.0 {
143            GraphRecordValue::String(value) => value.into_bound_py_any(py),
144            GraphRecordValue::Int(value) => value.into_bound_py_any(py),
145            GraphRecordValue::Float(value) => value.into_bound_py_any(py),
146            GraphRecordValue::Bool(value) => value.into_bound_py_any(py),
147            GraphRecordValue::DateTime(value) => value.into_bound_py_any(py),
148            GraphRecordValue::Duration(value) => value.into_bound_py_any(py),
149            GraphRecordValue::Null => py.None().into_bound_py_any(py),
150        }
151    }
152}