dora_operator_api_python/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8    merged::{MergeExternalSend, MergedEvent},
9    DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter,
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15    prelude::*,
16    types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
17};
18
19/// Dora Event
20pub struct PyEvent {
21    pub event: MergedEvent<PyObject>,
22}
23
24/// Keeps the dora node alive until all event objects have been dropped.
25#[derive(Clone)]
26#[pyclass]
27pub struct NodeCleanupHandle {
28    pub _handles: Arc<CleanupHandle<DoraNode>>,
29}
30
31/// Owned type with delayed cleanup (using `handle` method).
32pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
33
34impl<T> DelayedCleanup<T> {
35    pub fn new(value: T) -> Self {
36        Self(Arc::new(Mutex::new(value)))
37    }
38
39    pub fn handle(&self) -> CleanupHandle<T> {
40        CleanupHandle(self.0.clone())
41    }
42
43    pub fn get_mut(&mut self) -> std::sync::MutexGuard<T> {
44        self.0.try_lock().expect("failed to lock DelayedCleanup")
45    }
46}
47
48impl Stream for DelayedCleanup<EventStream> {
49    type Item = Event;
50
51    fn poll_next(
52        self: std::pin::Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54    ) -> std::task::Poll<Option<Self::Item>> {
55        let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
56        inner.poll_next_unpin(cx)
57    }
58}
59
60impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
61where
62    E: 'static,
63{
64    type Item = MergedEvent<E>;
65
66    fn merge_external_send(
67        self,
68        external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
69    ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
70        let dora = self.map(MergedEvent::Dora);
71        let external = external_events.map(MergedEvent::External);
72        Box::new((dora, external).merge())
73    }
74}
75
76#[allow(dead_code)]
77pub struct CleanupHandle<T>(Arc<Mutex<T>>);
78
79impl PyEvent {
80    pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
81        let mut pydict = HashMap::new();
82        match &self.event {
83            MergedEvent::Dora(_) => pydict.insert(
84                "kind",
85                "dora"
86                    .into_pyobject(py)
87                    .context("Failed to create pystring")?
88                    .unbind()
89                    .into(),
90            ),
91            MergedEvent::External(_) => pydict.insert(
92                "kind",
93                "external"
94                    .into_pyobject(py)
95                    .context("Failed to create pystring")?
96                    .unbind()
97                    .into(),
98            ),
99        };
100        match &self.event {
101            MergedEvent::Dora(event) => {
102                if let Some(id) = Self::id(event) {
103                    pydict.insert(
104                        "id",
105                        id.into_pyobject(py)
106                            .context("Failed to create id pyobject")?
107                            .into(),
108                    );
109                }
110                pydict.insert(
111                    "type",
112                    Self::ty(event)
113                        .into_pyobject(py)
114                        .context("Failed to create event pyobject")?
115                        .unbind()
116                        .into(),
117                );
118
119                if let Some(value) = self.value(py)? {
120                    pydict.insert("value", value);
121                }
122                if let Some(metadata) = Self::metadata(event, py)? {
123                    pydict.insert("metadata", metadata);
124                }
125                if let Some(error) = Self::error(event) {
126                    pydict.insert(
127                        "error",
128                        error
129                            .into_pyobject(py)
130                            .context("Failed to create error pyobject")?
131                            .unbind()
132                            .into(),
133                    );
134                }
135            }
136            MergedEvent::External(event) => {
137                pydict.insert("value", event.clone_ref(py));
138            }
139        }
140
141        Ok(pydict
142            .into_py_dict(py)
143            .context("Failed to create py_dict")?
144            .unbind())
145    }
146
147    fn ty(event: &Event) -> &str {
148        match event {
149            Event::Stop => "STOP",
150            Event::Input { .. } => "INPUT",
151            Event::InputClosed { .. } => "INPUT_CLOSED",
152            Event::Error(_) => "ERROR",
153            _other => "UNKNOWN",
154        }
155    }
156
157    fn id(event: &Event) -> Option<&str> {
158        match event {
159            Event::Input { id, .. } => Some(id),
160            Event::InputClosed { id } => Some(id),
161            _ => None,
162        }
163    }
164
165    /// Returns the payload of an input event as an arrow array (if any).
166    fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
167        match &self.event {
168            MergedEvent::Dora(Event::Input { data, .. }) => {
169                // TODO: Does this call leak data?&
170                let array_data = data.to_data().to_pyarrow(py)?;
171                Ok(Some(array_data))
172            }
173            _ => Ok(None),
174        }
175    }
176
177    fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
178        match event {
179            Event::Input { metadata, .. } => Ok(Some(
180                metadata_to_pydict(metadata, py)
181                    .context("Issue deserializing metadata")?
182                    .into_pyobject(py)
183                    .context("Failed to create metadata_to_pydice")?
184                    .unbind()
185                    .into(),
186            )),
187            _ => Ok(None),
188        }
189    }
190
191    fn error(event: &Event) -> Option<&str> {
192        match event {
193            Event::Error(error) => Some(error),
194            _other => None,
195        }
196    }
197}
198
199pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
200    let mut parameters = BTreeMap::default();
201    if let Some(pymetadata) = dict {
202        for (key, value) in pymetadata.iter() {
203            let key = key.extract::<String>().context("Parsing metadata keys")?;
204            if value.is_exact_instance_of::<PyBool>() {
205                parameters.insert(key, Parameter::Bool(value.extract()?))
206            } else if value.is_instance_of::<PyInt>() {
207                parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
208            } else if value.is_instance_of::<PyFloat>() {
209                parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
210            } else if value.is_instance_of::<PyString>() {
211                parameters.insert(key, Parameter::String(value.extract()?))
212            } else if value.is_instance_of::<PyTuple>()
213                && value.len()? > 0
214                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
215            {
216                let list: Vec<i64> = value.extract()?;
217                parameters.insert(key, Parameter::ListInt(list))
218            } else if value.is_instance_of::<PyList>()
219                && value.len()? > 0
220                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
221            {
222                let list: Vec<i64> = value.extract()?;
223                parameters.insert(key, Parameter::ListInt(list))
224            } else if value.is_instance_of::<PyTuple>()
225                && value.len()? > 0
226                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
227            {
228                let list: Vec<f64> = value.extract()?;
229                parameters.insert(key, Parameter::ListFloat(list))
230            } else if value.is_instance_of::<PyList>()
231                && value.len()? > 0
232                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
233            {
234                let list: Vec<f64> = value.extract()?;
235                parameters.insert(key, Parameter::ListFloat(list))
236            } else if value.is_instance_of::<PyList>()
237                && value.len()? > 0
238                && value.get_item(0)?.is_exact_instance_of::<PyString>()
239            {
240                let list: Vec<String> = value.extract()?;
241                parameters.insert(key, Parameter::ListString(list))
242            } else {
243                println!("could not convert type {value}");
244                parameters.insert(key, Parameter::String(value.str()?.to_string()))
245            };
246        }
247    }
248    Ok(parameters)
249}
250
251pub fn metadata_to_pydict<'a>(
252    metadata: &'a Metadata,
253    py: Python<'a>,
254) -> Result<pyo3::Bound<'a, PyDict>> {
255    let dict = PyDict::new(py);
256    for (k, v) in metadata.parameters.iter() {
257        match v {
258            Parameter::Bool(bool) => dict
259                .set_item(k, bool)
260                .context("Could not insert metadata into python dictionary")?,
261            Parameter::Integer(int) => dict
262                .set_item(k, int)
263                .context("Could not insert metadata into python dictionary")?,
264            Parameter::Float(float) => dict
265                .set_item(k, float)
266                .context("Could not insert metadata into python dictionary")?,
267            Parameter::String(s) => dict
268                .set_item(k, s)
269                .context("Could not insert metadata into python dictionary")?,
270            Parameter::ListInt(l) => dict
271                .set_item(k, l)
272                .context("Could not insert metadata into python dictionary")?,
273            Parameter::ListFloat(l) => dict
274                .set_item(k, l)
275                .context("Could not insert metadata into python dictionary")?,
276            Parameter::ListString(l) => dict
277                .set_item(k, l)
278                .context("Could not insert metadata into python dictionary")?,
279        }
280    }
281
282    Ok(dict)
283}
284
285#[cfg(test)]
286mod tests {
287    use std::sync::Arc;
288
289    use aligned_vec::{AVec, ConstAlign};
290    use arrow::{
291        array::{
292            ArrayData, ArrayRef, BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array,
293            ListArray, StructArray,
294        },
295        buffer::Buffer,
296    };
297
298    use arrow_schema::{DataType, Field};
299    use dora_node_api::{
300        arrow_utils::{copy_array_into_sample, required_data_size},
301        RawData,
302    };
303    use eyre::{Context, Result};
304
305    fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
306        let size = required_data_size(arrow_array);
307        let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
308
309        let info = copy_array_into_sample(&mut sample, arrow_array);
310
311        let serialized_deserialized_arrow_array = RawData::Vec(sample)
312            .into_arrow_array(&info)
313            .context("Could not create arrow array")?;
314        assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
315
316        Ok(())
317    }
318
319    #[test]
320    fn serialize_deserialize_arrow() -> Result<()> {
321        // Int8
322        let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
323        assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
324
325        // Int64
326        let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
327        assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
328
329        // Float64
330        let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
331        assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
332
333        // Struct
334        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
335        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
336
337        let struct_array = StructArray::from(vec![
338            (
339                Arc::new(Field::new("b", DataType::Boolean, false)),
340                boolean as ArrayRef,
341            ),
342            (
343                Arc::new(Field::new("c", DataType::Int32, false)),
344                int as ArrayRef,
345            ),
346        ])
347        .into();
348        assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
349
350        // List
351        let value_data = ArrayData::builder(DataType::Int32)
352            .len(8)
353            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
354            .build()
355            .unwrap();
356
357        // Construct a buffer for value offsets, for the nested array:
358        //  [[0, 1, 2], [3, 4, 5], [6, 7]]
359        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
360
361        // Construct a list array from the above two
362        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
363        let list_data = ArrayData::builder(list_data_type)
364            .len(3)
365            .add_buffer(value_offsets)
366            .add_child_data(value_data)
367            .build()
368            .unwrap();
369        let list_array = ListArray::from(list_data).into();
370        assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
371
372        Ok(())
373    }
374}