dora_runtime/
lib.rs

1#![warn(unsafe_op_in_unsafe_fn)]
2
3use dora_core::{
4    config::{DataId, OperatorId},
5    descriptor::OperatorConfig,
6};
7use dora_message::daemon_to_node::{NodeConfig, RuntimeConfig};
8use dora_metrics::init_meter_provider;
9use dora_node_api::{DoraNode, Event};
10use eyre::{bail, Context, Result};
11use futures::{Stream, StreamExt};
12use futures_concurrency::stream::Merge;
13use operator::{run_operator, OperatorEvent, StopReason};
14
15#[cfg(feature = "tracing")]
16use dora_tracing::set_up_tracing;
17use std::{
18    collections::{BTreeMap, BTreeSet, HashMap},
19    mem,
20};
21use tokio::{
22    runtime::Builder,
23    sync::{mpsc, oneshot},
24};
25use tokio_stream::wrappers::ReceiverStream;
26mod operator;
27
28pub fn main() -> eyre::Result<()> {
29    let config: RuntimeConfig = {
30        let raw = std::env::var("DORA_RUNTIME_CONFIG")
31            .wrap_err("env variable DORA_RUNTIME_CONFIG must be set")?;
32        serde_yaml::from_str(&raw).context("failed to deserialize runtime config")?
33    };
34    let RuntimeConfig {
35        node: config,
36        operators,
37    } = config;
38    let node_id = config.node_id.clone();
39    #[cfg(feature = "tracing")]
40    set_up_tracing(node_id.as_ref()).context("failed to set up tracing subscriber")?;
41
42    let dataflow_descriptor = config.dataflow_descriptor.clone();
43
44    let operator_definition = if operators.is_empty() {
45        bail!("no operators");
46    } else if operators.len() > 1 {
47        bail!("multiple operators are not supported");
48    } else {
49        let mut ops = operators;
50        ops.remove(0)
51    };
52
53    let (operator_events_tx, events) = mpsc::channel(1);
54    let operator_id = operator_definition.id.clone();
55    let operator_events = ReceiverStream::new(events).map(move |event| RuntimeEvent::Operator {
56        id: operator_id.clone(),
57        event,
58    });
59
60    let tokio_runtime = Builder::new_current_thread()
61        .enable_all()
62        .build()
63        .wrap_err("Could not build a tokio runtime.")?;
64
65    let mut operator_channels = HashMap::new();
66    let queue_sizes = queue_sizes(&operator_definition.config);
67    let (operator_channel, incoming_events) =
68        operator::channel::channel(tokio_runtime.handle(), queue_sizes);
69    operator_channels.insert(operator_definition.id.clone(), operator_channel);
70
71    tracing::info!("spawning main task");
72    let operator_config = [(
73        operator_definition.id.clone(),
74        operator_definition.config.clone(),
75    )]
76    .into_iter()
77    .collect();
78    let (init_done_tx, init_done) = oneshot::channel();
79    let main_task = std::thread::spawn(move || -> Result<()> {
80        tokio_runtime.block_on(run(
81            operator_config,
82            config,
83            operator_events,
84            operator_channels,
85            init_done,
86        ))
87    });
88
89    let operator_id = operator_definition.id.clone();
90    run_operator(
91        &node_id,
92        operator_definition,
93        incoming_events,
94        operator_events_tx,
95        init_done_tx,
96        &dataflow_descriptor,
97    )
98    .wrap_err_with(|| format!("failed to run operator {operator_id}"))?;
99
100    match main_task.join() {
101        Ok(result) => result.wrap_err("main task failed")?,
102        Err(panic) => std::panic::resume_unwind(panic),
103    }
104
105    Ok(())
106}
107
108fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap<DataId, usize> {
109    let mut sizes = BTreeMap::new();
110    for (input_id, input) in &config.inputs {
111        let queue_size = input.queue_size.unwrap_or(10);
112        sizes.insert(input_id.clone(), queue_size);
113    }
114    sizes
115}
116
117#[tracing::instrument(skip(operator_events, operator_channels), level = "trace")]
118async fn run(
119    operators: HashMap<OperatorId, OperatorConfig>,
120    config: NodeConfig,
121    operator_events: impl Stream<Item = RuntimeEvent> + Unpin,
122    mut operator_channels: HashMap<OperatorId, flume::Sender<Event>>,
123    init_done: oneshot::Receiver<Result<()>>,
124) -> eyre::Result<()> {
125    #[cfg(feature = "metrics")]
126    let _meter_provider = init_meter_provider(config.node_id.to_string());
127    init_done
128        .await
129        .wrap_err("the `init_done` channel was closed unexpectedly")?
130        .wrap_err("failed to init an operator")?;
131    tracing::info!("All operators are ready, starting runtime");
132
133    let (mut node, mut daemon_events) = DoraNode::init(config)?;
134    let (daemon_events_tx, daemon_event_stream) = flume::bounded(1);
135    tokio::task::spawn_blocking(move || {
136        while let Some(event) = daemon_events.recv() {
137            if daemon_events_tx.send(RuntimeEvent::Event(event)).is_err() {
138                break;
139            }
140        }
141    });
142    let mut events = (operator_events, daemon_event_stream.into_stream()).merge();
143
144    let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators
145        .iter()
146        .map(|(id, config)| (id, config.inputs.keys().collect()))
147        .collect();
148
149    while let Some(event) = events.next().await {
150        match event {
151            RuntimeEvent::Operator {
152                id: operator_id,
153                event,
154            } => {
155                match event {
156                    OperatorEvent::Error(err) => {
157                        bail!(err.wrap_err(format!(
158                            "operator {}/{operator_id} raised an error",
159                            node.id()
160                        )))
161                    }
162                    OperatorEvent::Panic(payload) => {
163                        bail!("operator {operator_id} panicked: {payload:?}");
164                    }
165                    OperatorEvent::Finished { reason } => {
166                        if let StopReason::ExplicitStopAll = reason {
167                            // let hlc = dora_core::message::uhlc::HLC::default();
168                            // let metadata = dora_core::message::Metadata::new(hlc.new_timestamp());
169                            // let data = metadata
170                            // .serialize()
171                            // .wrap_err("failed to serialize stop message")?;
172                            todo!("instruct dora-daemon/dora-coordinator to stop other nodes");
173                            // manual_stop_publisher
174                            //     .publish(&data)
175                            //     .map_err(|err| eyre::eyre!(err))
176                            //     .wrap_err("failed to send stop message")?;
177                            // break;
178                        }
179
180                        let Some(config) = operators.get(&operator_id) else {
181                            tracing::warn!(
182                                "received Finished event for unknown operator `{operator_id}`"
183                            );
184                            continue;
185                        };
186                        let outputs = config
187                            .outputs
188                            .iter()
189                            .map(|output_id| operator_output_id(&operator_id, output_id))
190                            .collect();
191                        let result;
192                        (node, result) = tokio::task::spawn_blocking(move || {
193                            let result = node.close_outputs(outputs);
194                            (node, result)
195                        })
196                        .await
197                        .wrap_err("failed to wait for close_outputs task")?;
198                        result.wrap_err("failed to close outputs of finished operator")?;
199
200                        operator_channels.remove(&operator_id);
201
202                        if operator_channels.is_empty() {
203                            break;
204                        }
205                    }
206                    OperatorEvent::AllocateOutputSample { len, sample: tx } => {
207                        let sample = node.allocate_data_sample(len);
208                        if tx.send(sample).is_err() {
209                            tracing::warn!("output sample requested, but operator {operator_id} exited already");
210                        }
211                    }
212                    OperatorEvent::Output {
213                        output_id,
214                        type_info,
215                        parameters,
216                        data,
217                    } => {
218                        let output_id = operator_output_id(&operator_id, &output_id);
219                        let result;
220                        (node, result) = tokio::task::spawn_blocking(move || {
221                            let result =
222                                node.send_output_sample(output_id, type_info, parameters, data);
223                            (node, result)
224                        })
225                        .await
226                        .wrap_err("failed to wait for send_output task")?;
227                        result.wrap_err("failed to send node output")?;
228                    }
229                }
230            }
231            RuntimeEvent::Event(Event::Stop) => {
232                // forward stop event to all operators and close the event channels
233                for (_, channel) in operator_channels.drain() {
234                    let _ = channel.send_async(Event::Stop).await;
235                }
236            }
237            RuntimeEvent::Event(Event::Reload {
238                operator_id: Some(operator_id),
239            }) => {
240                let _ = operator_channels
241                    .get(&operator_id)
242                    .unwrap()
243                    .send_async(Event::Reload {
244                        operator_id: Some(operator_id),
245                    })
246                    .await;
247            }
248            RuntimeEvent::Event(Event::Reload { operator_id: None }) => {
249                tracing::warn!("Reloading runtime nodes is not supported");
250            }
251            RuntimeEvent::Event(Event::Input { id, metadata, data }) => {
252                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
253                    tracing::warn!("received non-operator input {id}");
254                    continue;
255                };
256                let operator_id = OperatorId::from(operator_id.to_owned());
257                let input_id = DataId::from(input_id.to_owned());
258                let Some(operator_channel) = operator_channels.get(&operator_id) else {
259                    tracing::warn!("received input {id} for unknown operator");
260                    continue;
261                };
262
263                if let Err(err) = operator_channel
264                    .send_async(Event::Input {
265                        id: input_id.clone(),
266                        metadata,
267                        data,
268                    })
269                    .await
270                    .wrap_err_with(|| {
271                        format!("failed to send input `{input_id}` to operator `{operator_id}`")
272                    })
273                {
274                    tracing::warn!("{err}");
275                }
276            }
277            RuntimeEvent::Event(Event::InputClosed { id }) => {
278                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
279                    tracing::warn!("received InputClosed event for non-operator input {id}");
280                    continue;
281                };
282                let operator_id = OperatorId::from(operator_id.to_owned());
283                let input_id = DataId::from(input_id.to_owned());
284
285                let Some(operator_channel) = operator_channels.get(&operator_id) else {
286                    tracing::warn!("received input {id} for unknown operator");
287                    continue;
288                };
289                if let Err(err) = operator_channel
290                    .send_async(Event::InputClosed {
291                        id: input_id.clone(),
292                    })
293                    .await
294                    .wrap_err_with(|| {
295                        format!(
296                            "failed to send InputClosed({input_id}) to operator `{operator_id}`"
297                        )
298                    })
299                {
300                    tracing::warn!("{err}");
301                }
302
303                if let Some(open_inputs) = open_operator_inputs.get_mut(&operator_id) {
304                    open_inputs.remove(&input_id);
305                    if open_inputs.is_empty() {
306                        // all inputs of the node were closed -> close its event channel
307                        tracing::trace!("all inputs of operator {}/{operator_id} were closed -> closing event channel", node.id());
308                        open_operator_inputs.remove(&operator_id);
309                        operator_channels.remove(&operator_id);
310                    }
311                }
312            }
313            RuntimeEvent::Event(Event::Error(err)) => eyre::bail!("received error event: {err}"),
314            RuntimeEvent::Event(other) => {
315                tracing::warn!("received unknown event `{other:?}`");
316            }
317        }
318    }
319
320    mem::drop(events);
321
322    Ok(())
323}
324
325fn operator_output_id(operator_id: &OperatorId, output_id: &DataId) -> DataId {
326    DataId::from(format!("{operator_id}/{output_id}"))
327}
328
329#[derive(Debug)]
330enum RuntimeEvent {
331    Operator {
332        id: OperatorId,
333        event: OperatorEvent,
334    },
335    Event(Event),
336}