use std::collections::HashMap;
use std::collections::HashSet;
use std::ops::Deref;
use std::ops::DerefMut;
use serde::Deserialize;
use serde::Serialize;
use crate::Value;
use crate::ValueRef;
use crate::ValueSeries;
type Result<T, E = SampleTableError> = std::result::Result<T, E>;
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
pub struct ChannelStreamDescriptor {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub unit: Option<String>,
}
impl ChannelStreamDescriptor {
pub fn new(name: String, unit: Option<String>) -> Self {
Self { name, unit }
}
}
impl FromIterator<ChannelStreamDescriptor> for HashMap<String, String> {
fn from_iter<T: IntoIterator<Item = ChannelStreamDescriptor>>(iter: T) -> Self {
iter.into_iter()
.filter_map(|channel| channel.unit.map(|unit| (channel.name, unit)))
.collect::<HashMap<_, _>>()
}
}
#[derive(thiserror::Error, Debug)]
pub enum SampleTableError {
#[error("found {0} values, expected {1}")]
RowCountMismatch(usize, usize),
#[error("number of channels ({0}) must be equal to the number of value series ({1})")]
ColumnCountMismatch(usize, usize),
#[error("channel name collision: {0}")]
NameCollision(String),
#[error("channel not found: {0}")]
ChannelNotFound(String),
#[error("error applying function to value: {0}")]
ApplyError(String),
}
#[derive(Deserialize)]
struct RawSampleTable {
timestamps: Vec<u64>,
channels: Vec<ChannelStreamDescriptor>,
values: Vec<ValueSeries>,
}
impl TryFrom<RawSampleTable> for SampleTable {
type Error = SampleTableError;
fn try_from(raw: RawSampleTable) -> Result<Self> {
SampleTable::new(raw.timestamps, raw.channels, raw.values)
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(try_from = "RawSampleTable")]
pub struct SampleTable {
timestamps: Vec<u64>,
channels: Vec<ChannelStreamDescriptor>,
values: Vec<ValueSeries>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct RowRef<'values>(
pub u64,
pub Box<[(&'values ChannelStreamDescriptor, ValueRef<'values>)]>,
);
impl SampleTable {
fn verify_dimensions(&self) -> Result<()> {
if self.channels.len() != self.values.len() {
return Err(SampleTableError::ColumnCountMismatch(
self.channels.len(),
self.values.len(),
));
}
for channel in &self.values {
if channel.len() != self.timestamps.len() {
return Err(SampleTableError::RowCountMismatch(
channel.len(),
self.timestamps.len(),
));
}
}
Ok(())
}
fn check_channel_names_unique(channels: &[ChannelStreamDescriptor]) -> Result<()> {
let mut seen_names = HashSet::new();
for channel in channels {
if !seen_names.insert(&channel.name) {
return Err(SampleTableError::NameCollision(channel.name.clone()));
}
}
Ok(())
}
pub fn new(
timestamps: Vec<u64>,
channels: Vec<ChannelStreamDescriptor>,
values: Vec<ValueSeries>,
) -> Result<Self> {
Self::check_channel_names_unique(&channels)?;
if values.len() != channels.len() {
return Err(SampleTableError::ColumnCountMismatch(
channels.len(),
values.len(),
));
}
for series in &values {
if series.len() != timestamps.len() {
return Err(SampleTableError::RowCountMismatch(
series.len(),
timestamps.len(),
));
}
}
Ok(Self {
timestamps,
channels,
values,
})
}
pub fn from_single_channel(
timestamp: u64,
channel: ChannelStreamDescriptor,
value: Value,
) -> Self {
Self {
timestamps: vec![timestamp],
channels: vec![channel],
values: vec![value.into()],
}
}
pub fn from_multiple_channels(
timestamp: u64,
channels: Vec<ChannelStreamDescriptor>,
values: Vec<Value>,
) -> Result<Self> {
Self::check_channel_names_unique(&channels)?;
if values.len() != channels.len() {
return Err(SampleTableError::ColumnCountMismatch(
values.len(),
channels.len(),
));
}
Ok(Self {
timestamps: vec![timestamp],
values: values.into_iter().map(Value::into_series).collect(),
channels,
})
}
pub fn iter_rows<'values>(&'values self) -> Result<impl Iterator<Item = RowRef<'values>>> {
self.verify_dimensions()?;
Ok(self.timestamps.iter().enumerate().map(|(row, timestamp)| {
RowRef(
*timestamp,
self.channels
.iter()
.enumerate()
.map(|(col, channel)| {
#[expect(clippy::expect_used, reason = "row and col are checked above")]
self.values
.get(col)
.and_then(|v| v.get(row))
.map(|v| (channel, v))
.expect("value series length check failed after invariants checked")
})
.collect::<Vec<_>>()
.into_boxed_slice(),
)
}))
}
pub fn channels(&self) -> &[ChannelStreamDescriptor] {
&self.channels
}
pub fn timestamps(&self) -> &[u64] {
&self.timestamps
}
pub fn columns(&self) -> &[ValueSeries] {
&self.values
}
pub fn get_channel_values(&self, channel: &ChannelStreamDescriptor) -> Option<&ValueSeries> {
let offset = self.channels.iter().position(|c| c.name == channel.name)?;
self.values.get(offset)
}
pub fn get_timestamped_channel_values(
&self,
channel: &ChannelStreamDescriptor,
) -> Option<impl Iterator<Item = (u64, ValueRef<'_>)>> {
let values = self.get_channel_values(channel)?;
Some(self.timestamps.iter().cloned().zip(values))
}
pub fn num_points(&self) -> usize {
self.timestamps()
.len()
.saturating_mul(self.channels().len())
}
pub fn scale_channel(
&mut self,
channel: &ChannelStreamDescriptor,
f: impl FnMut(ValueRef<'_>) -> Result<f64, String>,
) -> Result<()> {
let offset = self
.channels
.iter()
.position(|c| c.name == channel.name)
.ok_or(SampleTableError::ChannelNotFound(channel.name.clone()))?;
let Some(channel_values) = self.values.get_mut(offset) else {
return Err(SampleTableError::ChannelNotFound(channel.name.clone()));
};
let apply_result = channel_values
.iter()
.map(f)
.collect::<Result<Vec<_>, String>>()
.map_err(|e| SampleTableError::ApplyError(e.to_string()))?;
*channel_values = apply_result.into();
Ok(())
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct StreamData {
pub stream_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
pub samples: SampleTable,
}
impl Deref for StreamData {
type Target = SampleTable;
fn deref(&self) -> &Self::Target {
&self.samples
}
}
impl DerefMut for StreamData {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.samples
}
}
#[cfg(test)]
mod tests {
use super::*;
fn channel(name: &str, unit: Option<&str>) -> ChannelStreamDescriptor {
ChannelStreamDescriptor {
name: name.to_string(),
unit: unit.map(ToString::to_string),
}
}
#[test]
fn sample_table_creation_rejects_dimension_mismatches() {
let mismatched_rows = SampleTable::new(
vec![1, 2, 3],
vec![channel("ch0", None), channel("ch1", None)],
vec![vec![1.0, 2.0].into(), vec![3.0, 4.0].into()],
);
assert!(
matches!(
mismatched_rows.as_ref().unwrap_err(),
SampleTableError::RowCountMismatch(..)
),
"expected RowCountMismatch error, got: {mismatched_rows:?}"
);
let mismatched_columns = SampleTable::new(
vec![1, 2, 3],
vec![channel("ch0", None), channel("ch1", None)],
vec![vec![1.0, 2.0].into()],
);
assert!(matches!(
mismatched_columns.unwrap_err(),
SampleTableError::ColumnCountMismatch(2, 1)
));
}
#[test]
fn sample_table_creation_fails_with_duplicate_channel_names() {
let result = SampleTable::new(
vec![1, 2, 3],
vec![channel("ch0", None), channel("ch0", None)],
vec![vec![1.0, 2.0].into(), vec![3.0, 4.0].into()],
);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SampleTableError::NameCollision(..)
));
}
#[test]
fn sample_table_column_lookup_by_name() {
let table = SampleTable::new(
vec![1, 2, 3],
vec![channel("ch0", None)],
vec![vec![1.0, 2.0, 3.0].into()],
)
.expect("creating sample table should succeed");
assert!(table.get_channel_values(&channel("ch0", None)).is_some());
assert!(
table
.get_channel_values(&channel("missing", None))
.is_none()
);
}
#[test]
fn missing_stream_data_source_defaults_to_unknown() {
let data: StreamData = serde_json::from_str(
r#"{
"stream_id": "stream-1",
"samples": {
"timestamps": [42],
"channels": [{
"name": "temp"
}],
"values": [[1.0]]
}
}"#,
)
.expect("deserializing stream data should succeed");
assert!(data.source.is_none());
}
#[test]
fn single_frame_round_trips() {
let data = StreamData {
stream_id: "s1".to_string(),
source: Some("test".to_string()),
samples: SampleTable::from_single_channel(
100,
channel("ch0", Some("V")),
Value::Double(1.5),
),
};
let json = serde_json::to_string(&data).expect("serialize");
let parsed: StreamData = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.stream_id, "s1");
assert_eq!(parsed.source.unwrap(), "test");
let SampleTable {
timestamps,
channels,
values,
} = parsed.samples;
assert_eq!(timestamps.len(), 1, "expected 1 timestamp");
assert_eq!(channels.len(), 1, "expected 1 channel");
assert_eq!(values.len(), 1, "expected 1 value");
assert_eq!(
*timestamps.first().expect("timestamp should be present"),
100
);
assert_eq!(
channels.first().expect("channel should be present").name,
"ch0"
);
assert_eq!(
channels
.first()
.expect("channel should be present")
.unit
.as_deref(),
Some("V")
);
assert_eq!(values.first().expect("value should be present").len(), 1);
assert_eq!(
values
.first()
.expect("value should be present")
.get(0)
.expect("value should be present"),
Value::Double(1.5)
);
}
#[test]
fn tabular_frame_round_trips() {
let data = StreamData {
stream_id: "daq".to_string(),
source: Some("labjack_t7".to_string()),
samples: SampleTable::new(
vec![100, 200, 300],
vec![channel("ch0", Some("V")), channel("ch1", None)],
vec![
vec![1.0f64, 1.1f64, 1.2f64].into(),
vec![2.0f64, 2.1f64, 2.2f64].into(),
],
)
.expect("creating sample table should succeed"),
};
let json = serde_json::to_string(&data).expect("serialize");
let parsed: StreamData = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.stream_id, "daq");
assert_eq!(parsed.source.unwrap(), "labjack_t7");
let SampleTable {
timestamps,
channels,
values,
} = parsed.samples;
assert_eq!(timestamps, &[100, 200, 300]);
assert_eq!(values.iter().map(|v| v.len()).sum::<usize>(), 3 * 2);
assert!(values.iter().map(|v| v.len()).all(|len| len == 3));
assert_eq!(channels.len(), 2);
assert_eq!(
channels.first().expect("column should be present").name,
"ch0"
);
assert_eq!(
channels
.first()
.expect("column should be present")
.unit
.as_deref(),
Some("V")
);
}
#[test]
fn serialized_json_uses_expected_field_names() {
let data = StreamData {
stream_id: "s".to_string(),
source: Some("t".to_string()),
samples: SampleTable::new(vec![1], vec![channel("c", None)], vec![vec![0.0f64].into()])
.expect("creating sample table should succeed"),
};
let json = serde_json::to_string(&data).expect("serialize");
assert!(
json.contains(r#""samples""#),
"missing 'samples' key: {json}"
);
assert!(
json.contains(r#""timestamps""#),
"missing 'timestamps' key: {json}"
);
assert!(
json.contains(r#""channels""#),
"missing 'channels' key: {json}"
);
}
#[test]
fn deserialization_rejects_dimension_mismatches() {
let bad_payloads = [
(
"mismatched row count",
r#"{
"stream_id": "stream-1",
"samples": {
"timestamps": [1, 2],
"channels": [{"name": "ch"}],
"values": [[1.0]]
}
}"#,
),
(
"mismatched column count",
r#"{
"stream_id": "stream-1",
"samples": {
"timestamps": [1],
"channels": [{"name": "a"}, {"name": "b"}],
"values": [[1.0]]
}
}"#,
),
];
for (name, payload) in bad_payloads {
let result = serde_json::from_str::<StreamData>(payload);
assert!(
result.is_err(),
"deserializing sample table with {name} should fail"
);
}
}
#[test]
fn scale_channel_doubles() {
let channel = channel("voltage", Some("V"));
let mut table = SampleTable::new(
vec![1, 2, 3],
vec![channel.clone()],
vec![vec![1.0, 2.0, 3.0].into()],
)
.expect("creating sample table should succeed");
table
.scale_channel(&channel, |v| match v {
ValueRef::Double(v) => Ok(*v * 2.0),
_ => Err("unexpected type".to_string()),
})
.expect("scale_channel should succeed");
let SampleTable { values, .. } = table;
let series = values.first().expect("should have one series");
assert_eq!(series.get(0), Some(ValueRef::Double(&2.0)));
assert_eq!(series.get(1), Some(ValueRef::Double(&4.0)));
assert_eq!(series.get(2), Some(ValueRef::Double(&6.0)));
}
#[test]
fn scale_channel_integers_become_doubles() {
let channel = channel("count", None);
let mut table = SampleTable::new(
vec![1, 2],
vec![channel.clone()],
vec![vec![10i64, 20i64].into()],
)
.expect("creating sample table should succeed");
table
.scale_channel(&channel, |v| match v {
ValueRef::Integer(v) => Ok(*v as f64 * 0.5),
_ => Err("unexpected type".to_string()),
})
.expect("scale_channel should succeed");
let SampleTable { values, .. } = table;
let series = values.first().expect("should have one series");
assert_eq!(series.get(0), Some(ValueRef::Double(&5.0)));
assert_eq!(series.get(1), Some(ValueRef::Double(&10.0)));
}
#[test]
fn scale_channel_not_found() {
let mut table = SampleTable::new(
vec![1],
vec![channel("exists", None)],
vec![vec![1.0].into()],
)
.expect("creating sample table should succeed");
let missing = channel("missing", None);
let err = table
.scale_channel(&missing, |v| match v {
ValueRef::Double(v) => Ok(*v),
_ => Err("unexpected type".to_string()),
})
.expect_err("scale_channel should fail for missing channel");
assert!(
matches!(err, SampleTableError::ChannelNotFound(ref name) if name == "missing"),
"expected ChannelNotFound error, got: {err}"
);
}
#[test]
fn scale_channel_closure_error_propagates() {
let channel = channel("ch", None);
let mut table = SampleTable::new(
vec![1, 2, 3],
vec![channel.clone()],
vec![vec![1.0, 2.0, 3.0].into()],
)
.expect("creating sample table should succeed");
const ERR_MSG: &str = "value too large";
let err = table
.scale_channel(&channel, |v| match v {
ValueRef::Double(v) if *v > 1.5 => Err(ERR_MSG.to_string()),
ValueRef::Double(v) => Ok(*v),
_ => Err("unexpected type".to_string()),
})
.expect_err("scale_channel should propagate closure error");
assert!(
matches!(err, SampleTableError::ApplyError(ref msg) if msg.contains(ERR_MSG)),
"expected ApplyError, got: {err}"
);
}
}