1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8 DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
9 merged::{MergeExternalSend, MergedEvent},
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15 prelude::*,
16 types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple},
17};
18use std::time::{SystemTime, UNIX_EPOCH};
19
20pub struct PyEvent {
22 pub event: MergedEvent<PyObject>,
23}
24
25#[derive(Clone)]
27#[pyclass]
28pub struct NodeCleanupHandle {
29 pub _handles: Arc<CleanupHandle<DoraNode>>,
30}
31
32pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
34
35impl<T> DelayedCleanup<T> {
36 pub fn new(value: T) -> Self {
37 Self(Arc::new(Mutex::new(value)))
38 }
39
40 pub fn handle(&self) -> CleanupHandle<T> {
41 CleanupHandle(self.0.clone())
42 }
43
44 pub fn get_mut(&self) -> std::sync::MutexGuard<T> {
45 self.0.try_lock().expect("failed to lock DelayedCleanup")
46 }
47}
48
49impl Stream for DelayedCleanup<EventStream> {
50 type Item = Event;
51
52 fn poll_next(
53 self: std::pin::Pin<&mut Self>,
54 cx: &mut std::task::Context<'_>,
55 ) -> std::task::Poll<Option<Self::Item>> {
56 let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
57 inner.poll_next_unpin(cx)
58 }
59}
60
61impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
62where
63 E: 'static,
64{
65 type Item = MergedEvent<E>;
66
67 fn merge_external_send(
68 self,
69 external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
70 ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
71 let dora = self.map(MergedEvent::Dora);
72 let external = external_events.map(MergedEvent::External);
73 Box::new((dora, external).merge())
74 }
75}
76
77#[allow(dead_code)]
78pub struct CleanupHandle<T>(Arc<Mutex<T>>);
79
80impl PyEvent {
81 pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
82 let mut pydict = HashMap::new();
83 match &self.event {
84 MergedEvent::Dora(_) => pydict.insert(
85 "kind",
86 "dora"
87 .into_pyobject(py)
88 .context("Failed to create pystring")?
89 .unbind()
90 .into(),
91 ),
92 MergedEvent::External(_) => pydict.insert(
93 "kind",
94 "external"
95 .into_pyobject(py)
96 .context("Failed to create pystring")?
97 .unbind()
98 .into(),
99 ),
100 };
101 match &self.event {
102 MergedEvent::Dora(event) => {
103 if let Some(id) = Self::id(event) {
104 pydict.insert(
105 "id",
106 id.into_pyobject(py)
107 .context("Failed to create id pyobject")?
108 .into(),
109 );
110 }
111 pydict.insert(
112 "type",
113 Self::ty(event)
114 .into_pyobject(py)
115 .context("Failed to create event pyobject")?
116 .unbind()
117 .into(),
118 );
119
120 if let Some(value) = self.value(py)? {
121 pydict.insert("value", value);
122 }
123 if let Some(metadata) = Self::metadata(event, py)? {
124 pydict.insert("metadata", metadata);
125 }
126 if let Some(error) = Self::error(event) {
127 pydict.insert(
128 "error",
129 error
130 .into_pyobject(py)
131 .context("Failed to create error pyobject")?
132 .unbind()
133 .into(),
134 );
135 }
136 }
137 MergedEvent::External(event) => {
138 pydict.insert("value", event.clone_ref(py));
139 }
140 }
141
142 Ok(pydict
143 .into_py_dict(py)
144 .context("Failed to create py_dict")?
145 .unbind())
146 }
147
148 fn ty(event: &Event) -> &str {
149 match event {
150 Event::Stop(_) => "STOP",
151 Event::Input { .. } => "INPUT",
152 Event::InputClosed { .. } => "INPUT_CLOSED",
153 Event::Error(_) => "ERROR",
154 _other => "UNKNOWN",
155 }
156 }
157
158 fn id(event: &Event) -> Option<&str> {
159 match event {
160 Event::Input { id, .. } => Some(id),
161 Event::InputClosed { id } => Some(id),
162 Event::Stop(cause) => match cause {
163 StopCause::Manual => Some("MANUAL"),
164 StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
165 &_ => None,
166 },
167 _ => None,
168 }
169 }
170
171 fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
173 match &self.event {
174 MergedEvent::Dora(Event::Input { data, .. }) => {
175 let array_data = data.to_data().to_pyarrow(py)?;
177 Ok(Some(array_data))
178 }
179 _ => Ok(None),
180 }
181 }
182
183 fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
184 match event {
185 Event::Input { metadata, .. } => Ok(Some(
186 metadata_to_pydict(metadata, py)
187 .context("Issue deserializing metadata")?
188 .into_pyobject(py)
189 .context("Failed to create metadata_to_pydice")?
190 .unbind()
191 .into(),
192 )),
193 _ => Ok(None),
194 }
195 }
196
197 fn error(event: &Event) -> Option<&str> {
198 match event {
199 Event::Error(error) => Some(error),
200 _other => None,
201 }
202 }
203}
204
205pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
206 let mut parameters = BTreeMap::default();
207 if let Some(pymetadata) = dict {
208 for (key, value) in pymetadata.iter() {
209 let key = key.extract::<String>().context("Parsing metadata keys")?;
210 if value.is_exact_instance_of::<PyBool>() {
211 parameters.insert(key, Parameter::Bool(value.extract()?))
212 } else if value.is_instance_of::<PyInt>() {
213 parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
214 } else if value.is_instance_of::<PyFloat>() {
215 parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
216 } else if value.is_instance_of::<PyString>() {
217 parameters.insert(key, Parameter::String(value.extract()?))
218 } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
219 && value.len()? > 0
220 && value.get_item(0)?.is_exact_instance_of::<PyInt>()
221 {
222 let list: Vec<i64> = value.extract()?;
223 parameters.insert(key, Parameter::ListInt(list))
224 } else if (value.is_instance_of::<PyTuple>() || value.is_instance_of::<PyList>())
225 && value.len()? > 0
226 && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
227 {
228 let list: Vec<f64> = value.extract()?;
229 parameters.insert(key, Parameter::ListFloat(list))
230 } else if value.is_instance_of::<PyList>()
231 && value.len()? > 0
232 && value.get_item(0)?.is_exact_instance_of::<PyString>()
233 {
234 let list: Vec<String> = value.extract()?;
235 parameters.insert(key, Parameter::ListString(list))
236 } else {
237 println!("could not convert type {value}");
238 parameters.insert(key, Parameter::String(value.str()?.to_string()))
239 };
240 }
241 }
242 Ok(parameters)
243}
244
245pub fn metadata_to_pydict<'a>(
246 metadata: &'a Metadata,
247 py: Python<'a>,
248) -> Result<pyo3::Bound<'a, PyDict>> {
249 let dict = PyDict::new(py);
250
251 let timestamp = metadata.timestamp();
257 let system_time = timestamp.get_time().to_system_time();
258 let duration_since_epoch = system_time
259 .duration_since(UNIX_EPOCH)
260 .context("Failed to calculate duration since epoch")?;
261
262 let seconds = duration_since_epoch.as_secs() as i64;
264 let microseconds = duration_since_epoch.subsec_micros() as u32;
265
266 let datetime_module =
270 PyModule::import(py, "datetime").context("Failed to import datetime module")?;
271 let datetime_class = datetime_module.getattr("datetime")?;
272 let utc_timezone = datetime_module.getattr("timezone")?.getattr("utc")?;
273
274 let total_seconds = seconds as f64 + microseconds as f64 / 1_000_000.0;
278 let py_datetime = datetime_class
279 .call_method1("fromtimestamp", (total_seconds, utc_timezone))
280 .context("Failed to create Python datetime from timestamp")?;
281
282 dict.set_item("timestamp", py_datetime)
283 .context("Could not insert timestamp into python dictionary")?;
284
285 for (k, v) in metadata.parameters.iter() {
287 match v {
288 Parameter::Bool(bool) => dict
289 .set_item(k, bool)
290 .context("Could not insert metadata into python dictionary")?,
291 Parameter::Integer(int) => dict
292 .set_item(k, int)
293 .context("Could not insert metadata into python dictionary")?,
294 Parameter::Float(float) => dict
295 .set_item(k, float)
296 .context("Could not insert metadata into python dictionary")?,
297 Parameter::String(s) => dict
298 .set_item(k, s)
299 .context("Could not insert metadata into python dictionary")?,
300 Parameter::ListInt(l) => dict
301 .set_item(k, l)
302 .context("Could not insert metadata into python dictionary")?,
303 Parameter::ListFloat(l) => dict
304 .set_item(k, l)
305 .context("Could not insert metadata into python dictionary")?,
306 Parameter::ListString(l) => dict
307 .set_item(k, l)
308 .context("Could not insert metadata into python dictionary")?,
309 }
310 }
311
312 Ok(dict)
313}
314
315#[cfg(test)]
316mod tests {
317 use std::{ptr::NonNull, sync::Arc};
318
319 use aligned_vec::{AVec, ConstAlign};
320 use arrow::{
321 array::{
322 ArrayData, ArrayRef, BooleanArray, Float64Array, Int8Array, Int32Array, Int64Array,
323 ListArray, StructArray,
324 },
325 buffer::Buffer,
326 };
327
328 use arrow_schema::{DataType, Field};
329 use dora_node_api::arrow_utils::{
330 buffer_into_arrow_array, copy_array_into_sample, required_data_size,
331 };
332 use eyre::{Context, Result};
333
334 fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
335 let size = required_data_size(arrow_array);
336 let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
337
338 let info = copy_array_into_sample(&mut sample, arrow_array);
339
340 let serialized_deserialized_arrow_array = {
341 let ptr = NonNull::new(sample.as_ptr() as *mut _).unwrap();
342 let len = sample.len();
343
344 let raw_buffer = unsafe {
345 arrow::buffer::Buffer::from_custom_allocation(ptr, len, Arc::new(sample))
346 };
347 buffer_into_arrow_array(&raw_buffer, &info)?
348 };
349
350 assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
351
352 Ok(())
353 }
354
355 #[test]
356 fn serialize_deserialize_arrow() -> Result<()> {
357 let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
359 assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
360
361 let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
363 assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
364
365 let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
367 assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
368
369 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
371 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
372
373 let struct_array = StructArray::from(vec![
374 (
375 Arc::new(Field::new("b", DataType::Boolean, false)),
376 boolean as ArrayRef,
377 ),
378 (
379 Arc::new(Field::new("c", DataType::Int32, false)),
380 int as ArrayRef,
381 ),
382 ])
383 .into();
384 assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
385
386 let value_data = ArrayData::builder(DataType::Int32)
388 .len(8)
389 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
390 .build()
391 .unwrap();
392
393 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
396
397 let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
399 let list_data = ArrayData::builder(list_data_type)
400 .len(3)
401 .add_buffer(value_offsets)
402 .add_child_data(value_data)
403 .build()
404 .unwrap();
405 let list_array = ListArray::from(list_data).into();
406 assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
407
408 Ok(())
409 }
410}