use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Class {
Fixed,
Random,
}
impl std::fmt::Display for Class {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Class::Fixed => write!(f, "Fixed"),
Class::Random => write!(f, "Random"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StageId(pub String);
impl StageId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn numbered(n: usize) -> Self {
Self(format!("stage_{}", n))
}
}
impl std::fmt::Display for StageId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Marker {
pub stage: StageId,
pub start: usize,
pub end: usize,
}
impl Marker {
pub fn new(stage: StageId, start: usize, end: usize) -> Self {
debug_assert!(start < end, "Marker start must be before end");
Self { stage, start, end }
}
pub fn len(&self) -> usize {
self.end - self.start
}
pub fn is_empty(&self) -> bool {
self.start >= self.end
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum PowerUnits {
ADC,
Volts,
Millivolts,
Arbitrary(String),
}
impl Default for PowerUnits {
fn default() -> Self {
Self::Arbitrary("unknown".to_string())
}
}
impl std::fmt::Display for PowerUnits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PowerUnits::ADC => write!(f, "ADC"),
PowerUnits::Volts => write!(f, "V"),
PowerUnits::Millivolts => write!(f, "mV"),
PowerUnits::Arbitrary(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trace {
pub class: Class,
pub samples: Vec<f32>,
pub markers: Option<Vec<Marker>>,
pub id: u64,
}
impl Trace {
pub fn new(class: Class, samples: Vec<f32>) -> Self {
Self {
class,
samples,
markers: None,
id: 0,
}
}
pub fn with_id(class: Class, samples: Vec<f32>, id: u64) -> Self {
Self {
class,
samples,
markers: None,
id,
}
}
pub fn with_markers(mut self, markers: Vec<Marker>) -> Self {
self.markers = Some(markers);
self
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn stage_samples(&self, stage: &StageId) -> Option<&[f32]> {
self.markers.as_ref().and_then(|markers| {
markers
.iter()
.find(|m| &m.stage == stage)
.map(|m| &self.samples[m.start..m.end])
})
}
pub fn stage_ids(&self) -> Vec<&StageId> {
self.markers
.as_ref()
.map(|markers| markers.iter().map(|m| &m.stage).collect())
.unwrap_or_default()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Meta {
pub description: Option<String>,
pub device: Option<String>,
pub algorithm: Option<String>,
#[serde(default)]
pub extra: HashMap<String, String>,
}
impl Meta {
pub fn new() -> Self {
Self::default()
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn with_device(mut self, device: impl Into<String>) -> Self {
self.device = Some(device.into());
self
}
pub fn with_algorithm(mut self, algo: impl Into<String>) -> Self {
self.algorithm = Some(algo.into());
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dataset {
pub traces: Vec<Trace>,
pub units: PowerUnits,
pub sample_rate_hz: Option<f64>,
pub meta: Meta,
}
impl Dataset {
pub fn new(traces: Vec<Trace>) -> Self {
Self {
traces,
units: PowerUnits::default(),
sample_rate_hz: None,
meta: Meta::default(),
}
}
pub fn with_units(mut self, units: PowerUnits) -> Self {
self.units = units;
self
}
pub fn with_sample_rate(mut self, rate_hz: f64) -> Self {
self.sample_rate_hz = Some(rate_hz);
self
}
pub fn with_meta(mut self, meta: Meta) -> Self {
self.meta = meta;
self
}
pub fn len(&self) -> usize {
self.traces.len()
}
pub fn is_empty(&self) -> bool {
self.traces.is_empty()
}
pub fn fixed_count(&self) -> usize {
self.traces
.iter()
.filter(|t| t.class == Class::Fixed)
.count()
}
pub fn random_count(&self) -> usize {
self.traces
.iter()
.filter(|t| t.class == Class::Random)
.count()
}
pub fn fixed_traces(&self) -> impl Iterator<Item = &Trace> {
self.traces.iter().filter(|t| t.class == Class::Fixed)
}
pub fn random_traces(&self) -> impl Iterator<Item = &Trace> {
self.traces.iter().filter(|t| t.class == Class::Random)
}
pub fn trace_length(&self) -> Option<usize> {
self.traces.first().map(|t| t.len())
}
pub fn is_aligned(&self) -> bool {
if let Some(first_len) = self.trace_length() {
self.traces.iter().all(|t| t.len() == first_len)
} else {
true }
}
pub fn stage_ids(&self) -> Vec<StageId> {
let mut ids: Vec<StageId> = self
.traces
.iter()
.flat_map(|t| t.stage_ids().into_iter().cloned())
.collect();
ids.sort_by(|a, b| a.0.cmp(&b.0));
ids.dedup();
ids
}
pub fn validate(&self) -> Result<(), DatasetError> {
if self.traces.is_empty() {
return Err(DatasetError::Empty);
}
if self.fixed_count() == 0 {
return Err(DatasetError::NoFixedTraces);
}
if self.random_count() == 0 {
return Err(DatasetError::NoRandomTraces);
}
if !self.is_aligned() {
return Err(DatasetError::UnequalTraceLengths);
}
if self.traces.iter().any(|t| t.is_empty()) {
return Err(DatasetError::EmptyTraces);
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DatasetError {
Empty,
NoFixedTraces,
NoRandomTraces,
UnequalTraceLengths,
EmptyTraces,
}
impl std::fmt::Display for DatasetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DatasetError::Empty => write!(f, "Dataset is empty"),
DatasetError::NoFixedTraces => write!(f, "No Fixed class traces"),
DatasetError::NoRandomTraces => write!(f, "No Random class traces"),
DatasetError::UnequalTraceLengths => write!(f, "Traces have different lengths"),
DatasetError::EmptyTraces => write!(f, "Some traces are empty"),
}
}
}
impl std::error::Error for DatasetError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trace_creation() {
let trace = Trace::new(Class::Fixed, vec![1.0, 2.0, 3.0]);
assert_eq!(trace.class, Class::Fixed);
assert_eq!(trace.len(), 3);
assert!(!trace.is_empty());
}
#[test]
fn test_trace_with_markers() {
let stage = StageId::new("round_1");
let marker = Marker::new(stage.clone(), 0, 10);
let trace = Trace::new(Class::Random, vec![0.0; 20]).with_markers(vec![marker]);
assert!(trace.markers.is_some());
assert_eq!(trace.stage_ids().len(), 1);
}
#[test]
fn test_dataset_class_counts() {
let traces = vec![
Trace::new(Class::Fixed, vec![1.0, 2.0]),
Trace::new(Class::Fixed, vec![1.1, 2.1]),
Trace::new(Class::Random, vec![0.5, 1.5]),
];
let dataset = Dataset::new(traces);
assert_eq!(dataset.fixed_count(), 2);
assert_eq!(dataset.random_count(), 1);
assert!(dataset.is_aligned());
}
#[test]
fn test_dataset_validation() {
let valid = Dataset::new(vec![
Trace::new(Class::Fixed, vec![1.0]),
Trace::new(Class::Random, vec![2.0]),
]);
assert!(valid.validate().is_ok());
let empty = Dataset::new(vec![]);
assert_eq!(empty.validate(), Err(DatasetError::Empty));
let no_fixed = Dataset::new(vec![Trace::new(Class::Random, vec![1.0])]);
assert_eq!(no_fixed.validate(), Err(DatasetError::NoFixedTraces));
let no_random = Dataset::new(vec![Trace::new(Class::Fixed, vec![1.0])]);
assert_eq!(no_random.validate(), Err(DatasetError::NoRandomTraces));
let unequal = Dataset::new(vec![
Trace::new(Class::Fixed, vec![1.0, 2.0]),
Trace::new(Class::Random, vec![1.0]),
]);
assert_eq!(unequal.validate(), Err(DatasetError::UnequalTraceLengths));
}
}