dora_runtime/
lib.rs

1#![warn(unsafe_op_in_unsafe_fn)]
2
3use dora_core::{
4    config::{DataId, OperatorId},
5    descriptor::OperatorConfig,
6};
7use dora_message::daemon_to_node::{NodeConfig, RuntimeConfig};
8use dora_metrics::run_metrics_monitor;
9use dora_node_api::{DoraNode, Event};
10use dora_tracing::TracingBuilder;
11use eyre::{Context, Result, bail};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge;
14use operator::{OperatorEvent, StopReason, run_operator};
15
16use std::{
17    collections::{BTreeMap, BTreeSet, HashMap},
18    mem,
19};
20use tokio::{
21    runtime::Builder,
22    sync::{mpsc, oneshot},
23};
24use tokio_stream::wrappers::ReceiverStream;
25mod operator;
26
27pub fn main() -> eyre::Result<()> {
28    let config: RuntimeConfig = {
29        let raw = std::env::var("DORA_RUNTIME_CONFIG")
30            .wrap_err("env variable DORA_RUNTIME_CONFIG must be set")?;
31        serde_yaml::from_str(&raw).context("failed to deserialize runtime config")?
32    };
33    let RuntimeConfig {
34        node: config,
35        operators,
36    } = config;
37    let node_id = config.node_id.clone();
38    #[cfg(feature = "tracing")]
39    {
40        TracingBuilder::new(node_id.as_ref())
41            .with_stdout("warn", false)
42            .build()
43            .wrap_err("failed to set up tracing subscriber")?;
44    }
45
46    let dataflow_descriptor = serde_yaml::from_value(config.dataflow_descriptor.clone())
47        .context("failed to parse dataflow descriptor")?;
48
49    let operator_definition = if operators.is_empty() {
50        bail!("no operators");
51    } else if operators.len() > 1 {
52        bail!("multiple operators are not supported");
53    } else {
54        let mut ops = operators;
55        ops.remove(0)
56    };
57
58    let (operator_events_tx, events) = mpsc::channel(1);
59    let operator_id = operator_definition.id.clone();
60    let operator_events = ReceiverStream::new(events).map(move |event| RuntimeEvent::Operator {
61        id: operator_id.clone(),
62        event,
63    });
64
65    let tokio_runtime = Builder::new_current_thread()
66        .enable_all()
67        .build()
68        .wrap_err("Could not build a tokio runtime.")?;
69
70    let mut operator_channels = HashMap::new();
71    let queue_sizes = queue_sizes(&operator_definition.config);
72    let (operator_channel, incoming_events) =
73        operator::channel::channel(tokio_runtime.handle(), queue_sizes);
74    operator_channels.insert(operator_definition.id.clone(), operator_channel);
75
76    tracing::info!("spawning main task");
77    let operator_config = [(
78        operator_definition.id.clone(),
79        operator_definition.config.clone(),
80    )]
81    .into_iter()
82    .collect();
83    let (init_done_tx, init_done) = oneshot::channel();
84    let main_task = std::thread::spawn(move || -> Result<()> {
85        tokio_runtime.block_on(run(
86            operator_config,
87            config,
88            operator_events,
89            operator_channels,
90            init_done,
91        ))
92    });
93
94    let operator_id = operator_definition.id.clone();
95    run_operator(
96        &node_id,
97        operator_definition,
98        incoming_events,
99        operator_events_tx,
100        init_done_tx,
101        &dataflow_descriptor,
102    )
103    .wrap_err_with(|| format!("failed to run operator {operator_id}"))?;
104
105    match main_task.join() {
106        Ok(result) => result.wrap_err("main task failed")?,
107        Err(panic) => std::panic::resume_unwind(panic),
108    }
109
110    Ok(())
111}
112
113fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap<DataId, usize> {
114    let mut sizes = BTreeMap::new();
115    for (input_id, input) in &config.inputs {
116        let queue_size = input.queue_size.unwrap_or(10);
117        sizes.insert(input_id.clone(), queue_size);
118    }
119    sizes
120}
121
122#[tracing::instrument(skip(operator_events, operator_channels), level = "trace")]
123async fn run(
124    operators: HashMap<OperatorId, OperatorConfig>,
125    config: NodeConfig,
126    operator_events: impl Stream<Item = RuntimeEvent> + Unpin,
127    mut operator_channels: HashMap<OperatorId, flume::Sender<Event>>,
128    init_done: oneshot::Receiver<Result<()>>,
129) -> eyre::Result<()> {
130    #[cfg(feature = "metrics")]
131    let _meter_provider = run_metrics_monitor(config.node_id.to_string());
132    init_done
133        .await
134        .wrap_err("the `init_done` channel was closed unexpectedly")?
135        .wrap_err("failed to init an operator")?;
136    tracing::info!("All operators are ready, starting runtime");
137
138    let (mut node, mut daemon_events) = DoraNode::init(config)?;
139    let (daemon_events_tx, daemon_event_stream) = flume::bounded(1);
140    tokio::task::spawn_blocking(move || {
141        while let Some(event) = daemon_events.recv() {
142            if daemon_events_tx.send(RuntimeEvent::Event(event)).is_err() {
143                break;
144            }
145        }
146    });
147    let mut events = (operator_events, daemon_event_stream.into_stream()).merge();
148
149    let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators
150        .iter()
151        .map(|(id, config)| (id, config.inputs.keys().collect()))
152        .collect();
153
154    while let Some(event) = events.next().await {
155        match event {
156            RuntimeEvent::Operator {
157                id: operator_id,
158                event,
159            } => {
160                match event {
161                    OperatorEvent::Error(err) => {
162                        bail!(err.wrap_err(format!(
163                            "operator {}/{operator_id} raised an error",
164                            node.id()
165                        )))
166                    }
167                    OperatorEvent::Panic(payload) => {
168                        bail!("operator {operator_id} panicked: {payload:?}");
169                    }
170                    OperatorEvent::Finished { reason } => {
171                        if let StopReason::ExplicitStopAll = reason {
172                            // let hlc = dora_core::message::uhlc::HLC::default();
173                            // let metadata = dora_core::message::Metadata::new(hlc.new_timestamp());
174                            // let data = metadata
175                            // .serialize()
176                            // .wrap_err("failed to serialize stop message")?;
177                            todo!("instruct dora-daemon/dora-coordinator to stop other nodes");
178                            // manual_stop_publisher
179                            //     .publish(&data)
180                            //     .map_err(|err| eyre::eyre!(err))
181                            //     .wrap_err("failed to send stop message")?;
182                            // break;
183                        }
184
185                        let Some(config) = operators.get(&operator_id) else {
186                            tracing::warn!(
187                                "received Finished event for unknown operator `{operator_id}`"
188                            );
189                            continue;
190                        };
191                        let outputs = config
192                            .outputs
193                            .iter()
194                            .map(|output_id| operator_output_id(&operator_id, output_id))
195                            .collect();
196                        let result;
197                        (node, result) = tokio::task::spawn_blocking(move || {
198                            let result = node.close_outputs(outputs);
199                            (node, result)
200                        })
201                        .await
202                        .wrap_err("failed to wait for close_outputs task")?;
203                        result.wrap_err("failed to close outputs of finished operator")?;
204
205                        operator_channels.remove(&operator_id);
206
207                        if operator_channels.is_empty() {
208                            break;
209                        }
210                    }
211                    OperatorEvent::AllocateOutputSample { len, sample: tx } => {
212                        let sample = node.allocate_data_sample(len);
213                        if tx.send(sample).is_err() {
214                            tracing::warn!(
215                                "output sample requested, but operator {operator_id} exited already"
216                            );
217                        }
218                    }
219                    OperatorEvent::Output {
220                        output_id,
221                        type_info,
222                        parameters,
223                        data,
224                    } => {
225                        let output_id = operator_output_id(&operator_id, &output_id);
226                        let result;
227                        (node, result) = tokio::task::spawn_blocking(move || {
228                            let result =
229                                node.send_output_sample(output_id, type_info, parameters, data);
230                            (node, result)
231                        })
232                        .await
233                        .wrap_err("failed to wait for send_output task")?;
234                        result.wrap_err("failed to send node output")?;
235                    }
236                }
237            }
238            RuntimeEvent::Event(Event::Stop(cause)) => {
239                // forward stop event to all operators and close the event channels
240                for (_, channel) in operator_channels.drain() {
241                    let _ = channel.send_async(Event::Stop(cause.clone())).await;
242                }
243            }
244            RuntimeEvent::Event(Event::Reload {
245                operator_id: Some(operator_id),
246            }) => {
247                let _ = operator_channels
248                    .get(&operator_id)
249                    .unwrap()
250                    .send_async(Event::Reload {
251                        operator_id: Some(operator_id),
252                    })
253                    .await;
254            }
255            RuntimeEvent::Event(Event::Reload { operator_id: None }) => {
256                tracing::warn!("Reloading runtime nodes is not supported");
257            }
258            RuntimeEvent::Event(Event::Input { id, metadata, data }) => {
259                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
260                    tracing::warn!("received non-operator input {id}");
261                    continue;
262                };
263                let operator_id = OperatorId::from(operator_id.to_owned());
264                let input_id = DataId::from(input_id.to_owned());
265                let Some(operator_channel) = operator_channels.get(&operator_id) else {
266                    tracing::warn!("received input {id} for unknown operator");
267                    continue;
268                };
269
270                if let Err(err) = operator_channel
271                    .send_async(Event::Input {
272                        id: input_id.clone(),
273                        metadata,
274                        data,
275                    })
276                    .await
277                    .wrap_err_with(|| {
278                        format!("failed to send input `{input_id}` to operator `{operator_id}`")
279                    })
280                {
281                    tracing::warn!("{err}");
282                }
283            }
284            RuntimeEvent::Event(Event::InputClosed { id }) => {
285                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
286                    tracing::warn!("received InputClosed event for non-operator input {id}");
287                    continue;
288                };
289                let operator_id = OperatorId::from(operator_id.to_owned());
290                let input_id = DataId::from(input_id.to_owned());
291
292                let Some(operator_channel) = operator_channels.get(&operator_id) else {
293                    tracing::warn!("received input {id} for unknown operator");
294                    continue;
295                };
296                if let Err(err) = operator_channel
297                    .send_async(Event::InputClosed {
298                        id: input_id.clone(),
299                    })
300                    .await
301                    .wrap_err_with(|| {
302                        format!(
303                            "failed to send InputClosed({input_id}) to operator `{operator_id}`"
304                        )
305                    })
306                {
307                    tracing::warn!("{err}");
308                }
309
310                if let Some(open_inputs) = open_operator_inputs.get_mut(&operator_id) {
311                    open_inputs.remove(&input_id);
312                    if open_inputs.is_empty() {
313                        // all inputs of the node were closed -> close its event channel
314                        tracing::trace!(
315                            "all inputs of operator {}/{operator_id} were closed -> closing event channel",
316                            node.id()
317                        );
318                        open_operator_inputs.remove(&operator_id);
319                        operator_channels.remove(&operator_id);
320                    }
321                }
322            }
323            RuntimeEvent::Event(Event::Error(err)) => eyre::bail!("received error event: {err}"),
324            RuntimeEvent::Event(other) => {
325                tracing::warn!("received unknown event `{other:?}`");
326            }
327        }
328    }
329
330    mem::drop(events);
331
332    Ok(())
333}
334
335fn operator_output_id(operator_id: &OperatorId, output_id: &DataId) -> DataId {
336    DataId::from(format!("{operator_id}/{output_id}"))
337}
338
339#[derive(Debug)]
340enum RuntimeEvent {
341    Operator {
342        id: OperatorId,
343        event: OperatorEvent,
344    },
345    Event(Event),
346}