Skip to main content

dora_core/descriptor/
validate.rs

1use crate::{
2    adjust_shared_library_path,
3    descriptor::{self, source_is_url},
4    get_python_path,
5};
6
7use dora_message::{
8    config::{Input, InputMapping, UserInputMapping},
9    descriptor::{CoreNodeKind, DYNAMIC_SOURCE, OperatorSource, ResolvedNode, SHELL_SOURCE},
10    id::{DataId, NodeId, OperatorId},
11};
12use eyre::{Context, bail, eyre};
13use std::{collections::BTreeMap, path::Path, process::Command, time::Duration};
14use tracing::info;
15
16use super::{Descriptor, DescriptorExt, resolve_path};
17const VERSION: &str = env!("CARGO_PKG_VERSION");
18
19pub fn check_dataflow(
20    dataflow: &Descriptor,
21    working_dir: &Path,
22    remote_daemon_id: Option<&[&str]>,
23    coordinator_is_remote: bool,
24) -> eyre::Result<()> {
25    let nodes = dataflow.resolve_aliases_and_set_defaults()?;
26    let mut has_python_operator = false;
27    let mut errors: Vec<String> = Vec::new();
28
29    // check that nodes and operators exist
30    for node in nodes.values() {
31        match &node.kind {
32            descriptor::CoreNodeKind::Custom(custom) => match &custom.source {
33                dora_message::descriptor::NodeSource::Local => match custom.path.as_str() {
34                    SHELL_SOURCE => (),
35                    DYNAMIC_SOURCE => (),
36                    source => {
37                        if source_is_url(source) {
38                            if let Err(err) = check_url(source) {
39                                errors.push(format!("node `{}`: {err}", node.id));
40                            }
41                        } else if let Some(remote_daemon_id) = remote_daemon_id {
42                            if let Some(deploy) = &node.deploy {
43                                if let Some(machine) = &deploy.machine {
44                                    if remote_daemon_id.contains(&machine.as_str())
45                                        || coordinator_is_remote
46                                    {
47                                        info!("skipping path check for remote node `{}`", node.id);
48                                    }
49                                }
50                            }
51                        } else if custom.build.is_some() {
52                            info!("skipping path check for node with build command");
53                        } else if let Err(err) = resolve_path(source, working_dir) {
54                            errors.push(format!("node `{}`: {err}", node.id));
55                        };
56                    }
57                },
58                dora_message::descriptor::NodeSource::GitBranch { .. } => {
59                    info!("skipping check for node with git source");
60                }
61            },
62            descriptor::CoreNodeKind::Runtime(runtime_node) => {
63                for operator_definition in &runtime_node.operators {
64                    match &operator_definition.config.source {
65                        OperatorSource::SharedLibrary(path) => {
66                            if source_is_url(path) {
67                                if let Err(err) = check_url(path) {
68                                    errors.push(format!(
69                                        "node `{}`, operator `{}`: {err}",
70                                        node.id, operator_definition.id,
71                                    ));
72                                }
73                            } else if operator_definition.config.build.is_some() {
74                                info!("skipping path check for operator with build command");
75                            } else {
76                                match adjust_shared_library_path(Path::new(&path)) {
77                                    Ok(path) => {
78                                        if !working_dir.join(&path).exists() {
79                                            errors.push(format!(
80                                                "node `{}`, operator `{}`: no shared library at `{}`",
81                                                node.id,
82                                                operator_definition.id,
83                                                path.display()
84                                            ));
85                                        }
86                                    }
87                                    Err(err) => {
88                                        errors.push(format!(
89                                            "node `{}`, operator `{}`: {err}",
90                                            node.id, operator_definition.id,
91                                        ));
92                                    }
93                                }
94                            }
95                        }
96                        OperatorSource::Python(python_source) => {
97                            has_python_operator = true;
98                            let path = &python_source.source;
99                            if source_is_url(path) {
100                                if let Err(err) = check_url(path) {
101                                    errors.push(format!(
102                                        "node `{}`, operator `{}`: {err}",
103                                        node.id, operator_definition.id,
104                                    ));
105                                }
106                            } else if !working_dir.join(path).exists() {
107                                errors.push(format!(
108                                    "node `{}`, operator `{}`: no Python library at `{path}`",
109                                    node.id, operator_definition.id,
110                                ));
111                            }
112                        }
113                        OperatorSource::Wasm(path) => {
114                            if source_is_url(path) {
115                                if let Err(err) = check_url(path) {
116                                    errors.push(format!(
117                                        "node `{}`, operator `{}`: {err}",
118                                        node.id, operator_definition.id,
119                                    ));
120                                }
121                            } else if !working_dir.join(path).exists() {
122                                errors.push(format!(
123                                    "node `{}`, operator `{}`: no WASM library at `{path}`",
124                                    node.id, operator_definition.id,
125                                ));
126                            }
127                        }
128                    }
129                }
130            }
131        }
132    }
133
134    // check that all inputs mappings point to an existing output
135    for node in nodes.values() {
136        match &node.kind {
137            descriptor::CoreNodeKind::Custom(custom_node) => {
138                for (input_id, input) in &custom_node.run_config.inputs {
139                    if let Err(err) = check_input(input, &nodes, &format!("{}/{input_id}", node.id))
140                    {
141                        errors.push(format!("{err}"));
142                    }
143                }
144            }
145            descriptor::CoreNodeKind::Runtime(runtime_node) => {
146                for operator_definition in &runtime_node.operators {
147                    for (input_id, input) in &operator_definition.config.inputs {
148                        if let Err(err) = check_input(
149                            input,
150                            &nodes,
151                            &format!("{}/{}/{input_id}", operator_definition.id, node.id),
152                        ) {
153                            errors.push(format!("{err}"));
154                        }
155                    }
156                }
157            }
158        };
159    }
160
161    // Check that nodes can resolve `send_stdout_as`
162    for node in nodes.values() {
163        if let Err(err) = node.send_stdout_as() {
164            errors.push(format!(
165                "node `{}`: could not resolve `send_stdout_as` configuration: {err}",
166                node.id
167            ));
168        }
169    }
170
171    if has_python_operator {
172        if let Err(err) = check_python_runtime() {
173            errors.push(format!("{err}"));
174        }
175    }
176
177    if errors.is_empty() {
178        Ok(())
179    } else {
180        let error_list = errors
181            .iter()
182            .map(|e| format!("  - {e}"))
183            .collect::<Vec<_>>()
184            .join("\n");
185        bail!(
186            "found {} validation error(s):\n{}",
187            errors.len(),
188            error_list
189        );
190    }
191}
192
193pub trait ResolvedNodeExt {
194    fn send_stdout_as(&self) -> eyre::Result<Option<String>>;
195}
196
197impl ResolvedNodeExt for ResolvedNode {
198    fn send_stdout_as(&self) -> eyre::Result<Option<String>> {
199        match &self.kind {
200            // TODO: Split stdout between operators
201            CoreNodeKind::Runtime(n) => {
202                let count = n
203                    .operators
204                    .iter()
205                    .filter(|op| op.config.send_stdout_as.is_some())
206                    .count();
207                if count == 1 && n.operators.len() > 1 {
208                    tracing::warn!(
209                        "All stdout from all operators of a runtime are going to be sent in the selected `send_stdout_as` operator."
210                    )
211                } else if count > 1 {
212                    return Err(eyre!(
213                        "More than one `send_stdout_as` entries for a runtime node. Please only use one `send_stdout_as` per runtime."
214                    ));
215                }
216                Ok(n.operators.iter().find_map(|op| {
217                    op.config
218                        .send_stdout_as
219                        .clone()
220                        .map(|stdout| format!("{}/{}", op.id, stdout))
221                }))
222            }
223            CoreNodeKind::Custom(n) => Ok(n.send_stdout_as.clone()),
224        }
225    }
226}
227
228fn check_input(
229    input: &Input,
230    nodes: &BTreeMap<NodeId, super::ResolvedNode>,
231    input_id_str: &str,
232) -> Result<(), eyre::ErrReport> {
233    match &input.mapping {
234        InputMapping::Timer { interval: _ } => {}
235        InputMapping::User(UserInputMapping { source, output }) => {
236            let source_node = nodes.values().find(|n| &n.id == source).ok_or_else(|| {
237                eyre!("source node `{source}` mapped to input `{input_id_str}` does not exist",)
238            })?;
239            match &source_node.kind {
240                CoreNodeKind::Custom(custom_node) => {
241                    if !custom_node.run_config.outputs.contains(output) {
242                        bail!(
243                            "output `{source}/{output}` mapped to \
244                            input `{input_id_str}` does not exist",
245                        );
246                    }
247                }
248                CoreNodeKind::Runtime(runtime) => {
249                    let (operator_id, output) = output.split_once('/').unwrap_or_default();
250                    let operator_id = OperatorId::from(operator_id.to_owned());
251                    let output = DataId::from(output.to_owned());
252
253                    let operator = runtime
254                        .operators
255                        .iter()
256                        .find(|o| o.id == operator_id)
257                        .ok_or_else(|| {
258                            eyre!(
259                                "source operator `{source}/{operator_id}` used \
260                                for input `{input_id_str}` does not exist",
261                            )
262                        })?;
263
264                    if !operator.config.outputs.contains(&output) {
265                        bail!(
266                            "output `{source}/{operator_id}/{output}` mapped to \
267                            input `{input_id_str}` does not exist",
268                        );
269                    }
270                }
271            }
272        }
273    };
274    Ok(())
275}
276
277fn check_python_runtime() -> eyre::Result<()> {
278    // Check if python dora-rs is installed and match cli version
279    let reinstall_command =
280        format!("Please reinstall it with: `pip install dora-rs=={VERSION} --force`");
281    let mut command = Command::new(get_python_path().context("Could not get python binary")?);
282    command.args([
283        "-c",
284        &format!(
285            "
286import dora;
287assert dora.__version__=='{VERSION}',  'Python dora-rs should be {VERSION}, but current version is %s. {reinstall_command}' % (dora.__version__)
288        "
289        ),
290    ]);
291    let mut result = command
292        .spawn()
293        .wrap_err("Could not spawn python dora-rs command.")?;
294    let status = result
295        .wait()
296        .wrap_err("Could not get exit status when checking python dora-rs")?;
297
298    if !status.success() {
299        bail!("Something went wrong with Python dora-rs. {reinstall_command}")
300    }
301
302    Ok(())
303}
304
305fn check_url(url: &str) -> eyre::Result<()> {
306    let client = reqwest::blocking::Client::builder()
307        .timeout(Duration::from_secs(5))
308        .build()
309        .wrap_err("failed to build HTTP client for URL validation")?;
310
311    match client.head(url).send() {
312        Ok(response) => {
313            if response.status().is_success() {
314                Ok(())
315            } else if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
316                match client.get(url).send() {
317                    Ok(get_response) if get_response.status().is_success() => Ok(()),
318                    Ok(get_response) => eyre::bail!(
319                        "URL `{}` is not reachable (status code: {})",
320                        url,
321                        get_response.status()
322                    ),
323                    Err(err) => eyre::bail!("Failed to reach URL `{}`: {}", url, err),
324                }
325            } else {
326                eyre::bail!(
327                    "URL `{}` is not reachable (status code: {})",
328                    url,
329                    response.status()
330                )
331            }
332        }
333        Err(err) => eyre::bail!("Failed to reach URL `{}`: {}", url, err),
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::check_url;
340    use std::io::{Read, Write};
341    use std::net::{TcpListener, TcpStream};
342    use std::thread;
343
344    fn read_request_method(stream: &mut TcpStream) -> String {
345        let mut buffer = [0_u8; 2048];
346        let bytes_read = stream.read(&mut buffer).expect("failed to read request");
347        let request = std::str::from_utf8(&buffer[..bytes_read]).expect("invalid UTF-8 in request");
348        request
349            .lines()
350            .next()
351            .and_then(|line| line.split_whitespace().next())
352            .unwrap_or_default()
353            .to_string()
354    }
355
356    #[test]
357    fn check_url_accepts_success_status() {
358        let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
359        let addr = listener.local_addr().expect("failed to get local addr");
360
361        let handle = thread::spawn(move || {
362            let (mut stream, _) = listener.accept().expect("failed to accept connection");
363            let _ = read_request_method(&mut stream);
364            stream
365                .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
366                .expect("failed to write response");
367        });
368
369        check_url(&format!("http://{addr}/ok")).expect("URL should be reachable");
370        handle.join().expect("server thread panicked");
371    }
372
373    #[test]
374    fn check_url_rejects_non_success_status() {
375        let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
376        let addr = listener.local_addr().expect("failed to get local addr");
377
378        let handle = thread::spawn(move || {
379            let (mut stream, _) = listener.accept().expect("failed to accept connection");
380            let _ = read_request_method(&mut stream);
381            stream
382                .write_all(b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
383                .expect("failed to write response");
384        });
385
386        let err = check_url(&format!("http://{addr}/missing")).expect_err("URL should be rejected");
387        assert!(err.to_string().contains("status code: 404"));
388        handle.join().expect("server thread panicked");
389    }
390
391    #[test]
392    fn check_url_falls_back_to_get_when_head_not_allowed() {
393        let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind test server");
394        let addr = listener.local_addr().expect("failed to get local addr");
395
396        let handle = thread::spawn(move || {
397            for _ in 0..2 {
398                let (mut stream, _) = listener.accept().expect("failed to accept connection");
399                let method = read_request_method(&mut stream);
400                if method == "HEAD" {
401                    stream
402                        .write_all(b"HTTP/1.1 405 Method Not Allowed\r\nConnection: close\r\nContent-Length: 0\r\n\r\n")
403                        .expect("failed to write response");
404                } else {
405                    stream
406                        .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
407                        .expect("failed to write response");
408                }
409            }
410        });
411
412        check_url(&format!("http://{addr}/head-not-allowed"))
413            .expect("GET fallback should mark URL as reachable");
414        handle.join().expect("server thread panicked");
415    }
416}