Skip to main content

graphrecords_python/graphrecord/
connector.rs

1use crate::prelude::PyGraphRecord;
2use graphrecords_core::{
3    GraphRecord,
4    errors::{GraphRecordError, GraphRecordResult},
5    graphrecord::connector::{Connector, ExportConnector, IngestConnector},
6};
7use pyo3::{Py, PyAny, Python, types::PyAnyMethods};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9
10#[derive(Debug)]
11pub struct PyConnector(Py<PyAny>);
12
13impl PyConnector {
14    pub const fn new(connector: Py<PyAny>) -> Self {
15        Self(connector)
16    }
17}
18
19impl Clone for PyConnector {
20    fn clone(&self) -> Self {
21        Python::attach(|py| Self(self.0.clone_ref(py)))
22    }
23}
24
25impl Serialize for PyConnector {
26    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
27        Python::attach(|py| {
28            let cloudpickle = py
29                .import("cloudpickle")
30                .map_err(serde::ser::Error::custom)?;
31
32            let bytes: Vec<u8> = cloudpickle
33                .call_method1("dumps", (&self.0,))
34                .map_err(serde::ser::Error::custom)?
35                .extract()
36                .map_err(serde::ser::Error::custom)?;
37
38            serializer.serialize_bytes(&bytes)
39        })
40    }
41}
42
43impl<'de> Deserialize<'de> for PyConnector {
44    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
45        let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
46
47        Python::attach(|py| {
48            let cloudpickle = py.import("cloudpickle").map_err(serde::de::Error::custom)?;
49
50            let obj: Py<PyAny> = cloudpickle
51                .call_method1("loads", (bytes.as_slice(),))
52                .map_err(serde::de::Error::custom)?
53                .into();
54
55            Ok(Self(obj))
56        })
57    }
58}
59
60impl Connector for PyConnector {
61    fn initialize(&self, graphrecord: &mut GraphRecord) -> GraphRecordResult<()> {
62        Python::attach(|py| {
63            PyGraphRecord::scope_mut(py, graphrecord, |py, graphrecord| {
64                self.0
65                    .call_method1(py, "initialize", (graphrecord,))
66                    .map_err(|err| GraphRecordError::ConversionError(format!("{err}")))?;
67
68                Ok(())
69            })
70        })
71    }
72
73    fn disconnect(&self, graphrecord: &mut GraphRecord) -> GraphRecordResult<()> {
74        Python::attach(|py| {
75            PyGraphRecord::scope_mut(py, graphrecord, |py, graphrecord| {
76                self.0
77                    .call_method1(py, "disconnect", (graphrecord,))
78                    .map_err(|err| GraphRecordError::ConversionError(format!("{err}")))?;
79
80                Ok(())
81            })
82        })
83    }
84}
85
86impl IngestConnector for PyConnector {
87    type DataSet = Py<PyAny>;
88
89    fn ingest(&self, graphrecord: &mut GraphRecord, data: Self::DataSet) -> GraphRecordResult<()> {
90        Python::attach(|py| {
91            PyGraphRecord::scope_mut(py, graphrecord, |py, graphrecord| {
92                self.0
93                    .call_method1(py, "ingest", (graphrecord, data))
94                    .map_err(|err| GraphRecordError::ConversionError(format!("{err}")))?;
95
96                Ok(())
97            })
98        })
99    }
100}
101
102impl ExportConnector for PyConnector {
103    type DataSet = Py<PyAny>;
104
105    fn export(&self, graphrecord: &GraphRecord) -> GraphRecordResult<Self::DataSet> {
106        Python::attach(|py| {
107            PyGraphRecord::scope(py, graphrecord, |py, graphrecord| {
108                let data = self
109                    .0
110                    .call_method1(py, "export", (graphrecord,))
111                    .map_err(|err| GraphRecordError::ConversionError(format!("{err}")))?;
112
113                Ok(data)
114            })
115        })
116    }
117}