Skip to main content

dora_runtime/
lib.rs

1#![warn(unsafe_op_in_unsafe_fn)]
2
3use dora_core::{
4    config::{DataId, OperatorId},
5    descriptor::OperatorConfig,
6};
7use dora_message::daemon_to_node::{NodeConfig, RuntimeConfig};
8use dora_metrics::run_metrics_monitor;
9use dora_node_api::{DoraNode, Event};
10use dora_tracing::TracingBuilder;
11use eyre::{Context, Result, bail};
12use futures::{Stream, StreamExt};
13use futures_concurrency::stream::Merge;
14use operator::{OperatorEvent, StopReason, run_operator};
15
16use std::{
17    collections::{BTreeMap, BTreeSet, HashMap},
18    mem,
19};
20use tokio::{runtime::Builder, sync::oneshot};
21mod operator;
22
23pub fn main() -> eyre::Result<()> {
24    let config: RuntimeConfig = {
25        let raw = std::env::var("DORA_RUNTIME_CONFIG")
26            .wrap_err("env variable DORA_RUNTIME_CONFIG must be set")?;
27        serde_yaml::from_str(&raw).context("failed to deserialize runtime config")?
28    };
29    let RuntimeConfig {
30        node: config,
31        operators,
32    } = config;
33    let node_id = config.node_id.clone();
34    #[cfg(feature = "tracing")]
35    {
36        TracingBuilder::new(node_id.as_ref())
37            .with_stdout("warn", false)
38            .build()
39            .wrap_err("failed to set up tracing subscriber")?;
40    }
41
42    let dataflow_descriptor = serde_yaml::from_value(config.dataflow_descriptor.clone())
43        .context("failed to parse dataflow descriptor")?;
44
45    let operator_definition = if operators.is_empty() {
46        bail!("no operators");
47    } else if operators.len() > 1 {
48        bail!("multiple operators are not supported");
49    } else {
50        let mut ops = operators;
51        ops.remove(0)
52    };
53
54    let (operator_events_tx, events) = flume::bounded(1);
55    let operator_id = operator_definition.id.clone();
56    let operator_events = events
57        .into_stream()
58        .map(move |event| RuntimeEvent::Operator {
59            id: operator_id.clone(),
60            event,
61        });
62
63    let tokio_runtime = Builder::new_current_thread()
64        .enable_all()
65        .build()
66        .wrap_err("Could not build a tokio runtime.")?;
67
68    let mut operator_channels = HashMap::new();
69    let queue_sizes = queue_sizes(&operator_definition.config);
70    let (operator_channel, incoming_events) =
71        operator::channel::channel(tokio_runtime.handle(), queue_sizes);
72    operator_channels.insert(operator_definition.id.clone(), operator_channel);
73
74    tracing::info!("spawning main task");
75    let operator_config = [(
76        operator_definition.id.clone(),
77        operator_definition.config.clone(),
78    )]
79    .into_iter()
80    .collect();
81    let (init_done_tx, init_done) = oneshot::channel();
82    let main_task = std::thread::spawn(move || -> Result<()> {
83        tokio_runtime.block_on(run(
84            operator_config,
85            config,
86            operator_events,
87            operator_channels,
88            init_done,
89        ))
90    });
91
92    let operator_id = operator_definition.id.clone();
93    run_operator(
94        &node_id,
95        operator_definition,
96        incoming_events,
97        operator_events_tx,
98        init_done_tx,
99        &dataflow_descriptor,
100    )
101    .wrap_err_with(|| format!("failed to run operator {operator_id}"))?;
102
103    match main_task.join() {
104        Ok(result) => result.wrap_err("main task failed")?,
105        Err(panic) => std::panic::resume_unwind(panic),
106    }
107
108    Ok(())
109}
110
111fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap<DataId, usize> {
112    let mut sizes = BTreeMap::new();
113    for (input_id, input) in &config.inputs {
114        let queue_size = input.queue_size.unwrap_or(10);
115        sizes.insert(input_id.clone(), queue_size);
116    }
117    sizes
118}
119
120#[tracing::instrument(skip(operator_events, operator_channels), level = "trace")]
121async fn run(
122    operators: HashMap<OperatorId, OperatorConfig>,
123    config: NodeConfig,
124    operator_events: impl Stream<Item = RuntimeEvent> + Unpin,
125    mut operator_channels: HashMap<OperatorId, flume::Sender<Event>>,
126    init_done: oneshot::Receiver<Result<()>>,
127) -> eyre::Result<()> {
128    #[cfg(feature = "metrics")]
129    let _meter_provider = run_metrics_monitor(config.node_id.to_string());
130    init_done
131        .await
132        .wrap_err("the `init_done` channel was closed unexpectedly")?
133        .wrap_err("failed to init an operator")?;
134    tracing::info!("All operators are ready, starting runtime");
135
136    let (mut node, mut daemon_events) = DoraNode::init(config)?;
137    let (daemon_events_tx, daemon_event_stream) = flume::bounded(1);
138    tokio::task::spawn_blocking(move || {
139        while let Some(event) = daemon_events.recv() {
140            if daemon_events_tx.send(RuntimeEvent::Event(event)).is_err() {
141                break;
142            }
143        }
144    });
145    let mut events = (operator_events, daemon_event_stream.into_stream()).merge();
146
147    let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators
148        .iter()
149        .map(|(id, config)| (id, config.inputs.keys().collect()))
150        .collect();
151
152    while let Some(event) = events.next().await {
153        match event {
154            RuntimeEvent::Operator {
155                id: operator_id,
156                event,
157            } => {
158                match event {
159                    OperatorEvent::Error(err) => {
160                        bail!(err.wrap_err(format!(
161                            "operator {}/{operator_id} raised an error",
162                            node.id()
163                        )))
164                    }
165                    OperatorEvent::Panic(payload) => {
166                        bail!("operator {operator_id} panicked: {payload:?}");
167                    }
168                    OperatorEvent::Finished { reason } => {
169                        if let StopReason::ExplicitStopAll = reason {
170                            // let hlc = dora_core::message::uhlc::HLC::default();
171                            // let metadata = dora_core::message::Metadata::new(hlc.new_timestamp());
172                            // let data = metadata
173                            // .serialize()
174                            // .wrap_err("failed to serialize stop message")?;
175                            todo!("instruct dora-daemon/dora-coordinator to stop other nodes");
176                            // manual_stop_publisher
177                            //     .publish(&data)
178                            //     .map_err(|err| eyre::eyre!(err))
179                            //     .wrap_err("failed to send stop message")?;
180                            // break;
181                        }
182
183                        let Some(config) = operators.get(&operator_id) else {
184                            tracing::warn!(
185                                "received Finished event for unknown operator `{operator_id}`"
186                            );
187                            continue;
188                        };
189                        let outputs = config
190                            .outputs
191                            .iter()
192                            .map(|output_id| operator_output_id(&operator_id, output_id))
193                            .collect();
194                        let result;
195                        (node, result) = tokio::task::spawn_blocking(move || {
196                            let result = node.close_outputs(outputs);
197                            (node, result)
198                        })
199                        .await
200                        .wrap_err("failed to wait for close_outputs task")?;
201                        result.wrap_err("failed to close outputs of finished operator")?;
202
203                        operator_channels.remove(&operator_id);
204
205                        if operator_channels.is_empty() {
206                            break;
207                        }
208                    }
209                    OperatorEvent::AllocateOutputSample { len, sample: tx } => {
210                        let sample = node.allocate_data_sample(len);
211                        if tx.send(sample).is_err() {
212                            tracing::warn!(
213                                "output sample requested, but operator {operator_id} exited already"
214                            );
215                        }
216                    }
217                    OperatorEvent::Output {
218                        output_id,
219                        type_info,
220                        parameters,
221                        data,
222                    } => {
223                        let output_id = operator_output_id(&operator_id, &output_id);
224                        let result;
225                        (node, result) = tokio::task::spawn_blocking(move || {
226                            let result =
227                                node.send_output_sample(output_id, type_info, parameters, data);
228                            (node, result)
229                        })
230                        .await
231                        .wrap_err("failed to wait for send_output task")?;
232                        result.wrap_err("failed to send node output")?;
233                    }
234                }
235            }
236            RuntimeEvent::Event(Event::Stop(cause)) => {
237                // forward stop event to all operators and close the event channels
238                for (_, channel) in operator_channels.drain() {
239                    let _ = channel.send_async(Event::Stop(cause.clone())).await;
240                }
241            }
242            RuntimeEvent::Event(Event::Reload {
243                operator_id: Some(operator_id),
244            }) => {
245                let _ = operator_channels
246                    .get(&operator_id)
247                    .unwrap()
248                    .send_async(Event::Reload {
249                        operator_id: Some(operator_id),
250                    })
251                    .await;
252            }
253            RuntimeEvent::Event(Event::Reload { operator_id: None }) => {
254                tracing::warn!("Reloading runtime nodes is not supported");
255            }
256            RuntimeEvent::Event(Event::Input { id, metadata, data }) => {
257                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
258                    tracing::warn!("received non-operator input {id}");
259                    continue;
260                };
261                let operator_id = OperatorId::from(operator_id.to_owned());
262                let input_id = DataId::from(input_id.to_owned());
263                let Some(operator_channel) = operator_channels.get(&operator_id) else {
264                    tracing::warn!("received input {id} for unknown operator");
265                    continue;
266                };
267
268                if let Err(err) = operator_channel
269                    .send_async(Event::Input {
270                        id: input_id.clone(),
271                        metadata,
272                        data,
273                    })
274                    .await
275                    .wrap_err_with(|| {
276                        format!("failed to send input `{input_id}` to operator `{operator_id}`")
277                    })
278                {
279                    tracing::warn!("{err}");
280                }
281            }
282            RuntimeEvent::Event(Event::InputClosed { id }) => {
283                let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
284                    tracing::warn!("received InputClosed event for non-operator input {id}");
285                    continue;
286                };
287                let operator_id = OperatorId::from(operator_id.to_owned());
288                let input_id = DataId::from(input_id.to_owned());
289
290                let Some(operator_channel) = operator_channels.get(&operator_id) else {
291                    tracing::warn!("received input {id} for unknown operator");
292                    continue;
293                };
294                if let Err(err) = operator_channel
295                    .send_async(Event::InputClosed {
296                        id: input_id.clone(),
297                    })
298                    .await
299                    .wrap_err_with(|| {
300                        format!(
301                            "failed to send InputClosed({input_id}) to operator `{operator_id}`"
302                        )
303                    })
304                {
305                    tracing::warn!("{err}");
306                }
307
308                if let Some(open_inputs) = open_operator_inputs.get_mut(&operator_id) {
309                    open_inputs.remove(&input_id);
310                    if open_inputs.is_empty() {
311                        // all inputs of the node were closed -> close its event channel
312                        tracing::trace!(
313                            "all inputs of operator {}/{operator_id} were closed -> closing event channel",
314                            node.id()
315                        );
316                        open_operator_inputs.remove(&operator_id);
317                        operator_channels.remove(&operator_id);
318                    }
319                }
320            }
321            RuntimeEvent::Event(Event::Error(err)) => eyre::bail!("received error event: {err}"),
322            RuntimeEvent::Event(other) => {
323                tracing::warn!("received unknown event `{other:?}`");
324            }
325        }
326    }
327
328    mem::drop(events);
329
330    Ok(())
331}
332
333fn operator_output_id(operator_id: &OperatorId, output_id: &DataId) -> DataId {
334    DataId::from(format!("{operator_id}/{output_id}"))
335}
336
337#[derive(Debug)]
338enum RuntimeEvent {
339    Operator {
340        id: OperatorId,
341        event: OperatorEvent,
342    },
343    Event(Event),
344}