1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::{Arc, Mutex},
4};
5
6use arrow::pyarrow::ToPyArrow;
7use dora_node_api::{
8 merged::{MergeExternalSend, MergedEvent},
9 DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause,
10};
11use eyre::{Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge as _;
14use pyo3::{
15 prelude::*,
16 types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
17};
18
19pub struct PyEvent {
21 pub event: MergedEvent<PyObject>,
22}
23
24#[derive(Clone)]
26#[pyclass]
27pub struct NodeCleanupHandle {
28 pub _handles: Arc<CleanupHandle<DoraNode>>,
29}
30
31pub struct DelayedCleanup<T>(Arc<Mutex<T>>);
33
34impl<T> DelayedCleanup<T> {
35 pub fn new(value: T) -> Self {
36 Self(Arc::new(Mutex::new(value)))
37 }
38
39 pub fn handle(&self) -> CleanupHandle<T> {
40 CleanupHandle(self.0.clone())
41 }
42
43 pub fn get_mut(&mut self) -> std::sync::MutexGuard<T> {
44 self.0.try_lock().expect("failed to lock DelayedCleanup")
45 }
46}
47
48impl Stream for DelayedCleanup<EventStream> {
49 type Item = Event;
50
51 fn poll_next(
52 self: std::pin::Pin<&mut Self>,
53 cx: &mut std::task::Context<'_>,
54 ) -> std::task::Poll<Option<Self::Item>> {
55 let mut inner: std::sync::MutexGuard<'_, EventStream> = self.get_mut().get_mut();
56 inner.poll_next_unpin(cx)
57 }
58}
59
60impl<'a, E> MergeExternalSend<'a, E> for DelayedCleanup<EventStream>
61where
62 E: 'static,
63{
64 type Item = MergedEvent<E>;
65
66 fn merge_external_send(
67 self,
68 external_events: impl Stream<Item = E> + Unpin + Send + Sync + 'a,
69 ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
70 let dora = self.map(MergedEvent::Dora);
71 let external = external_events.map(MergedEvent::External);
72 Box::new((dora, external).merge())
73 }
74}
75
76#[allow(dead_code)]
77pub struct CleanupHandle<T>(Arc<Mutex<T>>);
78
79impl PyEvent {
80 pub fn to_py_dict(self, py: Python<'_>) -> PyResult<Py<PyDict>> {
81 let mut pydict = HashMap::new();
82 match &self.event {
83 MergedEvent::Dora(_) => pydict.insert(
84 "kind",
85 "dora"
86 .into_pyobject(py)
87 .context("Failed to create pystring")?
88 .unbind()
89 .into(),
90 ),
91 MergedEvent::External(_) => pydict.insert(
92 "kind",
93 "external"
94 .into_pyobject(py)
95 .context("Failed to create pystring")?
96 .unbind()
97 .into(),
98 ),
99 };
100 match &self.event {
101 MergedEvent::Dora(event) => {
102 if let Some(id) = Self::id(event) {
103 pydict.insert(
104 "id",
105 id.into_pyobject(py)
106 .context("Failed to create id pyobject")?
107 .into(),
108 );
109 }
110 pydict.insert(
111 "type",
112 Self::ty(event)
113 .into_pyobject(py)
114 .context("Failed to create event pyobject")?
115 .unbind()
116 .into(),
117 );
118
119 if let Some(value) = self.value(py)? {
120 pydict.insert("value", value);
121 }
122 if let Some(metadata) = Self::metadata(event, py)? {
123 pydict.insert("metadata", metadata);
124 }
125 if let Some(error) = Self::error(event) {
126 pydict.insert(
127 "error",
128 error
129 .into_pyobject(py)
130 .context("Failed to create error pyobject")?
131 .unbind()
132 .into(),
133 );
134 }
135 }
136 MergedEvent::External(event) => {
137 pydict.insert("value", event.clone_ref(py));
138 }
139 }
140
141 Ok(pydict
142 .into_py_dict(py)
143 .context("Failed to create py_dict")?
144 .unbind())
145 }
146
147 fn ty(event: &Event) -> &str {
148 match event {
149 Event::Stop(_) => "STOP",
150 Event::Input { .. } => "INPUT",
151 Event::InputClosed { .. } => "INPUT_CLOSED",
152 Event::Error(_) => "ERROR",
153 _other => "UNKNOWN",
154 }
155 }
156
157 fn id(event: &Event) -> Option<&str> {
158 match event {
159 Event::Input { id, .. } => Some(id),
160 Event::InputClosed { id } => Some(id),
161 Event::Stop(cause) => match cause {
162 StopCause::Manual => Some("MANUAL"),
163 StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"),
164 &_ => None,
165 },
166 _ => None,
167 }
168 }
169
170 fn value(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
172 match &self.event {
173 MergedEvent::Dora(Event::Input { data, .. }) => {
174 let array_data = data.to_data().to_pyarrow(py)?;
176 Ok(Some(array_data))
177 }
178 _ => Ok(None),
179 }
180 }
181
182 fn metadata(event: &Event, py: Python<'_>) -> Result<Option<PyObject>> {
183 match event {
184 Event::Input { metadata, .. } => Ok(Some(
185 metadata_to_pydict(metadata, py)
186 .context("Issue deserializing metadata")?
187 .into_pyobject(py)
188 .context("Failed to create metadata_to_pydice")?
189 .unbind()
190 .into(),
191 )),
192 _ => Ok(None),
193 }
194 }
195
196 fn error(event: &Event) -> Option<&str> {
197 match event {
198 Event::Error(error) => Some(error),
199 _other => None,
200 }
201 }
202}
203
204pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
205 let mut parameters = BTreeMap::default();
206 if let Some(pymetadata) = dict {
207 for (key, value) in pymetadata.iter() {
208 let key = key.extract::<String>().context("Parsing metadata keys")?;
209 if value.is_exact_instance_of::<PyBool>() {
210 parameters.insert(key, Parameter::Bool(value.extract()?))
211 } else if value.is_instance_of::<PyInt>() {
212 parameters.insert(key, Parameter::Integer(value.extract::<i64>()?))
213 } else if value.is_instance_of::<PyFloat>() {
214 parameters.insert(key, Parameter::Float(value.extract::<f64>()?))
215 } else if value.is_instance_of::<PyString>() {
216 parameters.insert(key, Parameter::String(value.extract()?))
217 } else if value.is_instance_of::<PyTuple>()
218 && value.len()? > 0
219 && value.get_item(0)?.is_exact_instance_of::<PyInt>()
220 {
221 let list: Vec<i64> = value.extract()?;
222 parameters.insert(key, Parameter::ListInt(list))
223 } else if value.is_instance_of::<PyList>()
224 && value.len()? > 0
225 && value.get_item(0)?.is_exact_instance_of::<PyInt>()
226 {
227 let list: Vec<i64> = value.extract()?;
228 parameters.insert(key, Parameter::ListInt(list))
229 } else if value.is_instance_of::<PyTuple>()
230 && value.len()? > 0
231 && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
232 {
233 let list: Vec<f64> = value.extract()?;
234 parameters.insert(key, Parameter::ListFloat(list))
235 } else if value.is_instance_of::<PyList>()
236 && value.len()? > 0
237 && value.get_item(0)?.is_exact_instance_of::<PyFloat>()
238 {
239 let list: Vec<f64> = value.extract()?;
240 parameters.insert(key, Parameter::ListFloat(list))
241 } else if value.is_instance_of::<PyList>()
242 && value.len()? > 0
243 && value.get_item(0)?.is_exact_instance_of::<PyString>()
244 {
245 let list: Vec<String> = value.extract()?;
246 parameters.insert(key, Parameter::ListString(list))
247 } else {
248 println!("could not convert type {value}");
249 parameters.insert(key, Parameter::String(value.str()?.to_string()))
250 };
251 }
252 }
253 Ok(parameters)
254}
255
256pub fn metadata_to_pydict<'a>(
257 metadata: &'a Metadata,
258 py: Python<'a>,
259) -> Result<pyo3::Bound<'a, PyDict>> {
260 let dict = PyDict::new(py);
261 for (k, v) in metadata.parameters.iter() {
262 match v {
263 Parameter::Bool(bool) => dict
264 .set_item(k, bool)
265 .context("Could not insert metadata into python dictionary")?,
266 Parameter::Integer(int) => dict
267 .set_item(k, int)
268 .context("Could not insert metadata into python dictionary")?,
269 Parameter::Float(float) => dict
270 .set_item(k, float)
271 .context("Could not insert metadata into python dictionary")?,
272 Parameter::String(s) => dict
273 .set_item(k, s)
274 .context("Could not insert metadata into python dictionary")?,
275 Parameter::ListInt(l) => dict
276 .set_item(k, l)
277 .context("Could not insert metadata into python dictionary")?,
278 Parameter::ListFloat(l) => dict
279 .set_item(k, l)
280 .context("Could not insert metadata into python dictionary")?,
281 Parameter::ListString(l) => dict
282 .set_item(k, l)
283 .context("Could not insert metadata into python dictionary")?,
284 }
285 }
286
287 Ok(dict)
288}
289
290#[cfg(test)]
291mod tests {
292 use std::sync::Arc;
293
294 use aligned_vec::{AVec, ConstAlign};
295 use arrow::{
296 array::{
297 ArrayData, ArrayRef, BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array,
298 ListArray, StructArray,
299 },
300 buffer::Buffer,
301 };
302
303 use arrow_schema::{DataType, Field};
304 use dora_node_api::{
305 arrow_utils::{copy_array_into_sample, required_data_size},
306 RawData,
307 };
308 use eyre::{Context, Result};
309
310 fn assert_roundtrip(arrow_array: &ArrayData) -> Result<()> {
311 let size = required_data_size(arrow_array);
312 let mut sample: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, size);
313
314 let info = copy_array_into_sample(&mut sample, arrow_array);
315
316 let serialized_deserialized_arrow_array = RawData::Vec(sample)
317 .into_arrow_array(&info)
318 .context("Could not create arrow array")?;
319 assert_eq!(arrow_array, &serialized_deserialized_arrow_array);
320
321 Ok(())
322 }
323
324 #[test]
325 fn serialize_deserialize_arrow() -> Result<()> {
326 let arrow_array = Int8Array::from(vec![1, -2, 3, 4]).into();
328 assert_roundtrip(&arrow_array).context("Int8Array roundtrip failed")?;
329
330 let arrow_array = Int64Array::from(vec![1, -2, 3, 4]).into();
332 assert_roundtrip(&arrow_array).context("Int64Array roundtrip failed")?;
333
334 let arrow_array = Float64Array::from(vec![1., -2., 3., 4.]).into();
336 assert_roundtrip(&arrow_array).context("Float64Array roundtrip failed")?;
337
338 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
340 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
341
342 let struct_array = StructArray::from(vec![
343 (
344 Arc::new(Field::new("b", DataType::Boolean, false)),
345 boolean as ArrayRef,
346 ),
347 (
348 Arc::new(Field::new("c", DataType::Int32, false)),
349 int as ArrayRef,
350 ),
351 ])
352 .into();
353 assert_roundtrip(&struct_array).context("StructArray roundtrip failed")?;
354
355 let value_data = ArrayData::builder(DataType::Int32)
357 .len(8)
358 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
359 .build()
360 .unwrap();
361
362 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
365
366 let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
368 let list_data = ArrayData::builder(list_data_type)
369 .len(3)
370 .add_buffer(value_offsets)
371 .add_child_data(value_data)
372 .build()
373 .unwrap();
374 let list_array = ListArray::from(list_data).into();
375 assert_roundtrip(&list_array).context("ListArray roundtrip failed")?;
376
377 Ok(())
378 }
379}