1#![warn(unsafe_op_in_unsafe_fn)]
2
3use dora_core::{
4 config::{DataId, OperatorId},
5 descriptor::OperatorConfig,
6};
7use dora_message::daemon_to_node::{NodeConfig, RuntimeConfig};
8use dora_metrics::init_meter_provider;
9use dora_node_api::{DoraNode, Event};
10use eyre::{bail, Context, Result};
11use futures::{Stream, StreamExt};
12use futures_concurrency::stream::Merge;
13use operator::{run_operator, OperatorEvent, StopReason};
14
15#[cfg(feature = "tracing")]
16use dora_tracing::set_up_tracing;
17use std::{
18 collections::{BTreeMap, BTreeSet, HashMap},
19 mem,
20};
21use tokio::{
22 runtime::Builder,
23 sync::{mpsc, oneshot},
24};
25use tokio_stream::wrappers::ReceiverStream;
26mod operator;
27
28pub fn main() -> eyre::Result<()> {
29 let config: RuntimeConfig = {
30 let raw = std::env::var("DORA_RUNTIME_CONFIG")
31 .wrap_err("env variable DORA_RUNTIME_CONFIG must be set")?;
32 serde_yaml::from_str(&raw).context("failed to deserialize runtime config")?
33 };
34 let RuntimeConfig {
35 node: config,
36 operators,
37 } = config;
38 let node_id = config.node_id.clone();
39 #[cfg(feature = "tracing")]
40 set_up_tracing(node_id.as_ref()).context("failed to set up tracing subscriber")?;
41
42 let dataflow_descriptor = config.dataflow_descriptor.clone();
43
44 let operator_definition = if operators.is_empty() {
45 bail!("no operators");
46 } else if operators.len() > 1 {
47 bail!("multiple operators are not supported");
48 } else {
49 let mut ops = operators;
50 ops.remove(0)
51 };
52
53 let (operator_events_tx, events) = mpsc::channel(1);
54 let operator_id = operator_definition.id.clone();
55 let operator_events = ReceiverStream::new(events).map(move |event| RuntimeEvent::Operator {
56 id: operator_id.clone(),
57 event,
58 });
59
60 let tokio_runtime = Builder::new_current_thread()
61 .enable_all()
62 .build()
63 .wrap_err("Could not build a tokio runtime.")?;
64
65 let mut operator_channels = HashMap::new();
66 let queue_sizes = queue_sizes(&operator_definition.config);
67 let (operator_channel, incoming_events) =
68 operator::channel::channel(tokio_runtime.handle(), queue_sizes);
69 operator_channels.insert(operator_definition.id.clone(), operator_channel);
70
71 tracing::info!("spawning main task");
72 let operator_config = [(
73 operator_definition.id.clone(),
74 operator_definition.config.clone(),
75 )]
76 .into_iter()
77 .collect();
78 let (init_done_tx, init_done) = oneshot::channel();
79 let main_task = std::thread::spawn(move || -> Result<()> {
80 tokio_runtime.block_on(run(
81 operator_config,
82 config,
83 operator_events,
84 operator_channels,
85 init_done,
86 ))
87 });
88
89 let operator_id = operator_definition.id.clone();
90 run_operator(
91 &node_id,
92 operator_definition,
93 incoming_events,
94 operator_events_tx,
95 init_done_tx,
96 &dataflow_descriptor,
97 )
98 .wrap_err_with(|| format!("failed to run operator {operator_id}"))?;
99
100 match main_task.join() {
101 Ok(result) => result.wrap_err("main task failed")?,
102 Err(panic) => std::panic::resume_unwind(panic),
103 }
104
105 Ok(())
106}
107
108fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap<DataId, usize> {
109 let mut sizes = BTreeMap::new();
110 for (input_id, input) in &config.inputs {
111 let queue_size = input.queue_size.unwrap_or(10);
112 sizes.insert(input_id.clone(), queue_size);
113 }
114 sizes
115}
116
117#[tracing::instrument(skip(operator_events, operator_channels), level = "trace")]
118async fn run(
119 operators: HashMap<OperatorId, OperatorConfig>,
120 config: NodeConfig,
121 operator_events: impl Stream<Item = RuntimeEvent> + Unpin,
122 mut operator_channels: HashMap<OperatorId, flume::Sender<Event>>,
123 init_done: oneshot::Receiver<Result<()>>,
124) -> eyre::Result<()> {
125 #[cfg(feature = "metrics")]
126 let _meter_provider = init_meter_provider(config.node_id.to_string());
127 init_done
128 .await
129 .wrap_err("the `init_done` channel was closed unexpectedly")?
130 .wrap_err("failed to init an operator")?;
131 tracing::info!("All operators are ready, starting runtime");
132
133 let (mut node, mut daemon_events) = DoraNode::init(config)?;
134 let (daemon_events_tx, daemon_event_stream) = flume::bounded(1);
135 tokio::task::spawn_blocking(move || {
136 while let Some(event) = daemon_events.recv() {
137 if daemon_events_tx.send(RuntimeEvent::Event(event)).is_err() {
138 break;
139 }
140 }
141 });
142 let mut events = (operator_events, daemon_event_stream.into_stream()).merge();
143
144 let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators
145 .iter()
146 .map(|(id, config)| (id, config.inputs.keys().collect()))
147 .collect();
148
149 while let Some(event) = events.next().await {
150 match event {
151 RuntimeEvent::Operator {
152 id: operator_id,
153 event,
154 } => {
155 match event {
156 OperatorEvent::Error(err) => {
157 bail!(err.wrap_err(format!(
158 "operator {}/{operator_id} raised an error",
159 node.id()
160 )))
161 }
162 OperatorEvent::Panic(payload) => {
163 bail!("operator {operator_id} panicked: {payload:?}");
164 }
165 OperatorEvent::Finished { reason } => {
166 if let StopReason::ExplicitStopAll = reason {
167 todo!("instruct dora-daemon/dora-coordinator to stop other nodes");
173 }
179
180 let Some(config) = operators.get(&operator_id) else {
181 tracing::warn!(
182 "received Finished event for unknown operator `{operator_id}`"
183 );
184 continue;
185 };
186 let outputs = config
187 .outputs
188 .iter()
189 .map(|output_id| operator_output_id(&operator_id, output_id))
190 .collect();
191 let result;
192 (node, result) = tokio::task::spawn_blocking(move || {
193 let result = node.close_outputs(outputs);
194 (node, result)
195 })
196 .await
197 .wrap_err("failed to wait for close_outputs task")?;
198 result.wrap_err("failed to close outputs of finished operator")?;
199
200 operator_channels.remove(&operator_id);
201
202 if operator_channels.is_empty() {
203 break;
204 }
205 }
206 OperatorEvent::AllocateOutputSample { len, sample: tx } => {
207 let sample = node.allocate_data_sample(len);
208 if tx.send(sample).is_err() {
209 tracing::warn!("output sample requested, but operator {operator_id} exited already");
210 }
211 }
212 OperatorEvent::Output {
213 output_id,
214 type_info,
215 parameters,
216 data,
217 } => {
218 let output_id = operator_output_id(&operator_id, &output_id);
219 let result;
220 (node, result) = tokio::task::spawn_blocking(move || {
221 let result =
222 node.send_output_sample(output_id, type_info, parameters, data);
223 (node, result)
224 })
225 .await
226 .wrap_err("failed to wait for send_output task")?;
227 result.wrap_err("failed to send node output")?;
228 }
229 }
230 }
231 RuntimeEvent::Event(Event::Stop) => {
232 for (_, channel) in operator_channels.drain() {
234 let _ = channel.send_async(Event::Stop).await;
235 }
236 }
237 RuntimeEvent::Event(Event::Reload {
238 operator_id: Some(operator_id),
239 }) => {
240 let _ = operator_channels
241 .get(&operator_id)
242 .unwrap()
243 .send_async(Event::Reload {
244 operator_id: Some(operator_id),
245 })
246 .await;
247 }
248 RuntimeEvent::Event(Event::Reload { operator_id: None }) => {
249 tracing::warn!("Reloading runtime nodes is not supported");
250 }
251 RuntimeEvent::Event(Event::Input { id, metadata, data }) => {
252 let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
253 tracing::warn!("received non-operator input {id}");
254 continue;
255 };
256 let operator_id = OperatorId::from(operator_id.to_owned());
257 let input_id = DataId::from(input_id.to_owned());
258 let Some(operator_channel) = operator_channels.get(&operator_id) else {
259 tracing::warn!("received input {id} for unknown operator");
260 continue;
261 };
262
263 if let Err(err) = operator_channel
264 .send_async(Event::Input {
265 id: input_id.clone(),
266 metadata,
267 data,
268 })
269 .await
270 .wrap_err_with(|| {
271 format!("failed to send input `{input_id}` to operator `{operator_id}`")
272 })
273 {
274 tracing::warn!("{err}");
275 }
276 }
277 RuntimeEvent::Event(Event::InputClosed { id }) => {
278 let Some((operator_id, input_id)) = id.as_str().split_once('/') else {
279 tracing::warn!("received InputClosed event for non-operator input {id}");
280 continue;
281 };
282 let operator_id = OperatorId::from(operator_id.to_owned());
283 let input_id = DataId::from(input_id.to_owned());
284
285 let Some(operator_channel) = operator_channels.get(&operator_id) else {
286 tracing::warn!("received input {id} for unknown operator");
287 continue;
288 };
289 if let Err(err) = operator_channel
290 .send_async(Event::InputClosed {
291 id: input_id.clone(),
292 })
293 .await
294 .wrap_err_with(|| {
295 format!(
296 "failed to send InputClosed({input_id}) to operator `{operator_id}`"
297 )
298 })
299 {
300 tracing::warn!("{err}");
301 }
302
303 if let Some(open_inputs) = open_operator_inputs.get_mut(&operator_id) {
304 open_inputs.remove(&input_id);
305 if open_inputs.is_empty() {
306 tracing::trace!("all inputs of operator {}/{operator_id} were closed -> closing event channel", node.id());
308 open_operator_inputs.remove(&operator_id);
309 operator_channels.remove(&operator_id);
310 }
311 }
312 }
313 RuntimeEvent::Event(Event::Error(err)) => eyre::bail!("received error event: {err}"),
314 RuntimeEvent::Event(other) => {
315 tracing::warn!("received unknown event `{other:?}`");
316 }
317 }
318 }
319
320 mem::drop(events);
321
322 Ok(())
323}
324
325fn operator_output_id(operator_id: &OperatorId, output_id: &DataId) -> DataId {
326 DataId::from(format!("{operator_id}/{output_id}"))
327}
328
329#[derive(Debug)]
330enum RuntimeEvent {
331 Operator {
332 id: OperatorId,
333 event: OperatorEvent,
334 },
335 Event(Event),
336}