dora_message/
config.rs

1use core::fmt;
2use std::{
3    collections::{BTreeMap, BTreeSet},
4    time::Duration,
5};
6
7use once_cell::sync::OnceCell;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11pub use crate::id::{DataId, NodeId, OperatorId};
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
14pub struct NodeRunConfig {
15    /// Inputs for the nodes as a map from input ID to `node_id/output_id`.
16    ///
17    /// e.g.
18    ///
19    /// inputs:
20    ///
21    ///   example_input: example_node/example_output1
22    ///
23    #[serde(default)]
24    pub inputs: BTreeMap<DataId, Input>,
25    /// List of output IDs.
26    ///
27    /// e.g.
28    ///
29    /// outputs:
30    ///
31    ///  - output_1
32    ///
33    ///  - output_2
34    #[serde(default)]
35    pub outputs: BTreeSet<DataId>,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
39#[serde(from = "InputDef", into = "InputDef")]
40pub struct Input {
41    pub mapping: InputMapping,
42    pub queue_size: Option<usize>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(untagged)]
47pub enum InputDef {
48    MappingOnly(InputMapping),
49    WithOptions {
50        source: InputMapping,
51        queue_size: Option<usize>,
52    },
53}
54
55impl From<Input> for InputDef {
56    fn from(input: Input) -> Self {
57        match input {
58            Input {
59                mapping,
60                queue_size: None,
61            } => Self::MappingOnly(mapping),
62            Input {
63                mapping,
64                queue_size,
65            } => Self::WithOptions {
66                source: mapping,
67                queue_size,
68            },
69        }
70    }
71}
72
73impl From<InputDef> for Input {
74    fn from(value: InputDef) -> Self {
75        match value {
76            InputDef::MappingOnly(mapping) => Self {
77                mapping,
78                queue_size: None,
79            },
80            InputDef::WithOptions { source, queue_size } => Self {
81                mapping: source,
82                queue_size,
83            },
84        }
85    }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
89pub enum InputMapping {
90    Timer { interval: Duration },
91    User(UserInputMapping),
92}
93
94impl InputMapping {
95    pub fn source(&self) -> &NodeId {
96        static DORA_NODE_ID: OnceCell<NodeId> = OnceCell::new();
97
98        match self {
99            InputMapping::User(mapping) => &mapping.source,
100            InputMapping::Timer { .. } => DORA_NODE_ID.get_or_init(|| NodeId("dora".to_string())),
101        }
102    }
103}
104
105impl fmt::Display for InputMapping {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            InputMapping::Timer { interval } => {
109                let duration = format_duration(*interval);
110                write!(f, "dora/timer/{duration}")
111            }
112            InputMapping::User(mapping) => {
113                write!(f, "{}/{}", mapping.source, mapping.output)
114            }
115        }
116    }
117}
118
119pub struct FormattedDuration(pub Duration);
120
121impl fmt::Display for FormattedDuration {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        if self.0.subsec_millis() == 0 {
124            write!(f, "secs/{}", self.0.as_secs())
125        } else {
126            write!(f, "millis/{}", self.0.as_millis())
127        }
128    }
129}
130
131pub fn format_duration(interval: Duration) -> FormattedDuration {
132    FormattedDuration(interval)
133}
134
135impl Serialize for InputMapping {
136    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
137    where
138        S: serde::Serializer,
139    {
140        serializer.collect_str(self)
141    }
142}
143
144impl<'de> Deserialize<'de> for InputMapping {
145    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146    where
147        D: serde::Deserializer<'de>,
148    {
149        let string = String::deserialize(deserializer)?;
150        let (source, output) = string
151            .split_once('/')
152            .ok_or_else(|| serde::de::Error::custom("input must start with `<source>/`"))?;
153
154        let deserialized = match source {
155            "dora" => match output.split_once('/') {
156                Some(("timer", output)) => {
157                    let (unit, value) = output.split_once('/').ok_or_else(|| {
158                        serde::de::Error::custom(
159                            "timer input must specify unit and value (e.g. `secs/5` or `millis/100`)",
160                        )
161                    })?;
162                    let interval = match unit {
163                        "secs" => {
164                            let value = value.parse().map_err(|_| {
165                                serde::de::Error::custom(format!(
166                                    "secs must be an integer (got `{value}`)"
167                                ))
168                            })?;
169                            Duration::from_secs(value)
170                        }
171                        "millis" => {
172                            let value = value.parse().map_err(|_| {
173                                serde::de::Error::custom(format!(
174                                    "millis must be an integer (got `{value}`)"
175                                ))
176                            })?;
177                            Duration::from_millis(value)
178                        }
179                        other => {
180                            return Err(serde::de::Error::custom(format!(
181                                "timer unit must be either secs or millis (got `{other}`"
182                            )))
183                        }
184                    };
185                    Self::Timer { interval }
186                }
187                Some((other, _)) => {
188                    return Err(serde::de::Error::custom(format!(
189                        "unknown dora input `{other}`"
190                    )))
191                }
192                None => return Err(serde::de::Error::custom("dora input has invalid format")),
193            },
194            _ => Self::User(UserInputMapping {
195                source: source.to_owned().into(),
196                output: output.to_owned().into(),
197            }),
198        };
199
200        Ok(deserialized)
201    }
202}
203
204#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
205pub struct UserInputMapping {
206    pub source: NodeId,
207    pub output: DataId,
208}
209
210#[derive(Debug, Default, Serialize, Deserialize, JsonSchema, Clone)]
211#[serde(deny_unknown_fields, rename_all = "lowercase")]
212pub struct CommunicationConfig {
213    // see https://github.com/dtolnay/serde-yaml/issues/298
214    #[serde(
215        default,
216        with = "serde_yaml::with::singleton_map",
217        rename = "_unstable_local"
218    )]
219    #[schemars(with = "String")]
220    pub local: LocalCommunicationConfig,
221    #[serde(
222        default,
223        with = "serde_yaml::with::singleton_map",
224        rename = "_unstable_remote"
225    )]
226    #[schemars(with = "String")]
227    pub remote: RemoteCommunicationConfig,
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
231pub enum LocalCommunicationConfig {
232    Tcp,
233    Shmem,
234    UnixDomain,
235}
236
237impl Default for LocalCommunicationConfig {
238    fn default() -> Self {
239        Self::Tcp
240    }
241}
242
243#[derive(Debug, Serialize, Deserialize, Clone)]
244#[serde(deny_unknown_fields, rename_all = "lowercase")]
245pub enum RemoteCommunicationConfig {
246    Tcp,
247    // TODO:a
248    // Zenoh {
249    //     config: Option<serde_yaml::Value>,
250    //     prefix: String,
251    // },
252}
253
254impl Default for RemoteCommunicationConfig {
255    fn default() -> Self {
256        Self::Tcp
257    }
258}