dora_message/
config.rs

1use core::fmt;
2use std::{
3    collections::{BTreeMap, BTreeSet},
4    str::FromStr,
5    time::Duration,
6};
7
8use once_cell::sync::OnceCell;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11
12pub use crate::id::{DataId, NodeId, OperatorId};
13
14/// Contains the input and output configuration of the node.
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
16pub struct NodeRunConfig {
17    /// Inputs for the nodes as a map from input ID to `node_id/output_id`.
18    ///
19    /// e.g.
20    ///
21    /// inputs:
22    ///
23    ///   example_input: example_node/example_output1
24    ///
25    #[serde(default)]
26    pub inputs: BTreeMap<DataId, Input>,
27    /// List of output IDs.
28    ///
29    /// e.g.
30    ///
31    /// outputs:
32    ///
33    ///  - output_1
34    ///
35    ///  - output_2
36    #[serde(default)]
37    pub outputs: BTreeSet<DataId>,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
41#[serde(from = "InputDef", into = "InputDef")]
42pub struct Input {
43    pub mapping: InputMapping,
44    pub queue_size: Option<usize>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
48#[serde(untagged)]
49pub enum InputDef {
50    MappingOnly(InputMapping),
51    WithOptions {
52        source: InputMapping,
53        queue_size: Option<usize>,
54    },
55}
56
57impl From<Input> for InputDef {
58    fn from(input: Input) -> Self {
59        match input {
60            Input {
61                mapping,
62                queue_size: None,
63            } => Self::MappingOnly(mapping),
64            Input {
65                mapping,
66                queue_size,
67            } => Self::WithOptions {
68                source: mapping,
69                queue_size,
70            },
71        }
72    }
73}
74
75impl From<InputDef> for Input {
76    fn from(value: InputDef) -> Self {
77        match value {
78            InputDef::MappingOnly(mapping) => Self {
79                mapping,
80                queue_size: None,
81            },
82            InputDef::WithOptions { source, queue_size } => Self {
83                mapping: source,
84                queue_size,
85            },
86        }
87    }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
91pub enum InputMapping {
92    Timer { interval: Duration },
93    User(UserInputMapping),
94}
95
96impl InputMapping {
97    pub fn source(&self) -> &NodeId {
98        static DORA_NODE_ID: OnceCell<NodeId> = OnceCell::new();
99
100        match self {
101            InputMapping::User(mapping) => &mapping.source,
102            InputMapping::Timer { .. } => DORA_NODE_ID.get_or_init(|| NodeId("dora".to_string())),
103        }
104    }
105}
106
107impl fmt::Display for InputMapping {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            InputMapping::Timer { interval } => {
111                let duration = format_duration(*interval);
112                write!(f, "dora/timer/{duration}")
113            }
114            InputMapping::User(mapping) => {
115                write!(f, "{}/{}", mapping.source, mapping.output)
116            }
117        }
118    }
119}
120
121impl FromStr for InputMapping {
122    type Err = String;
123
124    fn from_str(s: &str) -> Result<Self, Self::Err> {
125        let (source, output) = s
126            .split_once('/')
127            .ok_or("input must start with `<source>/`")?;
128
129        let mapping = match source {
130            "dora" => match output.split_once('/') {
131                Some(("timer", output)) => {
132                    let (unit, value) = output.split_once('/').ok_or(
133                        "timer input must specify unit and value (e.g. `secs/5` or `millis/100`)",
134                    )?;
135                    let interval = match unit {
136                        "secs" => {
137                            let value = value
138                                .parse()
139                                .map_err(|_| format!("secs must be an integer (got `{value}`)"))?;
140                            Duration::from_secs(value)
141                        }
142                        "millis" => {
143                            let value = value.parse().map_err(|_| {
144                                format!("millis must be an integer (got `{value}`)")
145                            })?;
146                            Duration::from_millis(value)
147                        }
148                        other => {
149                            return Err(format!(
150                                "timer unit must be either secs or millis (got `{other}`"
151                            ));
152                        }
153                    };
154                    Self::Timer { interval }
155                }
156                Some((other, _)) => {
157                    return Err(format!("unknown dora input `{other}`"));
158                }
159                None => return Err("dora input has invalid format".into()),
160            },
161            _ => Self::User(UserInputMapping {
162                source: source.to_owned().into(),
163                output: output.to_owned().into(),
164            }),
165        };
166
167        Ok(mapping)
168    }
169}
170
171pub struct FormattedDuration(pub Duration);
172
173impl fmt::Display for FormattedDuration {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        if self.0.subsec_millis() == 0 {
176            write!(f, "secs/{}", self.0.as_secs())
177        } else {
178            write!(f, "millis/{}", self.0.as_millis())
179        }
180    }
181}
182
183pub fn format_duration(interval: Duration) -> FormattedDuration {
184    FormattedDuration(interval)
185}
186
187impl Serialize for InputMapping {
188    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
189    where
190        S: serde::Serializer,
191    {
192        serializer.collect_str(self)
193    }
194}
195
196impl<'de> Deserialize<'de> for InputMapping {
197    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198    where
199        D: serde::Deserializer<'de>,
200    {
201        let string = String::deserialize(deserializer)?;
202        string.parse().map_err(serde::de::Error::custom)
203    }
204}
205
206#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
207pub struct UserInputMapping {
208    pub source: NodeId,
209    pub output: DataId,
210}
211
212#[derive(Debug, Default, Serialize, Deserialize, JsonSchema, Clone)]
213#[serde(deny_unknown_fields, rename_all = "lowercase")]
214pub struct CommunicationConfig {
215    // see https://github.com/dtolnay/serde-yaml/issues/298
216    #[serde(
217        default,
218        with = "serde_yaml::with::singleton_map",
219        rename = "_unstable_local"
220    )]
221    #[schemars(with = "String")]
222    pub local: LocalCommunicationConfig,
223    #[serde(
224        default,
225        with = "serde_yaml::with::singleton_map",
226        rename = "_unstable_remote"
227    )]
228    #[schemars(with = "String")]
229    pub remote: RemoteCommunicationConfig,
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
233pub enum LocalCommunicationConfig {
234    Tcp,
235    Shmem,
236    UnixDomain,
237}
238
239impl Default for LocalCommunicationConfig {
240    fn default() -> Self {
241        Self::Tcp
242    }
243}
244
245#[derive(Debug, Serialize, Deserialize, Clone)]
246#[serde(deny_unknown_fields, rename_all = "lowercase")]
247pub enum RemoteCommunicationConfig {
248    Tcp,
249    // TODO:a
250    // Zenoh {
251    //     config: Option<serde_yaml::Value>,
252    //     prefix: String,
253    // },
254}
255
256impl Default for RemoteCommunicationConfig {
257    fn default() -> Self {
258        Self::Tcp
259    }
260}