dora_operator_api_python/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8    DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
9    merged::{MergeExternalSend, MergedEvent},
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15    prelude::*,
16    types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple},
17};
18use std::time::{SystemTime, UNIX_EPOCH};
19
20/// Dora Event
21pub struct PyEvent {
22    pub event: MergedEvent<PyObject>,
23}
24
25/// Keeps the dora node alive until all event objects have been dropped.
26#[derive(Clone)]
27#[pyclass]
28pub struct NodeCleanupHandle {
29    pub _handles: Arc<CleanupHandle<DoraNode>>,
30}
31
32/// Owned type with delayed cleanup (using `handle` method).
33pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
34
35impl<T> DelayedCleanup<T> {
36    pub fn new(value: T) -> Self {
37        Self(Arc::new(Mutex::new(value)))
38    }
39
40    pub fn handle(&self) -> CleanupHandle<T> {
41        CleanupHandle(self.0.clone())
42    }
43
44    pub fn get_mut(&self) -> std::sync::MutexGuard<T> {
45        self.0.try_lock().expect("failed to lock DelayedCleanup")
46    }
47}
48
49impl Stream for DelayedCleanup<EventStream> {
50    type Item = Event;
51
52    fn poll_next(
53        self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55    ) -> std::task::Poll<Option<Self::Item>> {
56        let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
57        inner.poll_next_unpin(cx)
58    }
59}
60
61impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
62where
63    E: 'static,
64{
65    type Item = MergedEvent<E>;
66
67    fn merge_external_send(
68        self,
69        external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
70    ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
71        let dora = self.map(MergedEvent::Dora);
72        let external = external_events.map(MergedEvent::External);
73        Box::new((dora, external).merge())
74    }
75}
76
77#[allow(dead_code)]
78pub struct CleanupHandle<T>(Arc<Mutex<T>>);
79
80impl PyEvent {
81    pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
82        let mut pydict = HashMap::new();
83        match &self.event {
84            MergedEvent::Dora(_) => pydict.insert(
85                "kind",
86                "dora"
87                    .into_pyobject(py)
88                    .context("Failed to create pystring")?
89                    .unbind()
90                    .into(),
91            ),
92            MergedEvent::External(_) => pydict.insert(
93                "kind",
94                "external"
95                    .into_pyobject(py)
96                    .context("Failed to create pystring")?
97                    .unbind()
98                    .into(),
99            ),
100        };
101        match &self.event {
102            MergedEvent::Dora(event) => {
103                if let Some(id) = Self::id(event) {
104                    pydict.insert(
105                        "id",
106                        id.into_pyobject(py)
107                            .context("Failed to create id pyobject")?
108                            .into(),
109                    );
110                }
111                pydict.insert(
112                    "type",
113                    Self::ty(event)
114                        .into_pyobject(py)
115                        .context("Failed to create event pyobject")?
116                        .unbind()
117                        .into(),
118                );
119
120                if let Some(value) = self.value(py)? {
121                    pydict.insert("value", value);
122                }
123                if let Some(metadata) = Self::metadata(event, py)? {
124                    pydict.insert("metadata", metadata);
125                }
126                if let Some(error) = Self::error(event) {
127                    pydict.insert(
128                        "error",
129                        error
130                            .into_pyobject(py)
131                            .context("Failed to create error pyobject")?
132                            .unbind()
133                            .into(),
134                    );
135                }
136            }
137            MergedEvent::External(event) => {
138                pydict.insert("value", event.clone_ref(py));
139            }
140        }
141
142        Ok(pydict
143            .into_py_dict(py)
144            .context("Failed to create py_dict")?
145            .unbind())
146    }
147
148    fn ty(event: &Event) -> &str {
149        match event {
150            Event::Stop(_) => "STOP",
151            Event::Input { .. } => "INPUT",
152            Event::InputClosed { .. } => "INPUT_CLOSED",
153            Event::Error(_) => "ERROR",
154            _other => "UNKNOWN",
155        }
156    }
157
158    fn id(event: &Event) -> Option<&str> {
159        match event {
160            Event::Input { id, .. } => Some(id),
161            Event::InputClosed { id } => Some(id),
162            Event::Stop(cause) => match cause {
163                StopCause::Manual => Some("MANUAL"),
164                StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
165                &_ => None,
166            },
167            _ => None,
168        }
169    }
170
171    /// Returns the payload of an input event as an arrow array (if any).
172    fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
173        match &self.event {
174            MergedEvent::Dora(Event::Input { data, .. }) => {
175                // TODO: Does this call leak data?&
176                let array_data = data.to_data().to_pyarrow(py)?;
177                Ok(Some(array_data))
178            }
179            _ => Ok(None),
180        }
181    }
182
183    fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
184        match event {
185            Event::Input { metadata, .. } => Ok(Some(
186                metadata_to_pydict(metadata, py)
187                    .context("Issue deserializing metadata")?
188                    .into_pyobject(py)
189                    .context("Failed to create metadata_to_pydice")?
190                    .unbind()
191                    .into(),
192            )),
193            _ => Ok(None),
194        }
195    }
196
197    fn error(event: &Event) -> Option<&str> {
198        match event {
199            Event::Error(error) => Some(error),
200            _other => None,
201        }
202    }
203}
204
205pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
206    let mut parameters = BTreeMap::default();
207    if let Some(pymetadata) = dict {
208        for (key, value) in pymetadata.iter() {
209            let key = key.extract::<String>().context("Parsing metadata keys")?;
210            if value.is_exact_instance_of::<PyBool>() {
211                parameters.insert(key, Parameter::Bool(value.extract()?))
212            } else if value.is_instance_of::<PyInt>() {
213                parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
214            } else if value.is_instance_of::<PyFloat>() {
215                parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
216            } else if value.is_instance_of::<PyString>() {
217                parameters.insert(key, Parameter::String(value.extract()?))
218            } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
219                && value.len()? > 0
220                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
221            {
222                let list: Vec<i64> = value.extract()?;
223                parameters.insert(key, Parameter::ListInt(list))
224            } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
225                && value.len()? > 0
226                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
227            {
228                let list: Vec<f64> = value.extract()?;
229                parameters.insert(key, Parameter::ListFloat(list))
230            } else if value.is_instance_of::<PyList>()
231                && value.len()? > 0
232                && value.get_item(0)?.is_exact_instance_of::<PyString>()
233            {
234                let list: Vec<String> = value.extract()?;
235                parameters.insert(key, Parameter::ListString(list))
236            } else {
237                println!("could not convert type {value}");
238                parameters.insert(key, Parameter::String(value.str()?.to_string()))
239            };
240        }
241    }
242    Ok(parameters)
243}
244
245pub fn metadata_to_pydict<'a>(
246    metadata: &'a Metadata,
247    py: Python<'a>,
248) -> Result<pyo3::Bound<'a, PyDict>> {
249    let dict = PyDict::new(py);
250
251    // Add timestamp as timezone-aware Python datetime (UTC)
252    // Note: uhlc::Timestamp is a Hybrid Logical Clock. We use get_time().to_system_time()
253    // which extracts the physical clock component. This pattern is used consistently
254    // throughout the dora codebase (e.g., in binaries/daemon/src/log.rs, binaries/coordinator/src/lib.rs)
255    // and assumes the physical time component represents UTC wall-clock time.
256    let timestamp = metadata.timestamp();
257    let system_time = timestamp.get_time().to_system_time();
258    let duration_since_epoch = system_time
259        .duration_since(UNIX_EPOCH)
260        .context("Failed to calculate duration since epoch")?;
261
262    // Extract seconds and microseconds (Python datetime supports microsecond precision)
263    let seconds = duration_since_epoch.as_secs() as i64;
264    let microseconds = duration_since_epoch.subsec_micros() as u32;
265
266    // Get UTC timezone from Python's datetime module and create timezone-aware datetime
267    // We use Python's datetime.fromtimestamp() to create a UTC-aware datetime object
268    // This avoids float precision loss by using integer seconds and microseconds
269    let datetime_module =
270        PyModule::import(py, "datetime").context("Failed to import datetime module")?;
271    let datetime_class = datetime_module.getattr("datetime")?;
272    let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
273
274    // Create timezone-aware datetime using fromtimestamp
275    // We compute total_seconds as float (required by fromtimestamp) but preserve
276    // precision by computing from integer seconds and microseconds separately
277    let total_seconds = seconds as f64 + microseconds as f64 / 1_000_000.0;
278    let py_datetime = datetime_class
279        .call_method1("fromtimestamp", (total_seconds, utc_timezone))
280        .context("Failed to create Python datetime from timestamp")?;
281
282    dict.set_item("timestamp", py_datetime)
283        .context("Could not insert timestamp into python dictionary")?;
284
285    // Add existing parameters
286    for (k, v) in metadata.parameters.iter() {
287        match v {
288            Parameter::Bool(bool) => dict
289                .set_item(k, bool)
290                .context("Could not insert metadata into python dictionary")?,
291            Parameter::Integer(int) => dict
292                .set_item(k, int)
293                .context("Could not insert metadata into python dictionary")?,
294            Parameter::Float(float) => dict
295                .set_item(k, float)
296                .context("Could not insert metadata into python dictionary")?,
297            Parameter::String(s) => dict
298                .set_item(k, s)
299                .context("Could not insert metadata into python dictionary")?,
300            Parameter::ListInt(l) => dict
301                .set_item(k, l)
302                .context("Could not insert metadata into python dictionary")?,
303            Parameter::ListFloat(l) => dict
304                .set_item(k, l)
305                .context("Could not insert metadata into python dictionary")?,
306            Parameter::ListString(l) => dict
307                .set_item(k, l)
308                .context("Could not insert metadata into python dictionary")?,
309        }
310    }
311
312    Ok(dict)
313}
314
315#[cfg(test)]
316mod tests {
317    use std::{ptr::NonNull, sync::Arc};
318
319    use aligned_vec::{AVec, ConstAlign};
320    use arrow::{
321        array::{
322            ArrayData, ArrayRef, BooleanArray, Float64Array, Int8Array, Int32Array, Int64Array,
323            ListArray, StructArray,
324        },
325        buffer::Buffer,
326    };
327
328    use arrow_schema::{DataType, Field};
329    use dora_node_api::arrow_utils::{
330        buffer_into_arrow_array, copy_array_into_sample, required_data_size,
331    };
332    use eyre::{Context, Result};
333
334    fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
335        let size = required_data_size(arrow_array);
336        let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
337
338        let info = copy_array_into_sample(&mut sample, arrow_array);
339
340        let serialized_deserialized_arrow_array = {
341            let ptr = NonNull::new(sample.as_ptr() as *mut _).unwrap();
342            let len = sample.len();
343
344            let raw_buffer = unsafe {
345                arrow::buffer::Buffer::from_custom_allocation(ptr, len, Arc::new(sample))
346            };
347            buffer_into_arrow_array(&raw_buffer, &info)?
348        };
349
350        assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
351
352        Ok(())
353    }
354
355    #[test]
356    fn serialize_deserialize_arrow() -> Result<()> {
357        // Int8
358        let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
359        assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
360
361        // Int64
362        let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
363        assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
364
365        // Float64
366        let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
367        assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
368
369        // Struct
370        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
371        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
372
373        let struct_array = StructArray::from(vec![
374            (
375                Arc::new(Field::new("b", DataType::Boolean, false)),
376                boolean as ArrayRef,
377            ),
378            (
379                Arc::new(Field::new("c", DataType::Int32, false)),
380                int as ArrayRef,
381            ),
382        ])
383        .into();
384        assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
385
386        // List
387        let value_data = ArrayData::builder(DataType::Int32)
388            .len(8)
389            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
390            .build()
391            .unwrap();
392
393        // Construct a buffer for value offsets, for the nested array:
394        //  [[0, 1, 2], [3, 4, 5], [6, 7]]
395        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
396
397        // Construct a list array from the above two
398        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
399        let list_data = ArrayData::builder(list_data_type)
400            .len(3)
401            .add_buffer(value_offsets)
402            .add_child_data(value_data)
403            .build()
404            .unwrap();
405        let list_array = ListArray::from(list_data).into();
406        assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
407
408        Ok(())
409    }
410}