Skip to main content

dora_operator_api_python/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use chrono::{DateTime, Utc};
8use dora_node_api::{
9    DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
10    merged::{MergeExternalSend, MergedEvent},
11};
12use eyre::{Context, Result};
13use futures::{Stream, StreamExt};
14use futures_concurrency::stream::Merge as _;
15use pyo3::{
16    prelude::*,
17    types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple},
18};
19use std::time::UNIX_EPOCH;
20
21/// Dora Event
22pub struct PyEvent {
23    pub event: MergedEvent<PyObject>,
24}
25
26/// Keeps the dora node alive until all event objects have been dropped.
27#[derive(Clone)]
28#[pyclass]
29pub struct NodeCleanupHandle {
30    pub _handles: Arc<CleanupHandle<DoraNode>>,
31}
32
33/// Owned type with delayed cleanup (using `handle` method).
34pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
35
36impl<T> DelayedCleanup<T> {
37    pub fn new(value: T) -> Self {
38        Self(Arc::new(Mutex::new(value)))
39    }
40
41    pub fn handle(&self) -> CleanupHandle<T> {
42        CleanupHandle(self.0.clone())
43    }
44
45    pub fn get_mut(&self) -> std::sync::MutexGuard<T> {
46        self.0.try_lock().expect("failed to lock DelayedCleanup")
47    }
48}
49
50impl Stream for DelayedCleanup<EventStream> {
51    type Item = Event;
52
53    fn poll_next(
54        self: std::pin::Pin<&mut Self>,
55        cx: &mut std::task::Context<'_>,
56    ) -> std::task::Poll<Option<Self::Item>> {
57        let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
58        inner.poll_next_unpin(cx)
59    }
60}
61
62impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
63where
64    E: 'static,
65{
66    type Item = MergedEvent<E>;
67
68    fn merge_external_send(
69        self,
70        external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
71    ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
72        let dora = self.map(MergedEvent::Dora);
73        let external = external_events.map(MergedEvent::External);
74        Box::new((dora, external).merge())
75    }
76}
77
78#[allow(dead_code)]
79pub struct CleanupHandle<T>(Arc<Mutex<T>>);
80
81impl PyEvent {
82    pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
83        let mut pydict = HashMap::new();
84        match &self.event {
85            MergedEvent::Dora(_) => pydict.insert(
86                "kind",
87                "dora"
88                    .into_pyobject(py)
89                    .context("Failed to create pystring")?
90                    .unbind()
91                    .into(),
92            ),
93            MergedEvent::External(_) => pydict.insert(
94                "kind",
95                "external"
96                    .into_pyobject(py)
97                    .context("Failed to create pystring")?
98                    .unbind()
99                    .into(),
100            ),
101        };
102        match &self.event {
103            MergedEvent::Dora(event) => {
104                if let Some(id) = Self::id(event) {
105                    pydict.insert(
106                        "id",
107                        id.into_pyobject(py)
108                            .context("Failed to create id pyobject")?
109                            .into(),
110                    );
111                }
112                pydict.insert(
113                    "type",
114                    Self::ty(event)
115                        .into_pyobject(py)
116                        .context("Failed to create event pyobject")?
117                        .unbind()
118                        .into(),
119                );
120
121                if let Some(value) = self.value(py)? {
122                    pydict.insert("value", value);
123                }
124                if let Some(metadata) = Self::metadata(event, py)? {
125                    pydict.insert("metadata", metadata);
126                }
127                if let Some(error) = Self::error(event) {
128                    pydict.insert(
129                        "error",
130                        error
131                            .into_pyobject(py)
132                            .context("Failed to create error pyobject")?
133                            .unbind()
134                            .into(),
135                    );
136                }
137            }
138            MergedEvent::External(event) => {
139                pydict.insert("value", event.clone_ref(py));
140            }
141        }
142
143        Ok(pydict
144            .into_py_dict(py)
145            .context("Failed to create py_dict")?
146            .unbind())
147    }
148
149    fn ty(event: &Event) -> &str {
150        match event {
151            Event::Stop(_) => "STOP",
152            Event::Input { .. } => "INPUT",
153            Event::InputClosed { .. } => "INPUT_CLOSED",
154            Event::Error(_) => "ERROR",
155            _other => "UNKNOWN",
156        }
157    }
158
159    fn id(event: &Event) -> Option<&str> {
160        match event {
161            Event::Input { id, .. } => Some(id),
162            Event::InputClosed { id } => Some(id),
163            Event::Stop(cause) => match cause {
164                StopCause::Manual => Some("MANUAL"),
165                StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
166                &_ => None,
167            },
168            _ => None,
169        }
170    }
171
172    /// Returns the payload of an input event as an arrow array (if any).
173    fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
174        match &self.event {
175            MergedEvent::Dora(Event::Input { data, .. }) => {
176                // TODO: Does this call leak data?&
177                let array_data = data.to_data().to_pyarrow(py)?;
178                Ok(Some(array_data))
179            }
180            _ => Ok(None),
181        }
182    }
183
184    fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
185        match event {
186            Event::Input { metadata, .. } => Ok(Some(
187                metadata_to_pydict(metadata, py)
188                    .context("Issue deserializing metadata")?
189                    .into_pyobject(py)
190                    .context("Failed to create metadata_to_pydice")?
191                    .unbind()
192                    .into(),
193            )),
194            _ => Ok(None),
195        }
196    }
197
198    fn error(event: &Event) -> Option<&str> {
199        match event {
200            Event::Error(error) => Some(error),
201            _other => None,
202        }
203    }
204}
205
206pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
207    let mut parameters = BTreeMap::default();
208    if let Some(pymetadata) = dict {
209        for (key, value) in pymetadata.iter() {
210            let key = key.extract::<String>().context("Parsing metadata keys")?;
211            if value.is_exact_instance_of::<PyBool>() {
212                parameters.insert(key, Parameter::Bool(value.extract()?))
213            } else if value.is_instance_of::<PyInt>() {
214                parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
215            } else if value.is_instance_of::<PyFloat>() {
216                parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
217            } else if value.is_instance_of::<PyString>() {
218                parameters.insert(key, Parameter::String(value.extract()?))
219            } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
220                && value.len()? > 0
221                && value.get_item(0)?.is_exact_instance_of::<PyInt>()
222            {
223                let list: Vec<i64> = value.extract()?;
224                parameters.insert(key, Parameter::ListInt(list))
225            } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
226                && value.len()? > 0
227                && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
228            {
229                let list: Vec<f64> = value.extract()?;
230                parameters.insert(key, Parameter::ListFloat(list))
231            } else if value.is_instance_of::<PyList>()
232                && value.len()? > 0
233                && value.get_item(0)?.is_exact_instance_of::<PyString>()
234            {
235                let list: Vec<String> = value.extract()?;
236                parameters.insert(key, Parameter::ListString(list))
237            } else {
238                // Check if it's a datetime.datetime object
239                let datetime_module = PyModule::import(value.py(), "datetime")
240                    .context("Failed to import datetime module")?;
241                let datetime_class = datetime_module.getattr("datetime")?;
242
243                if value.is_instance(datetime_class.as_ref())? {
244                    // Extract timestamp using timestamp() method
245                    let timestamp_float: f64 = value
246                        .call_method0("timestamp")?
247                        .extract()
248                        .context("Failed to extract timestamp from datetime")?;
249
250                    // Convert to chrono::DateTime<Utc>
251                    // timestamp() returns seconds since epoch as float
252                    // Convert to SystemTime first, then to DateTime<Utc>
253                    let system_time = if timestamp_float >= 0.0 {
254                        let duration = std::time::Duration::try_from_secs_f64(timestamp_float)
255                            .context("Failed to convert timestamp to Duration")?;
256                        UNIX_EPOCH + duration
257                    } else {
258                        let duration = std::time::Duration::try_from_secs_f64(-timestamp_float)
259                            .context("Failed to convert timestamp to Duration")?;
260                        UNIX_EPOCH.checked_sub(duration).unwrap_or(UNIX_EPOCH)
261                    };
262
263                    let dt = DateTime::<Utc>::from(system_time);
264
265                    parameters.insert(key, Parameter::Timestamp(dt))
266                } else {
267                    println!("could not convert type {value}");
268                    parameters.insert(key, Parameter::String(value.str()?.to_string()))
269                }
270            };
271        }
272    }
273    Ok(parameters)
274}
275
276pub fn metadata_to_pydict<'a>(
277    metadata: &'a Metadata,
278    py: Python<'a>,
279) -> Result<pyo3::Bound<'a, PyDict>> {
280    let dict = PyDict::new(py);
281
282    // Add timestamp as timezone-aware Python datetime (UTC)
283    // Note: uhlc::Timestamp is a Hybrid Logical Clock. We use get_time().to_system_time()
284    // which extracts the physical clock component. This pattern is used consistently
285    // throughout the dora codebase (e.g., in binaries/daemon/src/log.rs, binaries/coordinator/src/lib.rs)
286    // and assumes the physical time component represents UTC wall-clock time.
287    let timestamp = metadata.timestamp();
288    let system_time = timestamp.get_time().to_system_time();
289    let duration_since_epoch = system_time
290        .duration_since(UNIX_EPOCH)
291        .context("Failed to calculate duration since epoch")?;
292
293    // Extract seconds and microseconds (Python datetime supports microsecond precision)
294    let seconds = duration_since_epoch.as_secs() as i64;
295    let microseconds = duration_since_epoch.subsec_micros() as u32;
296
297    // Get UTC timezone from Python's datetime module and create timezone-aware datetime
298    // We use Python's datetime.fromtimestamp() to create a UTC-aware datetime object
299    // This avoids float precision loss by using integer seconds and microseconds
300    let datetime_module =
301        PyModule::import(py, "datetime").context("Failed to import datetime module")?;
302    let datetime_class = datetime_module.getattr("datetime")?;
303    let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
304
305    // Create timezone-aware datetime using fromtimestamp
306    // We compute total_seconds as float (required by fromtimestamp) but preserve
307    // precision by computing from integer seconds and microseconds separately
308    let total_seconds = seconds as f64 + microseconds as f64 / 1_000_000.0;
309    let py_datetime = datetime_class
310        .call_method1("fromtimestamp", (total_seconds, utc_timezone))
311        .context("Failed to create Python datetime from timestamp")?;
312
313    dict.set_item("timestamp", py_datetime)
314        .context("Could not insert timestamp into python dictionary")?;
315
316    // Add existing parameters
317    for (k, v) in metadata.parameters.iter() {
318        match v {
319            Parameter::Bool(bool) => dict
320                .set_item(k, bool)
321                .context("Could not insert metadata into python dictionary")?,
322            Parameter::Integer(int) => dict
323                .set_item(k, int)
324                .context("Could not insert metadata into python dictionary")?,
325            Parameter::Float(float) => dict
326                .set_item(k, float)
327                .context("Could not insert metadata into python dictionary")?,
328            Parameter::String(s) => dict
329                .set_item(k, s)
330                .context("Could not insert metadata into python dictionary")?,
331            Parameter::ListInt(l) => dict
332                .set_item(k, l)
333                .context("Could not insert metadata into python dictionary")?,
334            Parameter::ListFloat(l) => dict
335                .set_item(k, l)
336                .context("Could not insert metadata into python dictionary")?,
337            Parameter::ListString(l) => dict
338                .set_item(k, l)
339                .context("Could not insert metadata into python dictionary")?,
340            Parameter::Timestamp(dt) => {
341                // Convert chrono::DateTime<Utc> to Python datetime.datetime
342                let timestamp = dt.timestamp();
343                let microseconds = dt.timestamp_subsec_micros();
344
345                // Get UTC timezone from Python's datetime module
346                let datetime_module =
347                    PyModule::import(py, "datetime").context("Failed to import datetime module")?;
348                let datetime_class = datetime_module.getattr("datetime")?;
349                let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
350
351                // Create timezone-aware datetime using fromtimestamp
352                let total_seconds = timestamp as f64 + microseconds as f64 / 1_000_000.0;
353                let py_datetime = datetime_class
354                    .call_method1("fromtimestamp", (total_seconds, utc_timezone))
355                    .context("Failed to create Python datetime from timestamp")?;
356
357                dict.set_item(k, py_datetime)
358                    .context("Could not insert timestamp into python dictionary")?
359            }
360        }
361    }
362
363    Ok(dict)
364}
365
366#[cfg(test)]
367mod tests {
368    use std::{ptr::NonNull, sync::Arc};
369
370    use aligned_vec::{AVec, ConstAlign};
371    use arrow::{
372        array::{
373            ArrayData, ArrayRef, BooleanArray, Float64Array, Int8Array, Int32Array, Int64Array,
374            ListArray, StructArray,
375        },
376        buffer::Buffer,
377    };
378
379    use arrow_schema::{DataType, Field};
380    use dora_node_api::arrow_utils::{
381        buffer_into_arrow_array, copy_array_into_sample, required_data_size,
382    };
383    use eyre::{Context, Result};
384
385    fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
386        let size = required_data_size(arrow_array);
387        let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
388
389        let info = copy_array_into_sample(&mut sample, arrow_array);
390
391        let serialized_deserialized_arrow_array = {
392            let ptr = NonNull::new(sample.as_ptr() as *mut _).unwrap();
393            let len = sample.len();
394
395            let raw_buffer = unsafe {
396                arrow::buffer::Buffer::from_custom_allocation(ptr, len, Arc::new(sample))
397            };
398            buffer_into_arrow_array(&raw_buffer, &info)?
399        };
400
401        assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
402
403        Ok(())
404    }
405
406    #[test]
407    fn serialize_deserialize_arrow() -> Result<()> {
408        // Int8
409        let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
410        assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
411
412        // Int64
413        let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
414        assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
415
416        // Float64
417        let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
418        assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
419
420        // Struct
421        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
422        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
423
424        let struct_array = StructArray::from(vec![
425            (
426                Arc::new(Field::new("b", DataType::Boolean, false)),
427                boolean as ArrayRef,
428            ),
429            (
430                Arc::new(Field::new("c", DataType::Int32, false)),
431                int as ArrayRef,
432            ),
433        ])
434        .into();
435        assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
436
437        // List
438        let value_data = ArrayData::builder(DataType::Int32)
439            .len(8)
440            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
441            .build()
442            .unwrap();
443
444        // Construct a buffer for value offsets, for the nested array:
445        //  [[0, 1, 2], [3, 4, 5], [6, 7]]
446        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
447
448        // Construct a list array from the above two
449        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
450        let list_data = ArrayData::builder(list_data_type)
451            .len(3)
452            .add_buffer(value_offsets)
453            .add_child_data(value_data)
454            .build()
455            .unwrap();
456        let list_array = ListArray::from(list_data).into();
457        assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
458
459        Ok(())
460    }
461}