dora_message/
config.rs

1use core::fmt;
2use std::{
3    collections::{BTreeMap, BTreeSet},
4    time::Duration,
5};
6
7use once_cell::sync::OnceCell;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11pub use crate::id::{DataId, NodeId, OperatorId};
12
13/// Contains the input and output configuration of the node.
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
15pub struct NodeRunConfig {
16    /// Inputs for the nodes as a map from input ID to `node_id/output_id`.
17    ///
18    /// e.g.
19    ///
20    /// inputs:
21    ///
22    ///   example_input: example_node/example_output1
23    ///
24    #[serde(default)]
25    pub inputs: BTreeMap<DataId, Input>,
26    /// List of output IDs.
27    ///
28    /// e.g.
29    ///
30    /// outputs:
31    ///
32    ///  - output_1
33    ///
34    ///  - output_2
35    #[serde(default)]
36    pub outputs: BTreeSet<DataId>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
40#[serde(from = "InputDef", into = "InputDef")]
41pub struct Input {
42    pub mapping: InputMapping,
43    pub queue_size: Option<usize>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
47#[serde(untagged)]
48pub enum InputDef {
49    MappingOnly(InputMapping),
50    WithOptions {
51        source: InputMapping,
52        queue_size: Option<usize>,
53    },
54}
55
56impl From<Input> for InputDef {
57    fn from(input: Input) -> Self {
58        match input {
59            Input {
60                mapping,
61                queue_size: None,
62            } => Self::MappingOnly(mapping),
63            Input {
64                mapping,
65                queue_size,
66            } => Self::WithOptions {
67                source: mapping,
68                queue_size,
69            },
70        }
71    }
72}
73
74impl From<InputDef> for Input {
75    fn from(value: InputDef) -> Self {
76        match value {
77            InputDef::MappingOnly(mapping) => Self {
78                mapping,
79                queue_size: None,
80            },
81            InputDef::WithOptions { source, queue_size } => Self {
82                mapping: source,
83                queue_size,
84            },
85        }
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
90pub enum InputMapping {
91    Timer { interval: Duration },
92    User(UserInputMapping),
93}
94
95impl InputMapping {
96    pub fn source(&self) -> &NodeId {
97        static DORA_NODE_ID: OnceCell<NodeId> = OnceCell::new();
98
99        match self {
100            InputMapping::User(mapping) => &mapping.source,
101            InputMapping::Timer { .. } => DORA_NODE_ID.get_or_init(|| NodeId("dora".to_string())),
102        }
103    }
104}
105
106impl fmt::Display for InputMapping {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        match self {
109            InputMapping::Timer { interval } => {
110                let duration = format_duration(*interval);
111                write!(f, "dora/timer/{duration}")
112            }
113            InputMapping::User(mapping) => {
114                write!(f, "{}/{}", mapping.source, mapping.output)
115            }
116        }
117    }
118}
119
120pub struct FormattedDuration(pub Duration);
121
122impl fmt::Display for FormattedDuration {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        if self.0.subsec_millis() == 0 {
125            write!(f, "secs/{}", self.0.as_secs())
126        } else {
127            write!(f, "millis/{}", self.0.as_millis())
128        }
129    }
130}
131
132pub fn format_duration(interval: Duration) -> FormattedDuration {
133    FormattedDuration(interval)
134}
135
136impl Serialize for InputMapping {
137    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138    where
139        S: serde::Serializer,
140    {
141        serializer.collect_str(self)
142    }
143}
144
145impl<'de> Deserialize<'de> for InputMapping {
146    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147    where
148        D: serde::Deserializer<'de>,
149    {
150        let string = String::deserialize(deserializer)?;
151        let (source, output) = string
152            .split_once('/')
153            .ok_or_else(|| serde::de::Error::custom("input must start with `<source>/`"))?;
154
155        let deserialized = match source {
156            "dora" => match output.split_once('/') {
157                Some(("timer", output)) => {
158                    let (unit, value) = output.split_once('/').ok_or_else(|| {
159                        serde::de::Error::custom(
160                            "timer input must specify unit and value (e.g. `secs/5` or `millis/100`)",
161                        )
162                    })?;
163                    let interval = match unit {
164                        "secs" => {
165                            let value = value.parse().map_err(|_| {
166                                serde::de::Error::custom(format!(
167                                    "secs must be an integer (got `{value}`)"
168                                ))
169                            })?;
170                            Duration::from_secs(value)
171                        }
172                        "millis" => {
173                            let value = value.parse().map_err(|_| {
174                                serde::de::Error::custom(format!(
175                                    "millis must be an integer (got `{value}`)"
176                                ))
177                            })?;
178                            Duration::from_millis(value)
179                        }
180                        other => {
181                            return Err(serde::de::Error::custom(format!(
182                                "timer unit must be either secs or millis (got `{other}`"
183                            )));
184                        }
185                    };
186                    Self::Timer { interval }
187                }
188                Some((other, _)) => {
189                    return Err(serde::de::Error::custom(format!(
190                        "unknown dora input `{other}`"
191                    )));
192                }
193                None => return Err(serde::de::Error::custom("dora input has invalid format")),
194            },
195            _ => Self::User(UserInputMapping {
196                source: source.to_owned().into(),
197                output: output.to_owned().into(),
198            }),
199        };
200
201        Ok(deserialized)
202    }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
206pub struct UserInputMapping {
207    pub source: NodeId,
208    pub output: DataId,
209}
210
211#[derive(Debug, Default, Serialize, Deserialize, JsonSchema, Clone)]
212#[serde(deny_unknown_fields, rename_all = "lowercase")]
213pub struct CommunicationConfig {
214    // see https://github.com/dtolnay/serde-yaml/issues/298
215    #[serde(
216        default,
217        with = "serde_yaml::with::singleton_map",
218        rename = "_unstable_local"
219    )]
220    #[schemars(with = "String")]
221    pub local: LocalCommunicationConfig,
222    #[serde(
223        default,
224        with = "serde_yaml::with::singleton_map",
225        rename = "_unstable_remote"
226    )]
227    #[schemars(with = "String")]
228    pub remote: RemoteCommunicationConfig,
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
232pub enum LocalCommunicationConfig {
233    Tcp,
234    Shmem,
235    UnixDomain,
236}
237
238impl Default for LocalCommunicationConfig {
239    fn default() -> Self {
240        Self::Tcp
241    }
242}
243
244#[derive(Debug, Serialize, Deserialize, Clone)]
245#[serde(deny_unknown_fields, rename_all = "lowercase")]
246pub enum RemoteCommunicationConfig {
247    Tcp,
248    // TODO:a
249    // Zenoh {
250    //     config: Option<serde_yaml::Value>,
251    //     prefix: String,
252    // },
253}
254
255impl Default for RemoteCommunicationConfig {
256    fn default() -> Self {
257        Self::Tcp
258    }
259}