Skip to main content

connect_zmq_types/
lib.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::ffi::OsString;
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::path::PathBuf;
7use std::str::FromStr;
8
9use base64::Engine;
10use connect_stream_types::Value;
11use connect_stream_types::ValueSeries;
12use connect_user_value::TableData;
13use connect_user_value::UserValue;
14use serde::Deserialize;
15use serde::Serialize;
16use serde::de::Error;
17use uuid::Uuid;
18
19// STREAMING AND COMMANDS
20
21#[derive(Serialize, Deserialize, Debug)]
22#[serde(untagged)]
23pub enum StreamPacket {
24    Data(StreamData),
25    Command(StreamCommand),
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29#[serde(untagged)]
30pub enum StreamData {
31    Single(SingleTimestampStreamData),
32    Multi(MultiTimestampStreamData),
33}
34
35#[derive(Clone, Serialize, Deserialize, Debug)]
36pub struct SingleTimestampStreamData {
37    pub stream_id: String,
38    pub timestamp: u64, // absolute timestamp in nanoseconds
39    /// Identifies the origin of this stream for metrics purposes. When present,
40    /// stream ingress rollups are tracked for this source via
41    /// `connect_metrics::ConnectMetrics::track_stream_ingress_rollup`. When
42    /// absent, rollups are tracked under the `"unknown"` source.
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub source: Option<String>,
45    #[serde(default)]
46    pub value: Value, // For simple 2D plots
47    #[serde(default)]
48    pub channel_name: Option<String>,
49    #[serde(default)]
50    pub channel_unit: Option<String>,
51    #[serde(default)]
52    pub multi_values: Vec<Value>, // For 2D plots with many lines
53    #[serde(default)]
54    pub channel_names: Vec<String>, // For showing multi-channel data on a single plot
55    #[serde(default)]
56    pub channel_units: HashMap<String, String>,
57}
58
59#[derive(Clone, Serialize, Deserialize, Debug)]
60pub struct MultiTimestampStreamData {
61    pub stream_id: String,
62    pub timestamps: Vec<u64>, // absolute timestamps in nanoseconds
63    /// Identifies the origin of this stream for metrics purposes. When present,
64    /// stream ingress rollups are tracked for this source. When absent, rollups
65    /// are tracked under the `"unknown"` source.
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    pub source: Option<String>,
68    #[serde(default)]
69    pub values: ValueSeries,
70    #[serde(default)]
71    pub channel_name: Option<String>,
72    #[serde(default)]
73    pub channel_unit: Option<String>,
74    #[serde(default)]
75    pub multi_values: Vec<ValueSeries>, // multi-channel data (each entry corresponds to points for a given channel)
76    #[serde(default)]
77    pub channel_names: Vec<String>,
78    #[serde(default)]
79    pub channel_units: HashMap<String, String>,
80}
81
82#[derive(Serialize, Deserialize, Debug)]
83#[serde(tag = "command", rename_all = "snake_case")]
84pub enum StreamCommand {
85    Clear {
86        stream_id: String,
87    },
88    ClearAllStreams,
89    SetChannelScaling {
90        id: String,
91        channels: Vec<String>,
92        config: serde_json::Value,
93    },
94    SetValue {
95        id: String,
96        value: UserValue,
97    },
98    ClearValues {
99        ids: Vec<String>,
100    },
101    ClearAllValues,
102    SetOutput(SetOutputCommand),
103    SetTestWorkflowState {
104        test_workflow_script_identifier: ScriptIdentifier,
105        state: TestWorkflowState,
106    },
107    SetTestWorkflowRecords {
108        test_workflow_script_identifier: ScriptIdentifier,
109        records: TestWorkflowRecords,
110    },
111    SetTestWorkflowFlags {
112        test_workflow_script_identifier: ScriptIdentifier,
113        can_submit: bool,
114        can_individual_rerun: bool,
115        can_rerun_full_workflow: bool,
116        can_run_before_full_workflow: bool,
117    },
118    ClearCreatedModels,
119    CreateModels {
120        configs: Vec<serde_json::Value>,
121    },
122    SetLogFile {
123        script_identifier: ScriptIdentifier,
124        file_path: PathBuf,
125    },
126    SetImageFrame {
127        buffer_name: String,
128        timestamp: f64,
129        frame_width: usize,
130        #[serde(flatten)]
131        data: ByteListOrBase64,
132    },
133    ClearImageFrameBuffer {
134        buffer_name: String,
135    },
136    PostPointCloud {
137        pointcloud_name: String,
138        timestamp: u64,
139        points: Vec<SamplePoint>,
140    },
141    ClearPointCloud {
142        pointcloud_name: String,
143    },
144    CreateLidarScan {
145        scan: LidarScan,
146    },
147    SetDropdownOptions {
148        id: String,
149        options: Vec<String>,
150    },
151    AddTimelineMilestone {
152        time: f64,
153        #[serde(default)]
154        tooltip: Option<String>,
155        #[serde(default)]
156        color: Option<String>,
157    },
158    SendNotification {
159        message: String,
160        level: String,
161        #[serde(default)]
162        duration_seconds: Option<u64>,
163    },
164    DeviceStreamingError {
165        message: String,
166        #[serde(default, skip_serializing_if = "Option::is_none")]
167        error_code: Option<i32>,
168        device_driver: String,
169        device_id: String,
170        stopped: bool,
171    },
172    UpdatePaneVisibility {
173        id: Uuid,
174        action: PaneVisbilityAction,
175    },
176    RegisterCommand {
177        command_id: String,
178        command_alias: String,
179        description: Option<String>,
180        param_schema: schemars::Schema,
181        param_default: serde_json::Value,
182    },
183}
184
185#[derive(Clone, Serialize, Deserialize, Debug)]
186#[serde(tag = "format", rename_all = "snake_case")]
187pub enum SetOutputCommand {
188    Json {
189        script_identifier: ScriptIdentifier,
190        value: JsonOutputValue,
191    },
192    Arrow {
193        script_identifier: ScriptIdentifier,
194        value: String,
195    },
196}
197
198#[derive(Clone, Serialize, Deserialize, Debug)]
199#[serde(untagged)]
200pub enum JsonOutputValue {
201    Table(TableData),
202    String(String),
203    Other(serde_json::Value),
204}
205
206/// Message sent over ZMQ when a user value is updated or removed.
207#[derive(Serialize, Deserialize, Debug)]
208#[serde(tag = "kind", rename_all = "snake_case")]
209pub enum ValueUpdateMessage<'a> {
210    Updated {
211        id: Cow<'a, str>,
212        value: Cow<'a, UserValue>,
213    },
214    Removed {
215        id: Cow<'a, str>,
216    },
217}
218
219// same as ConnectCommand
220#[derive(Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
221pub struct ScriptIdentifier {
222    pub name: String,
223    pub command: Cmd,
224    pub arguments: Vec<OsStringOrValueId>,
225    pub max_log_lines: u64,
226}
227
228impl Debug for ScriptIdentifier {
229    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230        write!(f, "'")?;
231        match &self.command {
232            Cmd::Python => write!(f, "python")?,
233            Cmd::Bash => write!(f, "bash")?,
234            Cmd::Binary(path) => write!(f, "{path:?}")?,
235        }
236        for arg in &self.arguments {
237            write!(f, " {arg:?}")?;
238        }
239        write!(f, "'")
240    }
241}
242
243#[derive(Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
244pub enum Cmd {
245    Python,
246    Bash,
247    Binary(OsString), // path from config, resolved to current_dir if relative
248}
249
250#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq, Hash)]
251pub enum OsStringOrValueId {
252    OsString(OsString),
253    ValueId(String),
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257#[serde(untagged)]
258pub enum ByteListOrBase64 {
259    List { data: Vec<u8> },
260    Base64 { data_base64: String },
261}
262
263impl ByteListOrBase64 {
264    pub fn into_bytes(self) -> Result<Vec<u8>, base64::DecodeError> {
265        match self {
266            Self::List { data } => Ok(data),
267            Self::Base64 { data_base64 } => base64::prelude::BASE64_STANDARD.decode(&data_base64),
268        }
269    }
270}
271
272// POINTCLOUD
273
274#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
275pub struct SamplePoint {
276    pub x: f64,
277    pub y: f64,
278    pub z: f64,
279    #[serde(flatten)]
280    pub color: Option<Rgb8>,
281    pub label: Option<u32>,
282}
283
284#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
285pub struct Rgb8 {
286    pub r: u8,
287    pub g: u8,
288    pub b: u8,
289}
290/// A lidar scan
291#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
292pub struct LidarScan {
293    /// A timestamp for when the scan occurs.
294    pub timestamp: u64,
295    /// The starting x position for the scan
296    pub x: f64,
297    /// The starting y position for the scan
298    pub y: f64,
299    /// The starting z position for the scan
300    pub z: f64,
301    /// Direction the scan is facing on the X axis
302    pub dir_x: f64,
303    /// Direction the scan is facing on the Y axis
304    pub dir_y: f64,
305    /// Direction the scan is facing on the Z axis
306    pub dir_z: f64,
307    /// The name of the pointcloud to add the scan to
308    pub pointcloud_name: String,
309}
310
311// TEST WORKFLOWS
312
313/// Stores resumable state for a test workflow as a string, allowing individual tests to be rerun while maintaining the context of the original workflow run.
314///
315/// For python tests, this is usually a base64-encoded pickled class.
316#[derive(Serialize, Deserialize, Debug, Clone)]
317#[serde(transparent)]
318pub struct TestWorkflowState {
319    pub state: String,
320}
321
322#[derive(Serialize, Deserialize, Debug, Clone, Default)]
323#[serde(transparent)]
324pub struct TestWorkflowRecords {
325    pub records: Vec<TestRecord>,
326}
327
328#[derive(Debug, Clone)]
329pub struct TestWorkflowFlags {
330    pub can_submit: bool,
331    pub can_individual_rerun: bool,
332    pub can_rerun_full_workflow: bool,
333    pub can_run_before_full_workflow: bool,
334}
335
336impl std::ops::Deref for TestWorkflowRecords {
337    type Target = [TestRecord];
338    fn deref(&self) -> &Self::Target {
339        &self.records
340    }
341}
342
343impl TestWorkflowRecords {
344    pub fn clear_outputs(&mut self) {
345        for record in &mut self.records {
346            record.status = None;
347            record.output.clear();
348        }
349    }
350}
351
352#[derive(Deserialize, Serialize, Debug, Clone)]
353pub struct TestRecord {
354    pub test: String,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub status: Option<PassFail>,
357    pub output: String,
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub start_time: Option<String>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub end_time: Option<String>,
362}
363
364#[derive(Serialize, Debug, Clone, Copy)]
365#[serde(rename_all = "snake_case")]
366pub enum PassFail {
367    Pass,
368    Fail,
369    Error,
370    Skip,
371}
372
373impl std::ops::BitOr for PassFail {
374    type Output = Self;
375    fn bitor(self, rhs: Self) -> Self::Output {
376        match (self, rhs) {
377            (Self::Error, _) | (_, Self::Error) => Self::Error,
378            (Self::Fail, _) | (_, Self::Fail) => Self::Fail,
379            (Self::Pass, _) | (_, Self::Pass) => Self::Pass,
380            (Self::Skip, Self::Skip) => Self::Skip,
381        }
382    }
383}
384
385#[derive(thiserror::Error, Debug, Clone, Copy)]
386#[error("input string does not match any pass/fail variants")]
387pub struct PassFailParseError;
388
389impl FromStr for PassFail {
390    type Err = PassFailParseError;
391    fn from_str(s: &str) -> Result<Self, Self::Err> {
392        Ok(match s.to_lowercase().as_str() {
393            "pass" => Self::Pass,
394            "fail" => Self::Fail,
395            "error" => Self::Error,
396            "skip" => Self::Skip,
397            _ => return Err(PassFailParseError),
398        })
399    }
400}
401
402impl<'de> Deserialize<'de> for PassFail {
403    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
404    where
405        D: serde::Deserializer<'de>,
406    {
407        let string = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
408        let value: Self = string.parse().map_err(|_| {
409            D::Error::custom(format_args!(
410                "input string `{string}` does not match any pass/fail variants"
411            ))
412        })?;
413        Ok(value)
414    }
415}
416
417#[derive(Deserialize, Serialize, Debug)]
418#[serde(rename_all = "snake_case")]
419pub enum PaneVisbilityAction {
420    Hide,
421    Focus,
422    Restore,
423}
424
425#[cfg(test)]
426mod tests {
427    use super::SingleTimestampStreamData;
428
429    #[test]
430    fn missing_stream_data_source_defaults_to_none() {
431        let data: SingleTimestampStreamData = serde_json::from_str(
432            r#"{
433                "stream_id": "stream-1",
434                "timestamp": 42
435            }"#,
436        )
437        .expect("deserializing stream data should succeed");
438
439        assert_eq!(data.source, None);
440    }
441
442    #[test]
443    fn present_stream_data_source_deserializes() {
444        let data: SingleTimestampStreamData = serde_json::from_str(
445            r#"{
446                "stream_id": "stream-1",
447                "timestamp": 42,
448                "source": "connect_python"
449            }"#,
450        )
451        .expect("deserializing stream data should succeed");
452
453        assert_eq!(data.source.as_deref(), Some("connect_python"));
454    }
455
456    #[test]
457    fn none_source_is_omitted_from_serialization() {
458        let data: SingleTimestampStreamData =
459            serde_json::from_str(r#"{ "stream_id": "s", "timestamp": 0 }"#)
460                .expect("deserializing should succeed");
461
462        let json = serde_json::to_string(&data).expect("serializing should succeed");
463        assert!(
464            !json.contains("source"),
465            "None source should be omitted from JSON"
466        );
467    }
468}