1use crate::{
2 adjust_shared_library_path,
3 descriptor::{self, source_is_url},
4 get_python_path,
5};
6
7use dora_message::{
8 config::{Input, InputMapping, UserInputMapping},
9 descriptor::{CoreNodeKind, DYNAMIC_SOURCE, OperatorSource, ResolvedNode, SHELL_SOURCE},
10 id::{DataId, NodeId, OperatorId},
11};
12use eyre::{Context, bail, eyre};
13use std::{collections::BTreeMap, path::Path, process::Command, time::Duration};
14use tracing::info;
15
16use super::{Descriptor, DescriptorExt, resolve_path};
17const VERSION: &str = env!("CARGO_PKG_VERSION");
18
19pub fn check_dataflow(
20 dataflow: &Descriptor,
21 working_dir: &Path,
22 remote_daemon_id: Option<&[&str]>,
23 coordinator_is_remote: bool,
24) -> eyre::Result<()> {
25 let nodes = dataflow.resolve_aliases_and_set_defaults()?;
26 let mut has_python_operator = false;
27 let mut errors: Vec<String> = Vec::new();
28
29 for node in nodes.values() {
31 match &node.kind {
32 descriptor::CoreNodeKind::Custom(custom) => match &custom.source {
33 dora_message::descriptor::NodeSource::Local => match custom.path.as_str() {
34 SHELL_SOURCE => (),
35 DYNAMIC_SOURCE => (),
36 source => {
37 if source_is_url(source) {
38 if let Err(err) = check_url(source) {
39 errors.push(format!("node `{}`: {err}", node.id));
40 }
41 } else if let Some(remote_daemon_id) = remote_daemon_id {
42 if let Some(deploy) = &node.deploy {
43 if let Some(machine) = &deploy.machine {
44 if remote_daemon_id.contains(&machine.as_str())
45 || coordinator_is_remote
46 {
47 info!("skipping path check for remote node `{}`", node.id);
48 }
49 }
50 }
51 } else if custom.build.is_some() {
52 info!("skipping path check for node with build command");
53 } else if let Err(err) = resolve_path(source, working_dir) {
54 errors.push(format!("node `{}`: {err}", node.id));
55 };
56 }
57 },
58 dora_message::descriptor::NodeSource::GitBranch { .. } => {
59 info!("skipping check for node with git source");
60 }
61 },
62 descriptor::CoreNodeKind::Runtime(runtime_node) => {
63 for operator_definition in &runtime_node.operators {
64 match &operator_definition.config.source {
65 OperatorSource::SharedLibrary(path) => {
66 if source_is_url(path) {
67 if let Err(err) = check_url(path) {
68 errors.push(format!(
69 "node `{}`, operator `{}`: {err}",
70 node.id, operator_definition.id,
71 ));
72 }
73 } else if operator_definition.config.build.is_some() {
74 info!("skipping path check for operator with build command");
75 } else {
76 match adjust_shared_library_path(Path::new(&path)) {
77 Ok(path) => {
78 if !working_dir.join(&path).exists() {
79 errors.push(format!(
80 "node `{}`, operator `{}`: no shared library at `{}`",
81 node.id,
82 operator_definition.id,
83 path.display()
84 ));
85 }
86 }
87 Err(err) => {
88 errors.push(format!(
89 "node `{}`, operator `{}`: {err}",
90 node.id, operator_definition.id,
91 ));
92 }
93 }
94 }
95 }
96 OperatorSource::Python(python_source) => {
97 has_python_operator = true;
98 let path = &python_source.source;
99 if source_is_url(path) {
100 if let Err(err) = check_url(path) {
101 errors.push(format!(
102 "node `{}`, operator `{}`: {err}",
103 node.id, operator_definition.id,
104 ));
105 }
106 } else if !working_dir.join(path).exists() {
107 errors.push(format!(
108 "node `{}`, operator `{}`: no Python library at `{path}`",
109 node.id, operator_definition.id,
110 ));
111 }
112 }
113 OperatorSource::Wasm(path) => {
114 if source_is_url(path) {
115 if let Err(err) = check_url(path) {
116 errors.push(format!(
117 "node `{}`, operator `{}`: {err}",
118 node.id, operator_definition.id,
119 ));
120 }
121 } else if !working_dir.join(path).exists() {
122 errors.push(format!(
123 "node `{}`, operator `{}`: no WASM library at `{path}`",
124 node.id, operator_definition.id,
125 ));
126 }
127 }
128 }
129 }
130 }
131 }
132 }
133
134 for node in nodes.values() {
136 match &node.kind {
137 descriptor::CoreNodeKind::Custom(custom_node) => {
138 for (input_id, input) in &custom_node.run_config.inputs {
139 if let Err(err) = check_input(input, &nodes, &format!("{}/{input_id}", node.id))
140 {
141 errors.push(format!("{err}"));
142 }
143 }
144 }
145 descriptor::CoreNodeKind::Runtime(runtime_node) => {
146 for operator_definition in &runtime_node.operators {
147 for (input_id, input) in &operator_definition.config.inputs {
148 if let Err(err) = check_input(
149 input,
150 &nodes,
151 &format!("{}/{}/{input_id}", operator_definition.id, node.id),
152 ) {
153 errors.push(format!("{err}"));
154 }
155 }
156 }
157 }
158 };
159 }
160
161 for node in nodes.values() {
163 if let Err(err) = node.send_stdout_as() {
164 errors.push(format!(
165 "node `{}`: could not resolve `send_stdout_as` configuration: {err}",
166 node.id
167 ));
168 }
169 }
170
171 if has_python_operator {
172 if let Err(err) = check_python_runtime() {
173 errors.push(format!("{err}"));
174 }
175 }
176
177 if errors.is_empty() {
178 Ok(())
179 } else {
180 let error_list = errors
181 .iter()
182 .map(|e| format!(" - {e}"))
183 .collect::<Vec<_>>()
184 .join("\n");
185 bail!(
186 "found {} validation error(s):\n{}",
187 errors.len(),
188 error_list
189 );
190 }
191}
192
193pub trait ResolvedNodeExt {
194 fn send_stdout_as(&self) -> eyre::Result<Option<String>>;
195}
196
197impl ResolvedNodeExt for ResolvedNode {
198 fn send_stdout_as(&self) -> eyre::Result<Option<String>> {
199 match &self.kind {
200 CoreNodeKind::Runtime(n) => {
202 let count = n
203 .operators
204 .iter()
205 .filter(|op| op.config.send_stdout_as.is_some())
206 .count();
207 if count == 1 && n.operators.len() > 1 {
208 tracing::warn!(
209 "All stdout from all operators of a runtime are going to be sent in the selected `send_stdout_as` operator."
210 )
211 } else if count > 1 {
212 return Err(eyre!(
213 "More than one `send_stdout_as` entries for a runtime node. Please only use one `send_stdout_as` per runtime."
214 ));
215 }
216 Ok(n.operators.iter().find_map(|op| {
217 op.config
218 .send_stdout_as
219 .clone()
220 .map(|stdout| format!("{}/{}", op.id, stdout))
221 }))
222 }
223 CoreNodeKind::Custom(n) => Ok(n.send_stdout_as.clone()),
224 }
225 }
226}
227
228fn check_input(
229 input: &Input,
230 nodes: &BTreeMap<NodeId, super::ResolvedNode>,
231 input_id_str: &str,
232) -> Result<(), eyre::ErrReport> {
233 match &input.mapping {
234 InputMapping::Timer { interval: _ } => {}
235 InputMapping::User(UserInputMapping { source, output }) => {
236 let source_node = nodes.values().find(|n| &n.id == source).ok_or_else(|| {
237 eyre!("source node `{source}` mapped to input `{input_id_str}` does not exist",)
238 })?;
239 match &source_node.kind {
240 CoreNodeKind::Custom(custom_node) => {
241 if !custom_node.run_config.outputs.contains(output) {
242 bail!(
243 "output `{source}/{output}` mapped to \
244 input `{input_id_str}` does not exist",
245 );
246 }
247 }
248 CoreNodeKind::Runtime(runtime) => {
249 let (operator_id, output) = output.split_once('/').unwrap_or_default();
250 let operator_id = OperatorId::from(operator_id.to_owned());
251 let output = DataId::from(output.to_owned());
252
253 let operator = runtime
254 .operators
255 .iter()
256 .find(|o| o.id == operator_id)
257 .ok_or_else(|| {
258 eyre!(
259 "source operator `{source}/{operator_id}` used \
260 for input `{input_id_str}` does not exist",
261 )
262 })?;
263
264 if !operator.config.outputs.contains(&output) {
265 bail!(
266 "output `{source}/{operator_id}/{output}` mapped to \
267 input `{input_id_str}` does not exist",
268 );
269 }
270 }
271 }
272 }
273 };
274 Ok(())
275}
276
277fn check_python_runtime() -> eyre::Result<()> {
278 let reinstall_command =
280 format!("Please reinstall it with: `pip install dora-rs=={VERSION} --force`");
281 let mut command = Command::new(get_python_path().context("Could not get python binary")?);
282 command.args([
283 "-c",
284 &format!(
285 "
286import dora;
287assert dora.__version__=='{VERSION}', 'Python dora-rs should be {VERSION}, but current version is %s. {reinstall_command}' % (dora.__version__)
288 "
289 ),
290 ]);
291 let mut result = command
292 .spawn()
293 .wrap_err("Could not spawn python dora-rs command.")?;
294 let status = result
295 .wait()
296 .wrap_err("Could not get exit status when checking python dora-rs")?;
297
298 if !status.success() {
299 bail!("Something went wrong with Python dora-rs. {reinstall_command}")
300 }
301
302 Ok(())
303}
304
305fn check_url(url: &str) -> eyre::Result<()> {
306 let client = reqwest::blocking::Client::builder()
307 .timeout(Duration::from_secs(5))
308 .build()
309 .wrap_err("failed to build HTTP client for URL validation")?;
310
311 match client.head(url).send() {
312 Ok(response) => {
313 if response.status().is_success() {
314 Ok(())
315 } else if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
316 match client.get(url).send() {
317 Ok(get_response) if get_response.status().is_success() => Ok(()),
318 Ok(get_response) => eyre::bail!(
319 "URL `{}` is not reachable (status code: {})",
320 url,
321 get_response.status()
322 ),
323 Err(err) => eyre::bail!("Failed to reach URL `{}`: {}", url, err),
324 }
325 } else {
326 eyre::bail!(
327 "URL `{}` is not reachable (status code: {})",
328 url,
329 response.status()
330 )
331 }
332 }
333 Err(err) => eyre::bail!("Failed to reach URL `{}`: {}", url, err),
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::check_url;
340 use std::io::{Read, Write};
341 use std::net::{TcpListener, TcpStream};
342 use std::thread;
343
344 fn read_request_method(stream: &mut TcpStream) -> String {
345 let mut buffer = [0_u8; 2048];
346 let bytes_read = stream.read(&mut buffer).expect("failed to read request");
347 let request = std::str::from_utf8(&buffer[..bytes_read]).expect("invalid UTF-8 in request");
348 request
349 .lines()
350 .next()
351 .and_then(|line| line.split_whitespace().next())
352 .unwrap_or_default()
353 .to_string()
354 }
355
356 #[test]
357 fn check_url_accepts_success_status() {
358 let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
359 let addr = listener.local_addr().expect("failed to get local addr");
360
361 let handle = thread::spawn(move || {
362 let (mut stream, _) = listener.accept().expect("failed to accept connection");
363 let _ = read_request_method(&mut stream);
364 stream
365 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
366 .expect("failed to write response");
367 });
368
369 check_url(&format!("http://{addr}/ok")).expect("URL should be reachable");
370 handle.join().expect("server thread panicked");
371 }
372
373 #[test]
374 fn check_url_rejects_non_success_status() {
375 let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
376 let addr = listener.local_addr().expect("failed to get local addr");
377
378 let handle = thread::spawn(move || {
379 let (mut stream, _) = listener.accept().expect("failed to accept connection");
380 let _ = read_request_method(&mut stream);
381 stream
382 .write_all(b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
383 .expect("failed to write response");
384 });
385
386 let err = check_url(&format!("http://{addr}/missing")).expect_err("URL should be rejected");
387 assert!(err.to_string().contains("status code: 404"));
388 handle.join().expect("server thread panicked");
389 }
390
391 #[test]
392 fn check_url_falls_back_to_get_when_head_not_allowed() {
393 let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
394 let addr = listener.local_addr().expect("failed to get local addr");
395
396 let handle = thread::spawn(move || {
397 for _ in 0..2 {
398 let (mut stream, _) = listener.accept().expect("failed to accept connection");
399 let method = read_request_method(&mut stream);
400 if method == "HEAD" {
401 stream
402 .write_all(b"HTTP/1.1 405 Method Not Allowed\r\nConnection: close\r\nContent-Length: 0\r\n\r\n")
403 .expect("failed to write response");
404 } else {
405 stream
406 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
407 .expect("failed to write response");
408 }
409 }
410 });
411
412 check_url(&format!("http://{addr}/head-not-allowed"))
413 .expect("GET fallback should mark URL as reachable");
414 handle.join().expect("server thread panicked");
415 }
416}