dora_operator_api_python/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8    merged::{MergeExternalSend, MergedEvent},
9    DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15    prelude::*,
16    types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
17};
18
19/// Dora Event
20pub struct PyEvent {
21    pub event: MergedEvent<PyObject>,
22}
23
24/// Keeps the dora node alive until all event objects have been dropped.
25#[derive(Clone)]
26#[pyclass]
27pub struct NodeCleanupHandle {
28    pub _handles: Arc<CleanupHandle<DoraNode>>,
29}
30
31/// Owned type with delayed cleanup (using `handle` method).
32pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
33
34impl<T> DelayedCleanup<T> {
35    pub fn new(value: T) -> Self {
36        Self(Arc::new(Mutex::new(value)))
37    }
38
39    pub fn handle(&self) -> CleanupHandle<T> {
40        CleanupHandle(self.0.clone())
41    }
42
43    pub fn get_mut(&mut self) -> std::sync::MutexGuard<T> {
44        self.0.try_lock().expect("failed to lock DelayedCleanup")
45    }
46}
47
48impl Stream for DelayedCleanup<EventStream> {
49    type Item = Event;
50
51    fn poll_next(
52        self: std::pin::Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54    ) -> std::task::Poll<Option<Self::Item>> {
55        let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
56        inner.poll_next_unpin(cx)
57    }
58}
59
60impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
61where
62    E: 'static,
63{
64    type Item = MergedEvent<E>;
65
66    fn merge_external_send(
67        self,
68        external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
69    ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
70        let dora = self.map(MergedEvent::Dora);
71        let external = external_events.map(MergedEvent::External);
72        Box::new((dora, external).merge())
73    }
74}
75
76#[allow(dead_code)]
77pub struct CleanupHandle<T>(Arc<Mutex<T>>);
78
79impl PyEvent {
80    pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
81        let mut pydict = HashMap::new();
82        match &self.event {
83            MergedEvent::Dora(_) => pydict.insert(
84                "kind",
85                "dora"
86                    .into_pyobject(py)
87                    .context("Failed to create pystring")?
88                    .unbind()
89                    .into(),
90            ),
91            MergedEvent::External(_) => pydict.insert(
92                "kind",
93                "external"
94                    .into_pyobject(py)
95                    .context("Failed to create pystring")?
96                    .unbind()
97                    .into(),
98            ),
99        };
100        match &self.event {
101            MergedEvent::Dora(event) => {
102                if let Some(id) = Self::id(event) {
103                    pydict.insert(
104                        "id",
105                        id.into_pyobject(py)
106                            .context("Failed to create id pyobject")?
107                            .into(),
108                    );
109                }
110                pydict.insert(
111                    "type",
112                    Self::ty(event)
113                        .into_pyobject(py)
114                        .context("Failed to create event pyobject")?
115                        .unbind()
116                        .into(),
117                );
118
119                if let Some(value) = self.value(py)? {
120                    pydict.insert("value", value);
121                }
122                if let Some(metadata) = Self::metadata(event, py)? {
123                    pydict.insert("metadata", metadata);
124                }
125                if let Some(error) = Self::error(event) {
126                    pydict.insert(
127                        "error",
128                        error
129                            .into_pyobject(py)
130                            .context("Failed to create error pyobject")?
131                            .unbind()
132                            .into(),
133                    );
134                }
135            }
136            MergedEvent::External(event) => {
137                pydict.insert("value", event.clone_ref(py));
138            }
139        }
140
141        Ok(pydict
142            .into_py_dict(py)
143            .context("Failed to create py_dict")?
144            .unbind())
145    }
146
147    fn ty(event: &Event) -> &str {
148        match event {
149            Event::Stop(_) => "STOP",
150            Event::Input { .. } => "INPUT",
151            Event::InputClosed { .. } => "INPUT_CLOSED",
152            Event::Error(_) => "ERROR",
153            _other => "UNKNOWN",
154        }
155    }
156
157    fn id(event: &Event) -> Option<&str> {
158        match event {
159            Event::Input { id, .. } => Some(id),
160            Event::InputClosed { id } => Some(id),
161            Event::Stop(cause) => match cause {
162                StopCause::Manual => Some("MANUAL"),
163                StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
164                &_ => None,
165            },
166            _ => None,
167        }
168    }
169
170    /// Returns the payload of an input event as an arrow array (if any).
171    fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
172        match &self.event {
173            MergedEvent::Dora(Event::Input { data, .. }) => {
174                // TODO: Does this call leak data?&
175                let array_data = data.to_data().to_pyarrow(py)?;
176                Ok(Some(array_data))
177            }
178            _ => Ok(None),
179        }
180    }
181
182    fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
183        match event {
184            Event::Input { metadata, .. } => Ok(Some(
185                metadata_to_pydict(metadata, py)
186                    .context("Issue deserializing metadata")?
187                    .into_pyobject(py)
188                    .context("Failed to create metadata_to_pydice")?
189                    .unbind()
190                    .into(),
191            )),
192            _ => Ok(None),
193        }
194    }
195
196    fn error(event: &Event) -> Option<&str> {
197        match event {
198            Event::Error(error) => Some(error),
199            _other => None,
200        }
201    }
202}
203
204pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
205    let mut parameters = BTreeMap::default();
206    if let Some(pymetadata) = dict {
207        for (key, value) in pymetadata.iter() {
208            let key = key.extract::<String>().context("Parsing metadata keys")?;
209            if value.is_exact_instance_of::<PyBool>() {
210                parameters.insert(key, Parameter::Bool(value.extract()?))
211            } else if value.is_instance_of::<PyInt>() {
212                parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
213            } else if value.is_instance_of::<PyFloat>() {
214                parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
215            } else if value.is_instance_of::<PyString>() {
216                parameters.insert(key, Parameter::String(value.extract()?))
217            } else if value.is_instance_of::<PyTuple>()
218                && value.len()? > 0
219                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
220            {
221                let list: Vec<i64> = value.extract()?;
222                parameters.insert(key, Parameter::ListInt(list))
223            } else if value.is_instance_of::<PyList>()
224                && value.len()? > 0
225                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
226            {
227                let list: Vec<i64> = value.extract()?;
228                parameters.insert(key, Parameter::ListInt(list))
229            } else if value.is_instance_of::<PyTuple>()
230                && value.len()? > 0
231                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
232            {
233                let list: Vec<f64> = value.extract()?;
234                parameters.insert(key, Parameter::ListFloat(list))
235            } else if value.is_instance_of::<PyList>()
236                && value.len()? > 0
237                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
238            {
239                let list: Vec<f64> = value.extract()?;
240                parameters.insert(key, Parameter::ListFloat(list))
241            } else if value.is_instance_of::<PyList>()
242                && value.len()? > 0
243                && value.get_item(0)?.is_exact_instance_of::<PyString>()
244            {
245                let list: Vec<String> = value.extract()?;
246                parameters.insert(key, Parameter::ListString(list))
247            } else {
248                println!("could not convert type {value}");
249                parameters.insert(key, Parameter::String(value.str()?.to_string()))
250            };
251        }
252    }
253    Ok(parameters)
254}
255
256pub fn metadata_to_pydict<'a>(
257    metadata: &'a Metadata,
258    py: Python<'a>,
259) -> Result<pyo3::Bound<'a, PyDict>> {
260    let dict = PyDict::new(py);
261    for (k, v) in metadata.parameters.iter() {
262        match v {
263            Parameter::Bool(bool) => dict
264                .set_item(k, bool)
265                .context("Could not insert metadata into python dictionary")?,
266            Parameter::Integer(int) => dict
267                .set_item(k, int)
268                .context("Could not insert metadata into python dictionary")?,
269            Parameter::Float(float) => dict
270                .set_item(k, float)
271                .context("Could not insert metadata into python dictionary")?,
272            Parameter::String(s) => dict
273                .set_item(k, s)
274                .context("Could not insert metadata into python dictionary")?,
275            Parameter::ListInt(l) => dict
276                .set_item(k, l)
277                .context("Could not insert metadata into python dictionary")?,
278            Parameter::ListFloat(l) => dict
279                .set_item(k, l)
280                .context("Could not insert metadata into python dictionary")?,
281            Parameter::ListString(l) => dict
282                .set_item(k, l)
283                .context("Could not insert metadata into python dictionary")?,
284        }
285    }
286
287    Ok(dict)
288}
289
290#[cfg(test)]
291mod tests {
292    use std::sync::Arc;
293
294    use aligned_vec::{AVec, ConstAlign};
295    use arrow::{
296        array::{
297            ArrayData, ArrayRef, BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array,
298            ListArray, StructArray,
299        },
300        buffer::Buffer,
301    };
302
303    use arrow_schema::{DataType, Field};
304    use dora_node_api::{
305        arrow_utils::{copy_array_into_sample, required_data_size},
306        RawData,
307    };
308    use eyre::{Context, Result};
309
310    fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
311        let size = required_data_size(arrow_array);
312        let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
313
314        let info = copy_array_into_sample(&mut sample, arrow_array);
315
316        let serialized_deserialized_arrow_array = RawData::Vec(sample)
317            .into_arrow_array(&info)
318            .context("Could not create arrow array")?;
319        assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
320
321        Ok(())
322    }
323
324    #[test]
325    fn serialize_deserialize_arrow() -> Result<()> {
326        // Int8
327        let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
328        assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
329
330        // Int64
331        let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
332        assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
333
334        // Float64
335        let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
336        assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
337
338        // Struct
339        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
340        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
341
342        let struct_array = StructArray::from(vec![
343            (
344                Arc::new(Field::new("b", DataType::Boolean, false)),
345                boolean as ArrayRef,
346            ),
347            (
348                Arc::new(Field::new("c", DataType::Int32, false)),
349                int as ArrayRef,
350            ),
351        ])
352        .into();
353        assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
354
355        // List
356        let value_data = ArrayData::builder(DataType::Int32)
357            .len(8)
358            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
359            .build()
360            .unwrap();
361
362        // Construct a buffer for value offsets, for the nested array:
363        //  [[0, 1, 2], [3, 4, 5], [6, 7]]
364        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
365
366        // Construct a list array from the above two
367        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
368        let list_data = ArrayData::builder(list_data_type)
369            .len(3)
370            .add_buffer(value_offsets)
371            .add_child_data(value_data)
372            .build()
373            .unwrap();
374        let list_array = ListArray::from(list_data).into();
375        assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
376
377        Ok(())
378    }
379}