dora_runtime/
lib.rs

1#![warn(unsafe_op_in_unsafe_fn)]
2
3use dora_core::{
4    config::{DataId, OperatorId},
5    descriptor::OperatorConfig,
6};
7use dora_message::daemon_to_node::{NodeConfig, RuntimeConfig};
8use dora_metrics::run_metrics_monitor;
9use dora_node_api::{DoraNode, Event};
10use dora_tracing::TracingBuilder;
11use eyre::{bail, Context, Result};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge;
14use operator::{run_operator, OperatorEvent, StopReason};
15
16use std::{
17    collections::{BTreeMap, BTreeSet, HashMap},
18    mem,
19};
20use tokio::{
21    runtime::Builder,
22    sync::{mpsc, oneshot},
23};
24use tokio_stream::wrappers::ReceiverStream;
25mod operator;
26
27pub fn main() -> eyre::Result<()> {
28    let config: RuntimeConfig = {
29        let raw = std::env::var("DORA_RUNTIME_CONFIG")
30            .wrap_err("env variable DORA_RUNTIME_CONFIG must be set")?;
31        serde_yaml::from_str(&raw).context("failed to deserialize runtime config")?
32    };
33    let RuntimeConfig {
34        node: config,
35        operators,
36    } = config;
37    let node_id = config.node_id.clone();
38    #[cfg(feature = "tracing")]
39    {
40        TracingBuilder::new(node_id.as_ref())
41            .with_stdout("warn")
42            .build()
43            .wrap_err("failed to set up tracing subscriber")?;
44    }
45
46    let dataflow_descriptor = serde_yaml::from_value(config.dataflow_descriptor.clone())
47        .context("failed to parse dataflow descriptor")?;
48
49    let operator_definition = if operators.is_empty() {
50        bail!("no operators");
51    } else if operators.len() > 1 {
52        bail!("multiple operators are not supported");
53    } else {
54        let mut ops = operators;
55        ops.remove(0)
56    };
57
58    let (operator_events_tx, events) = mpsc::channel(1);
59    let operator_id = operator_definition.id.clone();
60    let operator_events = ReceiverStream::new(events).map(move |event| RuntimeEvent::Operator {
61        id: operator_id.clone(),
62        event,
63    });
64
65    let tokio_runtime = Builder::new_current_thread()
66        .enable_all()
67        .build()
68        .wrap_err("Could not build a tokio runtime.")?;
69
70    let mut operator_channels = HashMap::new();
71    let queue_sizes = queue_sizes(&operator_definition.config);
72    let (operator_channel, incoming_events) =
73        operator::channel::channel(tokio_runtime.handle(), queue_sizes);
74    operator_channels.insert(operator_definition.id.clone(), operator_channel);
75
76    tracing::info!("spawning main task");
77    let operator_config = [(
78        operator_definition.id.clone(),
79        operator_definition.config.clone(),
80    )]
81    .into_iter()
82    .collect();
83    let (init_done_tx, init_done) = oneshot::channel();
84    let main_task = std::thread::spawn(move || -> Result<()> {
85        tokio_runtime.block_on(run(
86            operator_config,
87            config,
88            operator_events,
89            operator_channels,
90            init_done,
91        ))
92    });
93
94    let operator_id = operator_definition.id.clone();
95    run_operator(
96        &node_id,
97        operator_definition,
98        incoming_events,
99        operator_events_tx,
100        init_done_tx,
101        &dataflow_descriptor,
102    )
103    .wrap_err_with(|| format!("failed to run operator {operator_id}"))?;
104
105    match main_task.join() {
106        Ok(result) => result.wrap_err("main task failed")?,
107        Err(panic) => std::panic::resume_unwind(panic),
108    }
109
110    Ok(())
111}
112
113fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap<DataId, usize> {
114    let mut sizes = BTreeMap::new();
115    for (input_id, input) in &config.inputs {
116        let queue_size = input.queue_size.unwrap_or(10);
117        sizes.insert(input_id.clone(), queue_size);
118    }
119    sizes
120}
121
122#[tracing::instrument(skip(operator_events, operator_channels), level = "trace")]
123async fn run(
124    operators: HashMap<OperatorId, OperatorConfig>,
125    config: NodeConfig,
126    operator_events: impl Stream<Item = RuntimeEvent> + Unpin,
127    mut operator_channels: HashMap<OperatorId, flume::Sender<Event>>,
128    init_done: oneshot::Receiver<Result<()>>,
129) -> eyre::Result<()> {
130    #[cfg(feature = "metrics")]
131    let _meter_provider = run_metrics_monitor(config.node_id.to_string());
132    init_done
133        .await
134        .wrap_err("the `init_done` channel was closed unexpectedly")?
135        .wrap_err("failed to init an operator")?;
136    tracing::info!("All operators are ready, starting runtime");
137
138    let (mut node, mut daemon_events) = DoraNode::init(config)?;
139    let (daemon_events_tx, daemon_event_stream) = flume::bounded(1);
140    tokio::task::spawn_blocking(move || {
141        while let Some(event) = daemon_events.recv() {
142            if daemon_events_tx.send(RuntimeEvent::Event(event)).is_err() {
143                break;
144            }
145        }
146    });
147    let mut events = (operator_events, daemon_event_stream.into_stream()).merge();
148
149    let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators
150        .iter()
151        .map(|(id, config)| (id, config.inputs.keys().collect()))
152        .collect();
153
154    while let Some(event) = events.next().await {
155        match event {
156            RuntimeEvent::Operator {
157                id: operator_id,
158                event,
159            } => {
160                match event {
161                    OperatorEvent::Error(err) => {
162                        bail!(err.wrap_err(format!(
163                            "operator {}/{operator_id} raised an error",
164                            node.id()
165                        )))
166                    }
167                    OperatorEvent::Panic(payload) => {
168                        bail!("operator {operator_id} panicked: {payload:?}");
169                    }
170                    OperatorEvent::Finished { reason } => {
171                        if let StopReason::ExplicitStopAll = reason {
172                            // let hlc = dora_core::message::uhlc::HLC::default();
173                            // let metadata = dora_core::message::Metadata::new(hlc.new_timestamp());
174                            // let data = metadata
175                            // .serialize()
176                            // .wrap_err("failed to serialize stop message")?;
177                            todo!("instruct dora-daemon/dora-coordinator to stop other nodes");
178                            // manual_stop_publisher
179                            //     .publish(&data)
180                            //     .map_err(|err| eyre::eyre!(err))
181                            //     .wrap_err("failed to send stop message")?;
182                            // break;
183                        }
184
185                        let Some(config) = operators.get(&operator_id) else {
186                            tracing::warn!(
187                                "received Finished event for unknown operator `{operator_id}`"
188                            );
189                            continue;
190                        };
191                        let outputs = config
192                            .outputs
193                            .iter()
194                            .map(|output_id| operator_output_id(&operator_id, output_id))
195                            .collect();
196                        let result;
197                        (node, result) = tokio::task::spawn_blocking(move || {
198                            let result = node.close_outputs(outputs);
199                            (node, result)
200                        })
201                        .await
202                        .wrap_err("failed to wait for close_outputs task")?;
203                        result.wrap_err("failed to close outputs of finished operator")?;
204
205                        operator_channels.remove(&operator_id);
206
207                        if operator_channels.is_empty() {
208                            break;
209                        }
210                    }
211                    OperatorEvent::AllocateOutputSample { len, sample: tx } => {
212                        let sample = node.allocate_data_sample(len);
213                        if tx.send(sample).is_err() {
214                            tracing::warn!("output sample requested, but operator {operator_id} exited already");
215                        }
216                    }
217                    OperatorEvent::Output {
218                        output_id,
219                        type_info,
220                        parameters,
221                        data,
222                    } => {
223                        let output_id = operator_output_id(&operator_id, &output_id);
224                        let result;
225                        (node, result) = tokio::task::spawn_blocking(move || {
226                            let result =
227                                node.send_output_sample(output_id, type_info, parameters, data);
228                            (node, result)
229                        })
230                        .await
231                        .wrap_err("failed to wait for send_output task")?;
232                        result.wrap_err("failed to send node output")?;
233                    }
234                }
235            }
236            RuntimeEvent::Event(Event::Stop(cause)) => {
237                // forward stop event to all operators and close the event channels
238                for (_, channel) in operator_channels.drain() {
239                    let _ = channel.send_async(Event::Stop(cause.clone())).await;
240                }
241            }
242            RuntimeEvent::Event(Event::Reload {
243                operator_id: Some(operator_id),
244            }) => {
245                let _ = operator_channels
246                    .get(&operator_id)
247                    .unwrap()
248                    .send_async(Event::Reload {
249                        operator_id: Some(operator_id),
250                    })
251                    .await;
252            }
253            RuntimeEvent::Event(Event::Reload { operator_id: None }) => {
254                tracing::warn!("Reloading runtime nodes is not supported");
255            }
256            RuntimeEvent::Event(Event::Input { id, metadata, data }) => {
257                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
258                    tracing::warn!("received non-operator input {id}");
259                    continue;
260                };
261                let operator_id = OperatorId::from(operator_id.to_owned());
262                let input_id = DataId::from(input_id.to_owned());
263                let Some(operator_channel) = operator_channels.get(&operator_id) else {
264                    tracing::warn!("received input {id} for unknown operator");
265                    continue;
266                };
267
268                if let Err(err) = operator_channel
269                    .send_async(Event::Input {
270                        id: input_id.clone(),
271                        metadata,
272                        data,
273                    })
274                    .await
275                    .wrap_err_with(|| {
276                        format!("failed to send input `{input_id}` to operator `{operator_id}`")
277                    })
278                {
279                    tracing::warn!("{err}");
280                }
281            }
282            RuntimeEvent::Event(Event::InputClosed { id }) => {
283                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
284                    tracing::warn!("received InputClosed event for non-operator input {id}");
285                    continue;
286                };
287                let operator_id = OperatorId::from(operator_id.to_owned());
288                let input_id = DataId::from(input_id.to_owned());
289
290                let Some(operator_channel) = operator_channels.get(&operator_id) else {
291                    tracing::warn!("received input {id} for unknown operator");
292                    continue;
293                };
294                if let Err(err) = operator_channel
295                    .send_async(Event::InputClosed {
296                        id: input_id.clone(),
297                    })
298                    .await
299                    .wrap_err_with(|| {
300                        format!(
301                            "failed to send InputClosed({input_id}) to operator `{operator_id}`"
302                        )
303                    })
304                {
305                    tracing::warn!("{err}");
306                }
307
308                if let Some(open_inputs) = open_operator_inputs.get_mut(&operator_id) {
309                    open_inputs.remove(&input_id);
310                    if open_inputs.is_empty() {
311                        // all inputs of the node were closed -> close its event channel
312                        tracing::trace!("all inputs of operator {}/{operator_id} were closed -> closing event channel", node.id());
313                        open_operator_inputs.remove(&operator_id);
314                        operator_channels.remove(&operator_id);
315                    }
316                }
317            }
318            RuntimeEvent::Event(Event::Error(err)) => eyre::bail!("received error event: {err}"),
319            RuntimeEvent::Event(other) => {
320                tracing::warn!("received unknown event `{other:?}`");
321            }
322        }
323    }
324
325    mem::drop(events);
326
327    Ok(())
328}
329
330fn operator_output_id(operator_id: &OperatorId, output_id: &DataId) -> DataId {
331    DataId::from(format!("{operator_id}/{output_id}"))
332}
333
334#[derive(Debug)]
335enum RuntimeEvent {
336    Operator {
337        id: OperatorId,
338        event: OperatorEvent,
339    },
340    Event(Event),
341}