use anyhow::{bail, Context, Result};
use ndarray::Array2;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct HDF5Sample {
pub signal: Array2<f32>,
pub label: Option<i64>,
}
pub fn read_dataset<P: AsRef<Path>>(path: P) -> Result<Vec<HDF5Sample>> {
let path = path.as_ref();
let file = hdf5::File::open(path)
.with_context(|| format!("opening HDF5 file: {}", path.display()))?;
let mut samples = Vec::new();
let mut group_names: Vec<String> = file.member_names()
.with_context(|| "listing HDF5 groups")?;
group_names.sort();
for group_name in &group_names {
let group = match file.group(group_name) {
Ok(g) => g,
Err(_) => continue, };
let x_ds = group.dataset("X")
.with_context(|| format!("reading X from group {group_name}"))?;
let x_shape = x_ds.shape();
if x_shape.len() < 2 {
bail!("X in {group_name} has unexpected shape: {x_shape:?}");
}
let x_data: Vec<f32> = x_ds.read_raw()
.with_context(|| format!("reading X data from {group_name}"))?;
let (n_samples, n_ch, n_t) = if x_shape.len() == 3 {
(x_shape[0], x_shape[1], x_shape[2])
} else if x_shape.len() == 2 {
(x_shape[0], 1, x_shape[1])
} else {
bail!("X in {group_name} has unsupported dimensionality: {}", x_shape.len());
};
let labels: Option<Vec<i64>> = match group.dataset("y") {
Ok(y_ds) => {
let y_data: Vec<i64> = y_ds.read_raw()
.unwrap_or_default();
if y_data.len() == n_samples {
Some(y_data)
} else {
let y_i32: Vec<i32> = y_ds.read_raw().unwrap_or_default();
if y_i32.len() == n_samples {
Some(y_i32.iter().map(|&v| v as i64).collect())
} else {
None
}
}
}
Err(_) => None,
};
for i in 0..n_samples {
let offset = i * n_ch * n_t;
let sample_data: Vec<f32> = x_data[offset..offset + n_ch * n_t].to_vec();
let signal = Array2::from_shape_vec((n_ch, n_t), sample_data)
.with_context(|| format!("reshaping sample {i} from {group_name}"))?;
let label = labels.as_ref().map(|l| l[i]);
samples.push(HDF5Sample { signal, label });
}
}
Ok(samples)
}
pub fn read_dataset_split<P: AsRef<Path>>(path: P) -> Result<(Vec<Array2<f32>>, Vec<Option<i64>>)> {
let samples = read_dataset(path)?;
let signals = samples.iter().map(|s| s.signal.clone()).collect();
let labels = samples.iter().map(|s| s.label).collect();
Ok((signals, labels))
}
#[cfg(test)]
mod tests {
}