#[cfg(feature = "hdf5")]
use std::path::Path;
#[cfg(feature = "hdf5")]
use hdf5::File;
#[cfg(feature = "hdf5")]
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "hdf5")]
use crate::Dataset;
#[cfg(feature = "hdf5")]
#[derive(Debug, Clone)]
pub struct HDF5Config {
pub feature_dataset: Option<String>,
pub label_dataset: Option<String>,
pub group_path: String,
pub cache_data: bool,
pub max_samples: Option<usize>,
}
#[cfg(feature = "hdf5")]
impl Default for HDF5Config {
fn default() -> Self {
Self {
feature_dataset: None,
label_dataset: None,
group_path: "/".to_string(),
cache_data: true,
max_samples: None,
}
}
}
#[cfg(feature = "hdf5")]
impl HDF5Config {
pub fn with_feature_dataset(mut self, dataset: String) -> Self {
self.feature_dataset = Some(dataset);
self
}
pub fn with_label_dataset(mut self, dataset: String) -> Self {
self.label_dataset = Some(dataset);
self
}
pub fn with_group_path(mut self, path: String) -> Self {
self.group_path = path;
self
}
pub fn with_cache_data(mut self, cache: bool) -> Self {
self.cache_data = cache;
self
}
pub fn with_max_samples(mut self, max_samples: usize) -> Self {
self.max_samples = Some(max_samples);
self
}
}
#[cfg(feature = "hdf5")]
#[derive(Debug, Clone)]
pub struct HDF5DatasetInfo {
pub file_path: String,
pub feature_dataset: String,
pub label_dataset: Option<String>,
pub num_samples: usize,
pub feature_shape: Vec<usize>,
pub label_shape: Option<Vec<usize>>,
pub file_size: u64,
pub available_datasets: Vec<String>,
}
#[cfg(feature = "hdf5")]
pub struct HDF5DatasetBuilder {
path: Option<String>,
config: HDF5Config,
}
#[cfg(feature = "hdf5")]
impl Default for HDF5DatasetBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "hdf5")]
impl HDF5DatasetBuilder {
pub fn new() -> Self {
Self {
path: None,
config: HDF5Config::default(),
}
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().to_string_lossy().to_string());
self
}
pub fn config(mut self, config: HDF5Config) -> Self {
self.config = config;
self
}
pub fn feature_dataset(mut self, dataset: String) -> Self {
self.config.feature_dataset = Some(dataset);
self
}
pub fn label_dataset(mut self, dataset: String) -> Self {
self.config.label_dataset = Some(dataset);
self
}
pub fn build(self) -> Result<HDF5Dataset> {
let path = self
.path
.ok_or_else(|| TensorError::invalid_argument("Path must be specified".to_string()))?;
HDF5Dataset::from_file_with_config(&path, self.config)
}
}
#[cfg(feature = "hdf5")]
pub struct HDF5Dataset {
config: HDF5Config,
info: HDF5DatasetInfo,
cached_features: Option<Vec<Vec<f32>>>,
cached_labels: Option<Vec<f32>>,
}
#[cfg(feature = "hdf5")]
impl HDF5Dataset {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::from_file_with_config(path, HDF5Config::default())
}
pub fn from_file_with_config<P: AsRef<Path>>(path: P, config: HDF5Config) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
if !path.as_ref().exists() {
return Err(TensorError::invalid_argument(format!(
"HDF5 file not found: {path_str}"
)));
}
let file_size = std::fs::metadata(&path_str)
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to read file metadata: {e}"))
})?
.len();
let file = File::open(&path_str)
.map_err(|e| TensorError::invalid_argument(format!("Failed to open HDF5 file: {e}")))?;
let available_datasets = discover_datasets(&file)?;
if available_datasets.is_empty() {
return Err(TensorError::invalid_argument(
"No datasets found in HDF5 file".to_string(),
));
}
let feature_dataset = config
.feature_dataset
.clone()
.or_else(|| available_datasets.first().cloned())
.ok_or_else(|| TensorError::invalid_argument("No feature dataset found".to_string()))?;
let label_dataset = config.label_dataset.clone().or_else(|| {
if available_datasets.len() > 1 {
available_datasets.get(1).cloned()
} else {
None
}
});
let (num_samples, feature_shape) = get_dataset_shape(&file, &feature_dataset)?;
let label_shape = if let Some(ref label_name) = label_dataset {
Some(get_dataset_shape(&file, label_name)?.1)
} else {
None
};
let info = HDF5DatasetInfo {
file_path: path_str.clone(),
feature_dataset: feature_dataset.clone(),
label_dataset: label_dataset.clone(),
num_samples,
feature_shape,
label_shape,
file_size,
available_datasets,
};
let mut dataset = Self {
config,
info,
cached_features: None,
cached_labels: None,
};
if dataset.config.cache_data {
dataset.load_data(&path_str)?;
}
Ok(dataset)
}
pub fn info(&self) -> &HDF5DatasetInfo {
&self.info
}
fn load_data(&mut self, file_path: &str) -> Result<()> {
let file = File::open(file_path)
.map_err(|e| TensorError::invalid_argument(format!("Failed to open HDF5 file: {e}")))?;
let feature_dataset = file.dataset(&self.info.feature_dataset).map_err(|e| {
TensorError::invalid_argument(format!("Failed to open feature dataset: {e}"))
})?;
let feature_data: Vec<f32> = feature_dataset.read_raw().map_err(|e| {
TensorError::invalid_argument(format!("Failed to read feature data: {e}"))
})?;
let mut features = Vec::new();
let feature_size_per_sample = if self.info.feature_shape.len() > 1 {
self.info.feature_shape[1..].iter().product()
} else {
1
};
for i in 0..self.info.num_samples {
let start_idx = i * feature_size_per_sample;
let end_idx = start_idx + feature_size_per_sample;
if end_idx <= feature_data.len() {
features.push(feature_data[start_idx..end_idx].to_vec());
}
if let Some(max_samples) = self.config.max_samples {
if features.len() >= max_samples {
break;
}
}
}
self.cached_features = Some(features);
if let Some(ref label_dataset_name) = self.info.label_dataset {
let label_dataset = file.dataset(label_dataset_name).map_err(|e| {
TensorError::invalid_argument(format!("Failed to open label dataset: {e}"))
})?;
let labels: Vec<f32> = label_dataset.read_raw().map_err(|e| {
TensorError::invalid_argument(format!("Failed to read label data: {e}"))
})?;
self.cached_labels = Some(labels);
}
Ok(())
}
}
#[cfg(feature = "hdf5")]
impl Dataset<f32> for HDF5Dataset {
fn len(&self) -> usize {
if let Some(ref cached) = self.cached_features {
cached.len()
} else {
self.info.num_samples
}
}
fn get(&self, index: usize) -> Result<(Tensor<f32>, Tensor<f32>)> {
if index >= self.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.len()
)));
}
if let Some(ref cached_features) = self.cached_features {
let features = cached_features[index].clone();
let feature_tensor = Tensor::from_vec(features, &[cached_features[index].len()])?;
let label_tensor = if let Some(ref cached_labels) = self.cached_labels {
if index < cached_labels.len() {
Tensor::from_vec(vec![cached_labels[index]], &[])?
} else {
Tensor::from_vec(vec![0.0f32], &[])?
}
} else {
Tensor::from_vec(vec![0.0f32], &[])?
};
Ok((feature_tensor, label_tensor))
} else {
Err(TensorError::invalid_argument(
"Data not cached - enable cache_data for efficient access".to_string(),
))
}
}
}
#[cfg(feature = "hdf5")]
fn discover_datasets(file: &File) -> Result<Vec<String>> {
let mut datasets = Vec::new();
for name in file
.member_names()
.map_err(|e| TensorError::invalid_argument(format!("Failed to list file members: {e}")))?
{
if file.dataset(&name).is_ok() {
datasets.push(name);
}
}
Ok(datasets)
}
#[cfg(feature = "hdf5")]
fn get_dataset_shape(file: &File, dataset_name: &str) -> Result<(usize, Vec<usize>)> {
let dataset = file.dataset(dataset_name).map_err(|e| {
TensorError::invalid_argument(format!("Failed to open dataset {dataset_name}: {e}"))
})?;
let shape = dataset.shape();
let num_samples = if !shape.is_empty() { shape[0] } else { 0 };
let shape_vec = shape.to_vec();
Ok((num_samples, shape_vec))
}
#[cfg(not(feature = "hdf5"))]
pub struct HDF5Config;
#[cfg(not(feature = "hdf5"))]
pub struct HDF5DatasetInfo;
#[cfg(not(feature = "hdf5"))]
pub struct HDF5DatasetBuilder;
#[cfg(not(feature = "hdf5"))]
pub struct HDF5Dataset;
#[cfg(test)]
#[cfg(feature = "hdf5")]
mod tests {
use super::*;
#[test]
fn test_hdf5_config_default() {
let config = HDF5Config::default();
assert_eq!(config.group_path, "/");
assert!(config.cache_data);
assert!(config.feature_dataset.is_none());
assert!(config.label_dataset.is_none());
}
#[test]
fn test_hdf5_config_builder() {
let config = HDF5Config::default()
.with_feature_dataset("features".to_string())
.with_label_dataset("labels".to_string())
.with_group_path("/data".to_string())
.with_max_samples(1000);
assert_eq!(
config
.feature_dataset
.as_ref()
.expect("test: value should be present"),
"features"
);
assert_eq!(
config
.label_dataset
.as_ref()
.expect("test: value should be present"),
"labels"
);
assert_eq!(config.group_path, "/data");
assert_eq!(config.max_samples, Some(1000));
}
#[test]
fn test_hdf5_dataset_builder() {
let builder = HDF5DatasetBuilder::new()
.feature_dataset("data".to_string())
.label_dataset("targets".to_string());
assert_eq!(
builder
.config
.feature_dataset
.as_ref()
.expect("test: value should be present"),
"data"
);
assert_eq!(
builder
.config
.label_dataset
.as_ref()
.expect("test: value should be present"),
"targets"
);
}
}