1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use chrono::{DateTime, Utc};
8use dora_node_api::{
9 DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
10 merged::{MergeExternalSend, MergedEvent},
11};
12use eyre::{Context, Result};
13use futures::{Stream, StreamExt};
14use futures_concurrency::stream::Merge as _;
15use pyo3::{
16 prelude::*,
17 types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple},
18};
19use std::time::UNIX_EPOCH;
20
21pub struct PyEvent {
23 pub event: MergedEvent<PyObject>,
24}
25
26#[derive(Clone)]
28#[pyclass]
29pub struct NodeCleanupHandle {
30 pub _handles: Arc<CleanupHandle<DoraNode>>,
31}
32
33pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
35
36impl<T> DelayedCleanup<T> {
37 pub fn new(value: T) -> Self {
38 Self(Arc::new(Mutex::new(value)))
39 }
40
41 pub fn handle(&self) -> CleanupHandle<T> {
42 CleanupHandle(self.0.clone())
43 }
44
45 pub fn get_mut(&self) -> std::sync::MutexGuard<T> {
46 self.0.try_lock().expect("failed to lock DelayedCleanup")
47 }
48}
49
50impl Stream for DelayedCleanup<EventStream> {
51 type Item = Event;
52
53 fn poll_next(
54 self: std::pin::Pin<&mut Self>,
55 cx: &mut std::task::Context<'_>,
56 ) -> std::task::Poll<Option<Self::Item>> {
57 let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
58 inner.poll_next_unpin(cx)
59 }
60}
61
62impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
63where
64 E: 'static,
65{
66 type Item = MergedEvent<E>;
67
68 fn merge_external_send(
69 self,
70 external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
71 ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
72 let dora = self.map(MergedEvent::Dora);
73 let external = external_events.map(MergedEvent::External);
74 Box::new((dora, external).merge())
75 }
76}
77
78#[allow(dead_code)]
79pub struct CleanupHandle<T>(Arc<Mutex<T>>);
80
81impl PyEvent {
82 pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
83 let mut pydict = HashMap::new();
84 match &self.event {
85 MergedEvent::Dora(_) => pydict.insert(
86 "kind",
87 "dora"
88 .into_pyobject(py)
89 .context("Failed to create pystring")?
90 .unbind()
91 .into(),
92 ),
93 MergedEvent::External(_) => pydict.insert(
94 "kind",
95 "external"
96 .into_pyobject(py)
97 .context("Failed to create pystring")?
98 .unbind()
99 .into(),
100 ),
101 };
102 match &self.event {
103 MergedEvent::Dora(event) => {
104 if let Some(id) = Self::id(event) {
105 pydict.insert(
106 "id",
107 id.into_pyobject(py)
108 .context("Failed to create id pyobject")?
109 .into(),
110 );
111 }
112 pydict.insert(
113 "type",
114 Self::ty(event)
115 .into_pyobject(py)
116 .context("Failed to create event pyobject")?
117 .unbind()
118 .into(),
119 );
120
121 if let Some(value) = self.value(py)? {
122 pydict.insert("value", value);
123 }
124 if let Some(metadata) = Self::metadata(event, py)? {
125 pydict.insert("metadata", metadata);
126 }
127 if let Some(error) = Self::error(event) {
128 pydict.insert(
129 "error",
130 error
131 .into_pyobject(py)
132 .context("Failed to create error pyobject")?
133 .unbind()
134 .into(),
135 );
136 }
137 }
138 MergedEvent::External(event) => {
139 pydict.insert("value", event.clone_ref(py));
140 }
141 }
142
143 Ok(pydict
144 .into_py_dict(py)
145 .context("Failed to create py_dict")?
146 .unbind())
147 }
148
149 fn ty(event: &Event) -> &str {
150 match event {
151 Event::Stop(_) => "STOP",
152 Event::Input { .. } => "INPUT",
153 Event::InputClosed { .. } => "INPUT_CLOSED",
154 Event::Error(_) => "ERROR",
155 _other => "UNKNOWN",
156 }
157 }
158
159 fn id(event: &Event) -> Option<&str> {
160 match event {
161 Event::Input { id, .. } => Some(id),
162 Event::InputClosed { id } => Some(id),
163 Event::Stop(cause) => match cause {
164 StopCause::Manual => Some("MANUAL"),
165 StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
166 &_ => None,
167 },
168 _ => None,
169 }
170 }
171
172 fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
174 match &self.event {
175 MergedEvent::Dora(Event::Input { data, .. }) => {
176 let array_data = data.to_data().to_pyarrow(py)?;
178 Ok(Some(array_data))
179 }
180 _ => Ok(None),
181 }
182 }
183
184 fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
185 match event {
186 Event::Input { metadata, .. } => Ok(Some(
187 metadata_to_pydict(metadata, py)
188 .context("Issue deserializing metadata")?
189 .into_pyobject(py)
190 .context("Failed to create metadata_to_pydice")?
191 .unbind()
192 .into(),
193 )),
194 _ => Ok(None),
195 }
196 }
197
198 fn error(event: &Event) -> Option<&str> {
199 match event {
200 Event::Error(error) => Some(error),
201 _other => None,
202 }
203 }
204}
205
206pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
207 let mut parameters = BTreeMap::default();
208 if let Some(pymetadata) = dict {
209 for (key, value) in pymetadata.iter() {
210 let key = key.extract::<String>().context("Parsing metadata keys")?;
211 if value.is_exact_instance_of::<PyBool>() {
212 parameters.insert(key, Parameter::Bool(value.extract()?))
213 } else if value.is_instance_of::<PyInt>() {
214 parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
215 } else if value.is_instance_of::<PyFloat>() {
216 parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
217 } else if value.is_instance_of::<PyString>() {
218 parameters.insert(key, Parameter::String(value.extract()?))
219 } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
220 && value.len()? > 0
221 && value.get_item(0)?.is_exact_instance_of::<PyInt>()
222 {
223 let list: Vec<i64> = value.extract()?;
224 parameters.insert(key, Parameter::ListInt(list))
225 } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
226 && value.len()? > 0
227 && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
228 {
229 let list: Vec<f64> = value.extract()?;
230 parameters.insert(key, Parameter::ListFloat(list))
231 } else if value.is_instance_of::<PyList>()
232 && value.len()? > 0
233 && value.get_item(0)?.is_exact_instance_of::<PyString>()
234 {
235 let list: Vec<String> = value.extract()?;
236 parameters.insert(key, Parameter::ListString(list))
237 } else {
238 let datetime_module = PyModule::import(value.py(), "datetime")
240 .context("Failed to import datetime module")?;
241 let datetime_class = datetime_module.getattr("datetime")?;
242
243 if value.is_instance(datetime_class.as_ref())? {
244 let timestamp_float: f64 = value
246 .call_method0("timestamp")?
247 .extract()
248 .context("Failed to extract timestamp from datetime")?;
249
250 let system_time = if timestamp_float >= 0.0 {
254 let duration = std::time::Duration::try_from_secs_f64(timestamp_float)
255 .context("Failed to convert timestamp to Duration")?;
256 UNIX_EPOCH + duration
257 } else {
258 let duration = std::time::Duration::try_from_secs_f64(-timestamp_float)
259 .context("Failed to convert timestamp to Duration")?;
260 UNIX_EPOCH.checked_sub(duration).unwrap_or(UNIX_EPOCH)
261 };
262
263 let dt = DateTime::<Utc>::from(system_time);
264
265 parameters.insert(key, Parameter::Timestamp(dt))
266 } else {
267 println!("could not convert type {value}");
268 parameters.insert(key, Parameter::String(value.str()?.to_string()))
269 }
270 };
271 }
272 }
273 Ok(parameters)
274}
275
276pub fn metadata_to_pydict<'a>(
277 metadata: &'a Metadata,
278 py: Python<'a>,
279) -> Result<pyo3::Bound<'a, PyDict>> {
280 let dict = PyDict::new(py);
281
282 let timestamp = metadata.timestamp();
288 let system_time = timestamp.get_time().to_system_time();
289 let duration_since_epoch = system_time
290 .duration_since(UNIX_EPOCH)
291 .context("Failed to calculate duration since epoch")?;
292
293 let seconds = duration_since_epoch.as_secs() as i64;
295 let microseconds = duration_since_epoch.subsec_micros() as u32;
296
297 let datetime_module =
301 PyModule::import(py, "datetime").context("Failed to import datetime module")?;
302 let datetime_class = datetime_module.getattr("datetime")?;
303 let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
304
305 let total_seconds = seconds as f64 + microseconds as f64 / 1_000_000.0;
309 let py_datetime = datetime_class
310 .call_method1("fromtimestamp", (total_seconds, utc_timezone))
311 .context("Failed to create Python datetime from timestamp")?;
312
313 dict.set_item("timestamp", py_datetime)
314 .context("Could not insert timestamp into python dictionary")?;
315
316 for (k, v) in metadata.parameters.iter() {
318 match v {
319 Parameter::Bool(bool) => dict
320 .set_item(k, bool)
321 .context("Could not insert metadata into python dictionary")?,
322 Parameter::Integer(int) => dict
323 .set_item(k, int)
324 .context("Could not insert metadata into python dictionary")?,
325 Parameter::Float(float) => dict
326 .set_item(k, float)
327 .context("Could not insert metadata into python dictionary")?,
328 Parameter::String(s) => dict
329 .set_item(k, s)
330 .context("Could not insert metadata into python dictionary")?,
331 Parameter::ListInt(l) => dict
332 .set_item(k, l)
333 .context("Could not insert metadata into python dictionary")?,
334 Parameter::ListFloat(l) => dict
335 .set_item(k, l)
336 .context("Could not insert metadata into python dictionary")?,
337 Parameter::ListString(l) => dict
338 .set_item(k, l)
339 .context("Could not insert metadata into python dictionary")?,
340 Parameter::Timestamp(dt) => {
341 let timestamp = dt.timestamp();
343 let microseconds = dt.timestamp_subsec_micros();
344
345 let datetime_module =
347 PyModule::import(py, "datetime").context("Failed to import datetime module")?;
348 let datetime_class = datetime_module.getattr("datetime")?;
349 let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
350
351 let total_seconds = timestamp as f64 + microseconds as f64 / 1_000_000.0;
353 let py_datetime = datetime_class
354 .call_method1("fromtimestamp", (total_seconds, utc_timezone))
355 .context("Failed to create Python datetime from timestamp")?;
356
357 dict.set_item(k, py_datetime)
358 .context("Could not insert timestamp into python dictionary")?
359 }
360 }
361 }
362
363 Ok(dict)
364}
365
366#[cfg(test)]
367mod tests {
368 use std::{ptr::NonNull, sync::Arc};
369
370 use aligned_vec::{AVec, ConstAlign};
371 use arrow::{
372 array::{
373 ArrayData, ArrayRef, BooleanArray, Float64Array, Int8Array, Int32Array, Int64Array,
374 ListArray, StructArray,
375 },
376 buffer::Buffer,
377 };
378
379 use arrow_schema::{DataType, Field};
380 use dora_node_api::arrow_utils::{
381 buffer_into_arrow_array, copy_array_into_sample, required_data_size,
382 };
383 use eyre::{Context, Result};
384
385 fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
386 let size = required_data_size(arrow_array);
387 let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
388
389 let info = copy_array_into_sample(&mut sample, arrow_array);
390
391 let serialized_deserialized_arrow_array = {
392 let ptr = NonNull::new(sample.as_ptr() as *mut _).unwrap();
393 let len = sample.len();
394
395 let raw_buffer = unsafe {
396 arrow::buffer::Buffer::from_custom_allocation(ptr, len, Arc::new(sample))
397 };
398 buffer_into_arrow_array(&raw_buffer, &info)?
399 };
400
401 assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
402
403 Ok(())
404 }
405
406 #[test]
407 fn serialize_deserialize_arrow() -> Result<()> {
408 let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
410 assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
411
412 let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
414 assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
415
416 let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
418 assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
419
420 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
422 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
423
424 let struct_array = StructArray::from(vec![
425 (
426 Arc::new(Field::new("b", DataType::Boolean, false)),
427 boolean as ArrayRef,
428 ),
429 (
430 Arc::new(Field::new("c", DataType::Int32, false)),
431 int as ArrayRef,
432 ),
433 ])
434 .into();
435 assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
436
437 let value_data = ArrayData::builder(DataType::Int32)
439 .len(8)
440 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
441 .build()
442 .unwrap();
443
444 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
447
448 let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
450 let list_data = ArrayData::builder(list_data_type)
451 .len(3)
452 .add_buffer(value_offsets)
453 .add_child_data(value_data)
454 .build()
455 .unwrap();
456 let list_array = ListArray::from(list_data).into();
457 assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
458
459 Ok(())
460 }
461}