pub mod tiny;
pub use tiny::TinyModelRepr;
#[derive(Debug, Clone, PartialEq)]
pub struct EmbeddedTestData {
pub x_data: Vec<f32>,
pub x_shape: (usize, usize),
pub y_data: Option<Vec<f32>>,
pub feature_names: Option<Vec<String>>,
pub sample_ids: Option<Vec<String>>,
pub provenance: Option<DataProvenance>,
pub compression: DataCompression,
}
impl EmbeddedTestData {
#[must_use]
pub fn new(x_data: Vec<f32>, x_shape: (usize, usize)) -> Self {
assert_eq!(
x_data.len(),
x_shape.0 * x_shape.1,
"Data length {} doesn't match shape {:?}",
x_data.len(),
x_shape
);
Self {
x_data,
x_shape,
y_data: None,
feature_names: None,
sample_ids: None,
provenance: None,
compression: DataCompression::None,
}
}
#[must_use]
pub fn with_targets(mut self, y_data: Vec<f32>) -> Self {
assert_eq!(
y_data.len(),
self.x_shape.0,
"Target length {} doesn't match n_samples {}",
y_data.len(),
self.x_shape.0
);
self.y_data = Some(y_data);
self
}
#[must_use]
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
assert_eq!(
names.len(),
self.x_shape.1,
"Feature names length {} doesn't match n_features {}",
names.len(),
self.x_shape.1
);
self.feature_names = Some(names);
self
}
#[must_use]
pub fn with_sample_ids(mut self, ids: Vec<String>) -> Self {
assert_eq!(
ids.len(),
self.x_shape.0,
"Sample IDs length {} doesn't match n_samples {}",
ids.len(),
self.x_shape.0
);
self.sample_ids = Some(ids);
self
}
#[must_use]
pub fn with_provenance(mut self, provenance: DataProvenance) -> Self {
self.provenance = Some(provenance);
self
}
#[must_use]
pub fn with_compression(mut self, compression: DataCompression) -> Self {
self.compression = compression;
self
}
#[must_use]
pub const fn n_samples(&self) -> usize {
self.x_shape.0
}
#[must_use]
pub const fn n_features(&self) -> usize {
self.x_shape.1
}
#[must_use]
pub fn size_bytes(&self) -> usize {
let base_size = self.x_data.len() * 4; let y_size = self.y_data.as_ref().map_or(0, |y| y.len() * 4);
base_size + y_size
}
#[must_use]
pub fn get_row(&self, idx: usize) -> Option<&[f32]> {
if idx >= self.n_samples() {
return None;
}
let start = idx * self.n_features();
let end = start + self.n_features();
Some(&self.x_data[start..end])
}
#[must_use]
pub fn get_target(&self, idx: usize) -> Option<f32> {
self.y_data.as_ref().and_then(|y| y.get(idx).copied())
}
pub fn validate(&self) -> Result<(), EmbedError> {
if self.x_data.len() != self.x_shape.0 * self.x_shape.1 {
return Err(EmbedError::ShapeMismatch {
expected: self.x_shape.0 * self.x_shape.1,
actual: self.x_data.len(),
});
}
for (i, &val) in self.x_data.iter().enumerate() {
if !val.is_finite() {
return Err(EmbedError::InvalidValue {
index: i,
value: val,
});
}
}
if let Some(ref y) = self.y_data {
if y.len() != self.x_shape.0 {
return Err(EmbedError::TargetMismatch {
expected: self.x_shape.0,
actual: y.len(),
});
}
for (i, &val) in y.iter().enumerate() {
if !val.is_finite() {
return Err(EmbedError::InvalidValue {
index: i,
value: val,
});
}
}
}
Ok(())
}
}
impl Default for EmbeddedTestData {
fn default() -> Self {
Self::new(Vec::new(), (0, 0))
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DataProvenance {
pub source: String,
pub subset_criteria: Option<String>,
pub preprocessing: Vec<String>,
pub created_at: String,
pub license: Option<String>,
pub version: Option<String>,
pub metadata: std::collections::HashMap<String, String>,
}
impl DataProvenance {
#[must_use]
pub fn new(source: impl Into<String>) -> Self {
Self {
source: source.into(),
subset_criteria: None,
preprocessing: Vec::new(),
created_at: chrono_lite_timestamp(),
license: None,
version: None,
metadata: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn with_subset(mut self, criteria: impl Into<String>) -> Self {
self.subset_criteria = Some(criteria.into());
self
}
#[must_use]
pub fn with_preprocessing(mut self, step: impl Into<String>) -> Self {
self.preprocessing.push(step.into());
self
}
#[must_use]
pub fn with_preprocessing_steps(mut self, steps: Vec<String>) -> Self {
self.preprocessing.extend(steps);
self
}
#[must_use]
pub fn with_license(mut self, license: impl Into<String>) -> Self {
self.license = Some(license.into());
self
}
#[must_use]
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
#[must_use]
pub fn is_complete(&self) -> bool {
!self.source.is_empty() && self.license.is_some()
}
}
impl Default for DataProvenance {
fn default() -> Self {
Self::new("unknown")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DataCompression {
#[default]
None,
Zstd {
level: u8,
},
DeltaZstd {
level: u8,
},
QuantizedEntropy {
bits: u8,
},
Sparse {
threshold: u32, },
}
include!("embed_error.rs");
include!("tests_embedded_data.rs");