#[cfg(feature = "tfrecord")]
use std::collections::HashMap;
#[cfg(feature = "tfrecord")]
use std::fs::File;
#[cfg(feature = "tfrecord")]
use std::io::{BufReader, Read};
#[cfg(feature = "tfrecord")]
use std::path::Path;
#[cfg(feature = "tfrecord")]
use crc32fast::Hasher;
#[cfg(feature = "tfrecord")]
use oxiarc_archive::GzipReader;
#[cfg(feature = "tfrecord")]
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "tfrecord")]
use crate::Dataset;
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone)]
pub struct TFRecordConfig {
pub feature_keys: Option<Vec<String>>,
pub feature_label_keys: Option<(String, String)>,
pub compression: bool,
pub batch_size: usize,
pub cache_records: bool,
pub max_records: Option<usize>,
pub validate_crc: bool,
pub buffer_size: usize,
}
#[cfg(feature = "tfrecord")]
impl Default for TFRecordConfig {
fn default() -> Self {
Self {
feature_keys: None,
feature_label_keys: None,
compression: false,
batch_size: 1000,
cache_records: true,
max_records: None,
validate_crc: true,
buffer_size: 8192,
}
}
}
#[cfg(feature = "tfrecord")]
impl TFRecordConfig {
pub fn with_feature_keys(mut self, keys: Vec<String>) -> Self {
self.feature_keys = Some(keys);
self
}
pub fn with_feature_label_keys(mut self, feature_key: String, label_key: String) -> Self {
self.feature_label_keys = Some((feature_key, label_key));
self
}
pub fn with_compression(mut self, compressed: bool) -> Self {
self.compression = compressed;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_cache_records(mut self, cache: bool) -> Self {
self.cache_records = cache;
self
}
pub fn with_max_records(mut self, max_records: usize) -> Self {
self.max_records = Some(max_records);
self
}
pub fn with_validate_crc(mut self, validate: bool) -> Self {
self.validate_crc = validate;
self
}
}
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone)]
pub struct TFRecordDatasetInfo {
pub file_path: String,
pub num_records: usize,
pub file_size: u64,
pub compressed: bool,
pub feature_keys: Vec<String>,
pub example_features: HashMap<String, FeatureInfo>,
}
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone)]
pub struct FeatureInfo {
pub feature_type: FeatureType,
pub shape: Option<Vec<usize>>,
pub dtype: String,
}
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone, PartialEq)]
pub enum FeatureType {
Bytes,
Float,
Int64,
}
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone)]
pub struct TFRecord {
pub data: Vec<u8>,
pub features: HashMap<String, Feature>,
}
#[cfg(feature = "tfrecord")]
#[derive(Debug, Clone)]
pub enum Feature {
Bytes(Vec<Vec<u8>>),
Float(Vec<f32>),
Int64(Vec<i64>),
}
#[cfg(feature = "tfrecord")]
pub struct TFRecordDatasetBuilder {
path: Option<String>,
config: TFRecordConfig,
}
#[cfg(feature = "tfrecord")]
impl Default for TFRecordDatasetBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "tfrecord")]
impl TFRecordDatasetBuilder {
pub fn new() -> Self {
Self {
path: None,
config: TFRecordConfig::default(),
}
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().to_string_lossy().to_string());
self
}
pub fn config(mut self, config: TFRecordConfig) -> Self {
self.config = config;
self
}
pub fn feature_keys(mut self, keys: Vec<String>) -> Self {
self.config.feature_keys = Some(keys);
self
}
pub fn compression(mut self, compressed: bool) -> Self {
self.config.compression = compressed;
self
}
pub fn build(self) -> Result<TFRecordDataset> {
let path = self
.path
.ok_or_else(|| TensorError::invalid_argument("Path must be specified".to_string()))?;
TFRecordDataset::from_file_with_config(&path, self.config)
}
}
#[cfg(feature = "tfrecord")]
pub struct TFRecordDataset {
config: TFRecordConfig,
info: TFRecordDatasetInfo,
cached_records: Option<Vec<TFRecord>>,
}
#[cfg(feature = "tfrecord")]
impl TFRecordDataset {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::from_file_with_config(path, TFRecordConfig::default())
}
pub fn from_file_with_config<P: AsRef<Path>>(path: P, config: TFRecordConfig) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
if !path.as_ref().exists() {
return Err(TensorError::invalid_argument(format!(
"TFRecord file not found: {path_str}"
)));
}
let file_size = std::fs::metadata(&path_str)
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to read file metadata: {e}"))
})?
.len();
let (num_records, feature_keys, example_features) = scan_tfrecord_file(&path_str, &config)?;
let info = TFRecordDatasetInfo {
file_path: path_str.clone(),
num_records,
file_size,
compressed: config.compression,
feature_keys,
example_features,
};
let mut dataset = Self {
config,
info,
cached_records: None,
};
if dataset.config.cache_records {
dataset.load_records(&path_str)?;
}
Ok(dataset)
}
pub fn info(&self) -> &TFRecordDatasetInfo {
&self.info
}
fn load_records(&mut self, file_path: &str) -> Result<()> {
let mut reader = create_reader(file_path, &self.config)?;
let mut records = Vec::new();
let mut record_count = 0;
while let Some(record) = read_next_record(&mut reader, &self.config)? {
records.push(record);
record_count += 1;
if let Some(max_records) = self.config.max_records {
if record_count >= max_records {
break;
}
}
}
self.cached_records = Some(records);
Ok(())
}
}
#[cfg(feature = "tfrecord")]
impl Dataset<f32> for TFRecordDataset {
fn len(&self) -> usize {
if let Some(ref cached) = self.cached_records {
cached.len()
} else {
self.info.num_records
}
}
fn get(&self, index: usize) -> Result<(Tensor<f32>, Tensor<f32>)> {
if index >= self.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.len()
)));
}
if let Some(ref cached_records) = self.cached_records {
let record = &cached_records[index];
extract_features_and_labels(record, &self.config)
} else {
Err(TensorError::invalid_argument(
"Records not cached - enable cache_records for efficient access".to_string(),
))
}
}
}
#[cfg(feature = "tfrecord")]
fn create_reader(file_path: &str, config: &TFRecordConfig) -> Result<Box<dyn Read>> {
let file = File::open(file_path)
.map_err(|e| TensorError::invalid_argument(format!("Failed to open file: {e}")))?;
let mut reader = BufReader::with_capacity(config.buffer_size, file);
if config.compression {
use std::io::Read;
let mut raw_data = Vec::new();
reader.read_to_end(&mut raw_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read tfrecord file: {e}"))
})?;
let mut gzip_reader = GzipReader::new(std::io::Cursor::new(raw_data)).map_err(|e| {
TensorError::invalid_argument(format!("Failed to init gzip reader: {e}"))
})?;
let decompressed = gzip_reader.decompress().map_err(|e| {
TensorError::invalid_argument(format!("Failed to decompress tfrecord: {e}"))
})?;
Ok(Box::new(std::io::Cursor::new(decompressed)))
} else {
Ok(Box::new(reader))
}
}
#[cfg(feature = "tfrecord")]
fn scan_tfrecord_file(
file_path: &str,
config: &TFRecordConfig,
) -> Result<(usize, Vec<String>, HashMap<String, FeatureInfo>)> {
let mut reader = create_reader(file_path, config)?;
let mut record_count = 0;
let mut feature_keys = Vec::new();
let mut example_features = HashMap::new();
if let Some(record) = read_next_record(&mut reader, config)? {
record_count = 1;
for (key, feature) in &record.features {
if !feature_keys.contains(key) {
feature_keys.push(key.clone());
}
let feature_info = match feature {
Feature::Bytes(values) => FeatureInfo {
feature_type: FeatureType::Bytes,
shape: Some(vec![values.len()]),
dtype: "bytes".to_string(),
},
Feature::Float(values) => FeatureInfo {
feature_type: FeatureType::Float,
shape: Some(vec![values.len()]),
dtype: "float32".to_string(),
},
Feature::Int64(values) => FeatureInfo {
feature_type: FeatureType::Int64,
shape: Some(vec![values.len()]),
dtype: "int64".to_string(),
},
};
example_features.insert(key.clone(), feature_info);
}
while read_next_record(&mut reader, config)?.is_some() {
record_count += 1;
if let Some(max_records) = config.max_records {
if record_count >= max_records {
break;
}
}
}
}
feature_keys.sort();
Ok((record_count, feature_keys, example_features))
}
#[cfg(feature = "tfrecord")]
fn read_next_record(reader: &mut dyn Read, config: &TFRecordConfig) -> Result<Option<TFRecord>> {
let mut length_buf = [0u8; 8];
match reader.read_exact(&mut length_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => {
return Err(TensorError::invalid_argument(format!(
"Failed to read record length: {e}"
)))
}
}
let length = u64::from_le_bytes([
length_buf[0],
length_buf[1],
length_buf[2],
length_buf[3],
length_buf[4],
length_buf[5],
length_buf[6],
length_buf[7],
]);
let mut data = vec![0u8; length as usize];
reader
.read_exact(&mut data)
.map_err(|e| TensorError::invalid_argument(format!("Failed to read record data: {e}")))?;
let mut crc_buf = [0u8; 4];
reader
.read_exact(&mut crc_buf)
.map_err(|e| TensorError::invalid_argument(format!("Failed to read CRC: {e}")))?;
if config.validate_crc {
let expected_crc = u32::from_le_bytes(crc_buf);
let mut hasher = Hasher::new();
hasher.update(&data);
let actual_crc = hasher.finalize();
if actual_crc != expected_crc {
return Err(TensorError::invalid_argument(
"CRC validation failed".to_string(),
));
}
}
let features = create_mock_features(&data);
Ok(Some(TFRecord { data, features }))
}
#[cfg(feature = "tfrecord")]
fn create_mock_features(data: &[u8]) -> HashMap<String, Feature> {
let mut features = HashMap::new();
features.insert(
"feature".to_string(),
Feature::Float(vec![data.len() as f32]),
);
features.insert("label".to_string(), Feature::Int64(vec![0]));
features
}
#[cfg(feature = "tfrecord")]
fn extract_features_and_labels(
record: &TFRecord,
config: &TFRecordConfig,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
let (feature_data, label_data) =
if let Some((ref feature_key, ref label_key)) = config.feature_label_keys {
let features = extract_feature_values(&record.features, feature_key)?;
let labels = extract_feature_values(&record.features, label_key)?;
(features, labels)
} else {
let all_features: Vec<f32> = record
.features
.values()
.flat_map(|feature| match feature {
Feature::Float(values) => values.clone(),
Feature::Int64(values) => values.iter().map(|&x| x as f32).collect(),
Feature::Bytes(_) => vec![1.0f32], })
.collect();
(all_features, vec![0.0f32])
};
let feature_tensor = Tensor::from_vec(feature_data.clone(), &[feature_data.len()])?;
let label_tensor = Tensor::from_vec(label_data, &[])?;
Ok((feature_tensor, label_tensor))
}
#[cfg(feature = "tfrecord")]
fn extract_feature_values(features: &HashMap<String, Feature>, key: &str) -> Result<Vec<f32>> {
if let Some(feature) = features.get(key) {
match feature {
Feature::Float(values) => Ok(values.clone()),
Feature::Int64(values) => Ok(values.iter().map(|&x| x as f32).collect()),
Feature::Bytes(_) => Ok(vec![1.0f32]), }
} else {
Err(TensorError::invalid_argument(format!(
"Feature key '{key}' not found in record"
)))
}
}
#[cfg(not(feature = "tfrecord"))]
pub struct TFRecordConfig;
#[cfg(not(feature = "tfrecord"))]
pub struct TFRecordDatasetInfo;
#[cfg(not(feature = "tfrecord"))]
pub struct TFRecordDatasetBuilder;
#[cfg(not(feature = "tfrecord"))]
pub struct TFRecordDataset;
#[cfg(not(feature = "tfrecord"))]
pub struct TFRecord;
#[cfg(not(feature = "tfrecord"))]
pub struct FeatureInfo;
#[cfg(not(feature = "tfrecord"))]
pub enum FeatureType {
Bytes,
}
#[cfg(not(feature = "tfrecord"))]
pub enum Feature {
Bytes(Vec<Vec<u8>>),
}
#[cfg(test)]
#[cfg(feature = "tfrecord")]
mod tests {
use super::*;
#[test]
fn test_tfrecord_config_default() {
let config = TFRecordConfig::default();
assert_eq!(config.batch_size, 1000);
assert!(!config.compression);
assert!(config.cache_records);
assert!(config.validate_crc);
assert!(config.feature_keys.is_none());
}
#[test]
fn test_tfrecord_config_builder() {
let config = TFRecordConfig::default()
.with_compression(true)
.with_batch_size(500)
.with_feature_keys(vec!["image".to_string(), "label".to_string()])
.with_max_records(1000);
assert!(config.compression);
assert_eq!(config.batch_size, 500);
assert_eq!(
config
.feature_keys
.as_ref()
.expect("test: value should be present")
.len(),
2
);
assert_eq!(config.max_records, Some(1000));
}
#[test]
fn test_tfrecord_dataset_builder() {
let builder = TFRecordDatasetBuilder::new()
.compression(true)
.feature_keys(vec!["data".to_string()]);
assert!(builder.config.compression);
assert_eq!(
builder
.config
.feature_keys
.as_ref()
.expect("test: value should be present")
.len(),
1
);
}
#[test]
fn test_feature_type_equality() {
assert_eq!(FeatureType::Bytes, FeatureType::Bytes);
assert_eq!(FeatureType::Float, FeatureType::Float);
assert_eq!(FeatureType::Int64, FeatureType::Int64);
assert_ne!(FeatureType::Bytes, FeatureType::Float);
}
}