dora_operator_api_python/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8    merged::{MergeExternalSend, MergedEvent},
9    DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter,
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15    prelude::*,
16    types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
17};
18
19/// Dora Event
20pub struct PyEvent {
21    pub event: MergedEvent<PyObject>,
22    pub _cleanup: Option<NodeCleanupHandle>,
23}
24
25/// Keeps the dora node alive until all event objects have been dropped.
26#[derive(Clone)]
27#[pyclass]
28pub struct NodeCleanupHandle {
29    pub _handles: Arc<(CleanupHandle<DoraNode>, CleanupHandle<EventStream>)>,
30}
31
32/// Owned type with delayed cleanup (using `handle` method).
33pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
34
35impl<T> DelayedCleanup<T> {
36    pub fn new(value: T) -> Self {
37        Self(Arc::new(Mutex::new(value)))
38    }
39
40    pub fn handle(&self) -> CleanupHandle<T> {
41        CleanupHandle(self.0.clone())
42    }
43
44    pub fn get_mut(&mut self) -> std::sync::MutexGuard<T> {
45        self.0.try_lock().expect("failed to lock DelayedCleanup")
46    }
47}
48
49impl Stream for DelayedCleanup<EventStream> {
50    type Item = Event;
51
52    fn poll_next(
53        self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55    ) -> std::task::Poll<Option<Self::Item>> {
56        let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
57        inner.poll_next_unpin(cx)
58    }
59}
60
61impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
62where
63    E: 'static,
64{
65    type Item = MergedEvent<E>;
66
67    fn merge_external_send(
68        self,
69        external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
70    ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
71        let dora = self.map(MergedEvent::Dora);
72        let external = external_events.map(MergedEvent::External);
73        Box::new((dora, external).merge())
74    }
75}
76
77#[allow(dead_code)]
78pub struct CleanupHandle<T>(Arc<Mutex<T>>);
79
80impl PyEvent {
81    pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
82        let mut pydict = HashMap::new();
83        match &self.event {
84            MergedEvent::Dora(_) => pydict.insert("kind", "dora".to_object(py)),
85            MergedEvent::External(_) => pydict.insert("kind", "external".to_object(py)),
86        };
87        match &self.event {
88            MergedEvent::Dora(event) => {
89                if let Some(id) = Self::id(event) {
90                    pydict.insert("id", id.into_py(py));
91                }
92                pydict.insert("type", Self::ty(event).to_object(py));
93
94                if let Some(value) = self.value(py)? {
95                    pydict.insert("value", value);
96                }
97                if let Some(metadata) = Self::metadata(event, py)? {
98                    pydict.insert("metadata", metadata);
99                }
100                if let Some(error) = Self::error(event) {
101                    pydict.insert("error", error.to_object(py));
102                }
103            }
104            MergedEvent::External(event) => {
105                pydict.insert("value", event.clone_ref(py));
106            }
107        }
108
109        if let Some(cleanup) = self._cleanup.clone() {
110            pydict.insert("_cleanup", cleanup.into_py(py));
111        }
112
113        Ok(pydict.into_py_dict_bound(py).unbind())
114    }
115
116    fn ty(event: &Event) -> &str {
117        match event {
118            Event::Stop => "STOP",
119            Event::Input { .. } => "INPUT",
120            Event::InputClosed { .. } => "INPUT_CLOSED",
121            Event::Error(_) => "ERROR",
122            _other => "UNKNOWN",
123        }
124    }
125
126    fn id(event: &Event) -> Option<&str> {
127        match event {
128            Event::Input { id, .. } => Some(id),
129            Event::InputClosed { id } => Some(id),
130            _ => None,
131        }
132    }
133
134    /// Returns the payload of an input event as an arrow array (if any).
135    fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
136        match &self.event {
137            MergedEvent::Dora(Event::Input { data, .. }) => {
138                // TODO: Does this call leak data?&
139                let array_data = data.to_data().to_pyarrow(py)?;
140                Ok(Some(array_data))
141            }
142            _ => Ok(None),
143        }
144    }
145
146    fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
147        match event {
148            Event::Input { metadata, .. } => Ok(Some(
149                metadata_to_pydict(metadata, py)
150                    .context("Issue deserializing metadata")?
151                    .to_object(py),
152            )),
153            _ => Ok(None),
154        }
155    }
156
157    fn error(event: &Event) -> Option<&str> {
158        match event {
159            Event::Error(error) => Some(error),
160            _other => None,
161        }
162    }
163}
164
165pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
166    let mut parameters = BTreeMap::default();
167    if let Some(pymetadata) = dict {
168        for (key, value) in pymetadata.iter() {
169            let key = key.extract::<String>().context("Parsing metadata keys")?;
170            if value.is_exact_instance_of::<PyBool>() {
171                parameters.insert(key, Parameter::Bool(value.extract()?))
172            } else if value.is_instance_of::<PyInt>() {
173                parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
174            } else if value.is_instance_of::<PyFloat>() {
175                parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
176            } else if value.is_instance_of::<PyString>() {
177                parameters.insert(key, Parameter::String(value.extract()?))
178            } else if value.is_instance_of::<PyTuple>()
179                && value.len()? > 0
180                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
181            {
182                let list: Vec<i64> = value.extract()?;
183                parameters.insert(key, Parameter::ListInt(list))
184            } else if value.is_instance_of::<PyList>()
185                && value.len()? > 0
186                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
187            {
188                let list: Vec<i64> = value.extract()?;
189                parameters.insert(key, Parameter::ListInt(list))
190            } else if value.is_instance_of::<PyTuple>()
191                && value.len()? > 0
192                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
193            {
194                let list: Vec<f64> = value.extract()?;
195                parameters.insert(key, Parameter::ListFloat(list))
196            } else if value.is_instance_of::<PyList>()
197                && value.len()? > 0
198                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
199            {
200                let list: Vec<f64> = value.extract()?;
201                parameters.insert(key, Parameter::ListFloat(list))
202            } else if value.is_instance_of::<PyList>()
203                && value.len()? > 0
204                && value.get_item(0)?.is_exact_instance_of::<PyString>()
205            {
206                let list: Vec<String> = value.extract()?;
207                parameters.insert(key, Parameter::ListString(list))
208            } else {
209                println!("could not convert type {value}");
210                parameters.insert(key, Parameter::String(value.str()?.to_string()))
211            };
212        }
213    }
214    Ok(parameters)
215}
216
217pub fn metadata_to_pydict<'a>(
218    metadata: &'a Metadata,
219    py: Python<'a>,
220) -> Result<pyo3::Bound<'a, PyDict>> {
221    let dict = PyDict::new_bound(py);
222    for (k, v) in metadata.parameters.iter() {
223        match v {
224            Parameter::Bool(bool) => dict
225                .set_item(k, bool)
226                .context("Could not insert metadata into python dictionary")?,
227            Parameter::Integer(int) => dict
228                .set_item(k, int)
229                .context("Could not insert metadata into python dictionary")?,
230            Parameter::Float(float) => dict
231                .set_item(k, float)
232                .context("Could not insert metadata into python dictionary")?,
233            Parameter::String(s) => dict
234                .set_item(k, s)
235                .context("Could not insert metadata into python dictionary")?,
236            Parameter::ListInt(l) => dict
237                .set_item(k, l)
238                .context("Could not insert metadata into python dictionary")?,
239            Parameter::ListFloat(l) => dict
240                .set_item(k, l)
241                .context("Could not insert metadata into python dictionary")?,
242            Parameter::ListString(l) => dict
243                .set_item(k, l)
244                .context("Could not insert metadata into python dictionary")?,
245        }
246    }
247
248    Ok(dict)
249}
250
251#[cfg(test)]
252mod tests {
253    use std::sync::Arc;
254
255    use aligned_vec::{AVec, ConstAlign};
256    use arrow::{
257        array::{
258            ArrayData, ArrayRef, BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array,
259            ListArray, StructArray,
260        },
261        buffer::Buffer,
262    };
263
264    use arrow_schema::{DataType, Field};
265    use dora_node_api::{
266        arrow_utils::{copy_array_into_sample, required_data_size},
267        RawData,
268    };
269    use eyre::{Context, Result};
270
271    fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
272        let size = required_data_size(arrow_array);
273        let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
274
275        let info = copy_array_into_sample(&mut sample, arrow_array);
276
277        let serialized_deserialized_arrow_array = RawData::Vec(sample)
278            .into_arrow_array(&info)
279            .context("Could not create arrow array")?;
280        assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
281
282        Ok(())
283    }
284
285    #[test]
286    fn serialize_deserialize_arrow() -> Result<()> {
287        // Int8
288        let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
289        assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
290
291        // Int64
292        let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
293        assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
294
295        // Float64
296        let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
297        assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
298
299        // Struct
300        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
301        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
302
303        let struct_array = StructArray::from(vec![
304            (
305                Arc::new(Field::new("b", DataType::Boolean, false)),
306                boolean as ArrayRef,
307            ),
308            (
309                Arc::new(Field::new("c", DataType::Int32, false)),
310                int as ArrayRef,
311            ),
312        ])
313        .into();
314        assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
315
316        // List
317        let value_data = ArrayData::builder(DataType::Int32)
318            .len(8)
319            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
320            .build()
321            .unwrap();
322
323        // Construct a buffer for value offsets, for the nested array:
324        //  [[0, 1, 2], [3, 4, 5], [6, 7]]
325        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
326
327        // Construct a list array from the above two
328        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
329        let list_data = ArrayData::builder(list_data_type)
330            .len(3)
331            .add_buffer(value_offsets)
332            .add_child_data(value_data)
333            .build()
334            .unwrap();
335        let list_array = ListArray::from(list_data).into();
336        assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
337
338        Ok(())
339    }
340}