Skip to main content

connect_stream_types/
sample_table.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use serde::Deserialize;
7use serde::Serialize;
8
9use crate::Value;
10use crate::ValueRef;
11use crate::ValueSeries;
12
13type Result<T, E = SampleTableError> = std::result::Result<T, E>;
14
15#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
16pub struct ChannelStreamDescriptor {
17    pub name: String,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub unit: Option<String>,
20}
21
22impl ChannelStreamDescriptor {
23    pub fn new(name: String, unit: Option<String>) -> Self {
24        Self { name, unit }
25    }
26}
27
28impl FromIterator<ChannelStreamDescriptor> for HashMap<String, String> {
29    fn from_iter<T: IntoIterator<Item = ChannelStreamDescriptor>>(iter: T) -> Self {
30        iter.into_iter()
31            .filter_map(|channel| channel.unit.map(|unit| (channel.name, unit)))
32            .collect::<HashMap<_, _>>()
33    }
34}
35
36#[derive(thiserror::Error, Debug)]
37pub enum SampleTableError {
38    #[error("found {0} values, expected {1}")]
39    RowCountMismatch(usize, usize),
40    #[error("number of channels ({0}) must be equal to the number of value series ({1})")]
41    ColumnCountMismatch(usize, usize),
42    #[error("channel name collision: {0}")]
43    NameCollision(String),
44    #[error("channel not found: {0}")]
45    ChannelNotFound(String),
46    #[error("error applying function to value: {0}")]
47    ApplyError(String),
48}
49
50/// Raw intermediate representation used only for deserialization.
51/// Serde deserializes into this, then [`TryFrom`] validates dimensions
52/// before producing a [`SampleTable`].
53#[derive(Deserialize)]
54struct RawSampleTable {
55    timestamps: Vec<u64>,
56    channels: Vec<ChannelStreamDescriptor>,
57    values: Vec<ValueSeries>,
58}
59
60impl TryFrom<RawSampleTable> for SampleTable {
61    type Error = SampleTableError;
62
63    fn try_from(raw: RawSampleTable) -> Result<Self> {
64        SampleTable::new(raw.timestamps, raw.channels, raw.values)
65    }
66}
67
68/// A table of sample values, with timestamps indexing rows and channels indexing columns.
69/// Values are stored in column-major order.
70#[derive(Clone, Debug, Default, Serialize, Deserialize)]
71#[serde(try_from = "RawSampleTable")]
72pub struct SampleTable {
73    timestamps: Vec<u64>,
74    channels: Vec<ChannelStreamDescriptor>,
75    values: Vec<ValueSeries>,
76}
77
78/// Channel information and data referenced from a [`SampleTable`] for a given timestamp.
79#[derive(Clone, Debug, PartialEq)]
80pub struct RowRef<'values>(
81    /// The timestamp of the row.
82    pub u64,
83    /// The channel and value for each column in the row.
84    pub Box<[(&'values ChannelStreamDescriptor, ValueRef<'values>)]>,
85);
86
87impl SampleTable {
88    fn verify_dimensions(&self) -> Result<()> {
89        if self.channels.len() != self.values.len() {
90            return Err(SampleTableError::ColumnCountMismatch(
91                self.channels.len(),
92                self.values.len(),
93            ));
94        }
95
96        for channel in &self.values {
97            if channel.len() != self.timestamps.len() {
98                return Err(SampleTableError::RowCountMismatch(
99                    channel.len(),
100                    self.timestamps.len(),
101                ));
102            }
103        }
104
105        Ok(())
106    }
107
108    fn check_channel_names_unique(channels: &[ChannelStreamDescriptor]) -> Result<()> {
109        let mut seen_names = HashSet::new();
110
111        for channel in channels {
112            if !seen_names.insert(&channel.name) {
113                return Err(SampleTableError::NameCollision(channel.name.clone()));
114            }
115        }
116
117        Ok(())
118    }
119
120    /// Create a sample table from a list of timestamps, channels, and value series.
121    ///
122    /// Each [`ValueSeries`] corresponds to the index-matched channel in `channels`,
123    /// and must have the same length as `timestamps`.
124    pub fn new(
125        timestamps: Vec<u64>,
126        channels: Vec<ChannelStreamDescriptor>,
127        values: Vec<ValueSeries>,
128    ) -> Result<Self> {
129        Self::check_channel_names_unique(&channels)?;
130
131        if values.len() != channels.len() {
132            return Err(SampleTableError::ColumnCountMismatch(
133                channels.len(),
134                values.len(),
135            ));
136        }
137
138        for series in &values {
139            if series.len() != timestamps.len() {
140                return Err(SampleTableError::RowCountMismatch(
141                    series.len(),
142                    timestamps.len(),
143                ));
144            }
145        }
146
147        Ok(Self {
148            timestamps,
149            channels,
150            values,
151        })
152    }
153
154    /// Create a sample table with a single timestamp and single values for a single channel.
155    pub fn from_single_channel(
156        timestamp: u64,
157        channel: ChannelStreamDescriptor,
158        value: Value,
159    ) -> Self {
160        Self {
161            timestamps: vec![timestamp],
162            channels: vec![channel],
163            values: vec![value.into()],
164        }
165    }
166
167    /// Create a sample table with a single timestamp and single values for multiple channels.
168    pub fn from_multiple_channels(
169        timestamp: u64,
170        channels: Vec<ChannelStreamDescriptor>,
171        values: Vec<Value>,
172    ) -> Result<Self> {
173        Self::check_channel_names_unique(&channels)?;
174
175        if values.len() != channels.len() {
176            return Err(SampleTableError::ColumnCountMismatch(
177                values.len(),
178                channels.len(),
179            ));
180        }
181
182        Ok(Self {
183            timestamps: vec![timestamp],
184            values: values.into_iter().map(Value::into_series).collect(),
185            channels,
186        })
187    }
188
189    /// Lazily-evaluated iterator over the rows of the sample table.
190    pub fn iter_rows<'values>(&'values self) -> Result<impl Iterator<Item = RowRef<'values>>> {
191        self.verify_dimensions()?;
192
193        Ok(self.timestamps.iter().enumerate().map(|(row, timestamp)| {
194            RowRef(
195                *timestamp,
196                self.channels
197                    .iter()
198                    .enumerate()
199                    .map(|(col, channel)| {
200                        #[expect(clippy::expect_used, reason = "row and col are checked above")]
201                        self.values
202                            .get(col)
203                            .and_then(|v| v.get(row))
204                            .map(|v| (channel, v))
205                            .expect("value series length check failed after invariants checked")
206                    })
207                    .collect::<Vec<_>>()
208                    .into_boxed_slice(),
209            )
210        }))
211    }
212
213    /// The channels in the sample table.
214    pub fn channels(&self) -> &[ChannelStreamDescriptor] {
215        &self.channels
216    }
217
218    /// The timestamps in the sample table.
219    pub fn timestamps(&self) -> &[u64] {
220        &self.timestamps
221    }
222
223    /// The columns in the sample table.
224    pub fn columns(&self) -> &[ValueSeries] {
225        &self.values
226    }
227
228    /// The values for a given channel.
229    pub fn get_channel_values(&self, channel: &ChannelStreamDescriptor) -> Option<&ValueSeries> {
230        let offset = self.channels.iter().position(|c| c.name == channel.name)?;
231        self.values.get(offset)
232    }
233
234    /// The values for a given channel and timestamp.
235    pub fn get_timestamped_channel_values(
236        &self,
237        channel: &ChannelStreamDescriptor,
238    ) -> Option<impl Iterator<Item = (u64, ValueRef<'_>)>> {
239        let values = self.get_channel_values(channel)?;
240        Some(self.timestamps.iter().cloned().zip(values))
241    }
242
243    /// The number of points in the stream.
244    pub fn num_points(&self) -> usize {
245        self.timestamps()
246            .len()
247            .saturating_mul(self.channels().len())
248    }
249
250    /// Apply a scaling function to a channel's values in place.
251    ///
252    /// An error is returned if the channel specified is not present.
253    ///
254    /// If any individual scaling operation fails, no further values are computed and the error is returned.
255    pub fn scale_channel(
256        &mut self,
257        channel: &ChannelStreamDescriptor,
258        f: impl FnMut(ValueRef<'_>) -> Result<f64, String>,
259    ) -> Result<()> {
260        let offset = self
261            .channels
262            .iter()
263            .position(|c| c.name == channel.name)
264            .ok_or(SampleTableError::ChannelNotFound(channel.name.clone()))?;
265
266        let Some(channel_values) = self.values.get_mut(offset) else {
267            return Err(SampleTableError::ChannelNotFound(channel.name.clone()));
268        };
269
270        let apply_result = channel_values
271            .iter()
272            .map(f)
273            .collect::<Result<Vec<_>, String>>()
274            .map_err(|e| SampleTableError::ApplyError(e.to_string()))?;
275
276        *channel_values = apply_result.into();
277
278        Ok(())
279    }
280}
281
282#[derive(Clone, Serialize, Deserialize, Debug)]
283pub struct StreamData {
284    /// Unique identifier for this stream.
285    pub stream_id: String,
286
287    /// Origin of this stream data, used for attribution in metrics.
288    ///
289    /// This remains optional for wire compatibility. When absent, ingress
290    /// rollups are still tracked under an explicit `"unknown"` source.
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub source: Option<String>,
293
294    /// The actual data being streamed.
295    pub samples: SampleTable,
296}
297
298impl Deref for StreamData {
299    type Target = SampleTable;
300    fn deref(&self) -> &Self::Target {
301        &self.samples
302    }
303}
304
305impl DerefMut for StreamData {
306    fn deref_mut(&mut self) -> &mut Self::Target {
307        &mut self.samples
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    fn channel(name: &str, unit: Option<&str>) -> ChannelStreamDescriptor {
316        ChannelStreamDescriptor {
317            name: name.to_string(),
318            unit: unit.map(ToString::to_string),
319        }
320    }
321
322    #[test]
323    fn sample_table_creation_rejects_dimension_mismatches() {
324        let mismatched_rows = SampleTable::new(
325            vec![1, 2, 3],
326            vec![channel("ch0", None), channel("ch1", None)],
327            vec![vec![1.0, 2.0].into(), vec![3.0, 4.0].into()],
328        );
329        assert!(
330            matches!(
331                mismatched_rows.as_ref().unwrap_err(),
332                SampleTableError::RowCountMismatch(..)
333            ),
334            "expected RowCountMismatch error, got: {mismatched_rows:?}"
335        );
336
337        let mismatched_columns = SampleTable::new(
338            vec![1, 2, 3],
339            vec![channel("ch0", None), channel("ch1", None)],
340            vec![vec![1.0, 2.0].into()],
341        );
342        assert!(matches!(
343            mismatched_columns.unwrap_err(),
344            SampleTableError::ColumnCountMismatch(2, 1)
345        ));
346    }
347
348    #[test]
349    fn sample_table_creation_fails_with_duplicate_channel_names() {
350        let result = SampleTable::new(
351            vec![1, 2, 3],
352            vec![channel("ch0", None), channel("ch0", None)],
353            vec![vec![1.0, 2.0].into(), vec![3.0, 4.0].into()],
354        );
355        assert!(result.is_err());
356        assert!(matches!(
357            result.unwrap_err(),
358            SampleTableError::NameCollision(..)
359        ));
360    }
361
362    #[test]
363    fn sample_table_column_lookup_by_name() {
364        let table = SampleTable::new(
365            vec![1, 2, 3],
366            vec![channel("ch0", None)],
367            vec![vec![1.0, 2.0, 3.0].into()],
368        )
369        .expect("creating sample table should succeed");
370
371        assert!(table.get_channel_values(&channel("ch0", None)).is_some());
372        assert!(
373            table
374                .get_channel_values(&channel("missing", None))
375                .is_none()
376        );
377    }
378
379    #[test]
380    fn missing_stream_data_source_defaults_to_unknown() {
381        let data: StreamData = serde_json::from_str(
382            r#"{
383                "stream_id": "stream-1",
384                "samples": {
385                    "timestamps": [42],
386                    "channels": [{
387                        "name": "temp"
388                    }],
389                    "values": [[1.0]]
390                }
391            }"#,
392        )
393        .expect("deserializing stream data should succeed");
394
395        assert!(data.source.is_none());
396    }
397
398    #[test]
399    fn single_frame_round_trips() {
400        let data = StreamData {
401            stream_id: "s1".to_string(),
402            source: Some("test".to_string()),
403            samples: SampleTable::from_single_channel(
404                100,
405                channel("ch0", Some("V")),
406                Value::Double(1.5),
407            ),
408        };
409
410        let json = serde_json::to_string(&data).expect("serialize");
411        let parsed: StreamData = serde_json::from_str(&json).expect("deserialize");
412
413        assert_eq!(parsed.stream_id, "s1");
414        assert_eq!(parsed.source.unwrap(), "test");
415
416        let SampleTable {
417            timestamps,
418            channels,
419            values,
420        } = parsed.samples;
421
422        assert_eq!(timestamps.len(), 1, "expected 1 timestamp");
423        assert_eq!(channels.len(), 1, "expected 1 channel");
424        assert_eq!(values.len(), 1, "expected 1 value");
425
426        assert_eq!(
427            *timestamps.first().expect("timestamp should be present"),
428            100
429        );
430        assert_eq!(
431            channels.first().expect("channel should be present").name,
432            "ch0"
433        );
434        assert_eq!(
435            channels
436                .first()
437                .expect("channel should be present")
438                .unit
439                .as_deref(),
440            Some("V")
441        );
442        assert_eq!(values.first().expect("value should be present").len(), 1);
443        assert_eq!(
444            values
445                .first()
446                .expect("value should be present")
447                .get(0)
448                .expect("value should be present"),
449            Value::Double(1.5)
450        );
451    }
452
453    #[test]
454    fn tabular_frame_round_trips() {
455        let data = StreamData {
456            stream_id: "daq".to_string(),
457            source: Some("labjack_t7".to_string()),
458            samples: SampleTable::new(
459                vec![100, 200, 300],
460                vec![channel("ch0", Some("V")), channel("ch1", None)],
461                vec![
462                    vec![1.0f64, 1.1f64, 1.2f64].into(),
463                    vec![2.0f64, 2.1f64, 2.2f64].into(),
464                ],
465            )
466            .expect("creating sample table should succeed"),
467        };
468
469        let json = serde_json::to_string(&data).expect("serialize");
470        let parsed: StreamData = serde_json::from_str(&json).expect("deserialize");
471
472        assert_eq!(parsed.stream_id, "daq");
473        assert_eq!(parsed.source.unwrap(), "labjack_t7");
474
475        let SampleTable {
476            timestamps,
477            channels,
478            values,
479        } = parsed.samples;
480
481        assert_eq!(timestamps, &[100, 200, 300]);
482        assert_eq!(values.iter().map(|v| v.len()).sum::<usize>(), 3 * 2);
483        assert!(values.iter().map(|v| v.len()).all(|len| len == 3));
484        assert_eq!(channels.len(), 2);
485
486        assert_eq!(
487            channels.first().expect("column should be present").name,
488            "ch0"
489        );
490
491        assert_eq!(
492            channels
493                .first()
494                .expect("column should be present")
495                .unit
496                .as_deref(),
497            Some("V")
498        );
499    }
500
501    #[test]
502    fn serialized_json_uses_expected_field_names() {
503        let data = StreamData {
504            stream_id: "s".to_string(),
505            source: Some("t".to_string()),
506            samples: SampleTable::new(vec![1], vec![channel("c", None)], vec![vec![0.0f64].into()])
507                .expect("creating sample table should succeed"),
508        };
509
510        let json = serde_json::to_string(&data).expect("serialize");
511
512        assert!(
513            json.contains(r#""samples""#),
514            "missing 'samples' key: {json}"
515        );
516        assert!(
517            json.contains(r#""timestamps""#),
518            "missing 'timestamps' key: {json}"
519        );
520        assert!(
521            json.contains(r#""channels""#),
522            "missing 'channels' key: {json}"
523        );
524    }
525
526    #[test]
527    fn deserialization_rejects_dimension_mismatches() {
528        let bad_payloads = [
529            (
530                "mismatched row count",
531                r#"{
532                    "stream_id": "stream-1",
533                    "samples": {
534                        "timestamps": [1, 2],
535                        "channels": [{"name": "ch"}],
536                        "values": [[1.0]]
537                    }
538                }"#,
539            ),
540            (
541                "mismatched column count",
542                r#"{
543                    "stream_id": "stream-1",
544                    "samples": {
545                        "timestamps": [1],
546                        "channels": [{"name": "a"}, {"name": "b"}],
547                        "values": [[1.0]]
548                    }
549                }"#,
550            ),
551        ];
552
553        for (name, payload) in bad_payloads {
554            let result = serde_json::from_str::<StreamData>(payload);
555            assert!(
556                result.is_err(),
557                "deserializing sample table with {name} should fail"
558            );
559        }
560    }
561
562    #[test]
563    fn scale_channel_doubles() {
564        let channel = channel("voltage", Some("V"));
565
566        let mut table = SampleTable::new(
567            vec![1, 2, 3],
568            vec![channel.clone()],
569            vec![vec![1.0, 2.0, 3.0].into()],
570        )
571        .expect("creating sample table should succeed");
572
573        table
574            .scale_channel(&channel, |v| match v {
575                ValueRef::Double(v) => Ok(*v * 2.0),
576                _ => Err("unexpected type".to_string()),
577            })
578            .expect("scale_channel should succeed");
579
580        let SampleTable { values, .. } = table;
581        let series = values.first().expect("should have one series");
582        assert_eq!(series.get(0), Some(ValueRef::Double(&2.0)));
583        assert_eq!(series.get(1), Some(ValueRef::Double(&4.0)));
584        assert_eq!(series.get(2), Some(ValueRef::Double(&6.0)));
585    }
586
587    #[test]
588    fn scale_channel_integers_become_doubles() {
589        let channel = channel("count", None);
590
591        let mut table = SampleTable::new(
592            vec![1, 2],
593            vec![channel.clone()],
594            vec![vec![10i64, 20i64].into()],
595        )
596        .expect("creating sample table should succeed");
597
598        table
599            .scale_channel(&channel, |v| match v {
600                ValueRef::Integer(v) => Ok(*v as f64 * 0.5),
601                _ => Err("unexpected type".to_string()),
602            })
603            .expect("scale_channel should succeed");
604
605        let SampleTable { values, .. } = table;
606        let series = values.first().expect("should have one series");
607        // Integer channel should now be Double after scaling
608        assert_eq!(series.get(0), Some(ValueRef::Double(&5.0)));
609        assert_eq!(series.get(1), Some(ValueRef::Double(&10.0)));
610    }
611
612    #[test]
613    fn scale_channel_not_found() {
614        let mut table = SampleTable::new(
615            vec![1],
616            vec![channel("exists", None)],
617            vec![vec![1.0].into()],
618        )
619        .expect("creating sample table should succeed");
620
621        let missing = channel("missing", None);
622
623        let err = table
624            .scale_channel(&missing, |v| match v {
625                ValueRef::Double(v) => Ok(*v),
626                _ => Err("unexpected type".to_string()),
627            })
628            .expect_err("scale_channel should fail for missing channel");
629
630        assert!(
631            matches!(err, SampleTableError::ChannelNotFound(ref name) if name == "missing"),
632            "expected ChannelNotFound error, got: {err}"
633        );
634    }
635
636    #[test]
637    fn scale_channel_closure_error_propagates() {
638        let channel = channel("ch", None);
639
640        let mut table = SampleTable::new(
641            vec![1, 2, 3],
642            vec![channel.clone()],
643            vec![vec![1.0, 2.0, 3.0].into()],
644        )
645        .expect("creating sample table should succeed");
646
647        const ERR_MSG: &str = "value too large";
648
649        let err = table
650            .scale_channel(&channel, |v| match v {
651                ValueRef::Double(v) if *v > 1.5 => Err(ERR_MSG.to_string()),
652                ValueRef::Double(v) => Ok(*v),
653                _ => Err("unexpected type".to_string()),
654            })
655            .expect_err("scale_channel should propagate closure error");
656
657        assert!(
658            matches!(err, SampleTableError::ApplyError(ref msg) if msg.contains(ERR_MSG)),
659            "expected ApplyError, got: {err}"
660        );
661    }
662}