use crate::error_taxonomy::{helpers as error_helpers, DatasetErrorContext};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tenflowers_core::{DType, Result, Tensor, TensorError};
#[derive(Debug, Clone)]
pub struct FormatMetadata {
pub format_name: String,
pub version: Option<String>,
pub num_samples: usize,
pub fields: Vec<FieldInfo>,
pub metadata: HashMap<String, String>,
pub supports_random_access: bool,
pub supports_streaming: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FieldInfo {
pub name: String,
pub dtype: DataType,
pub shape: Option<Vec<usize>>,
pub nullable: bool,
pub description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DataType {
Bool,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
String,
Binary,
Struct(Vec<FieldInfo>),
List(Box<DataType>),
Tensor(DType),
}
impl DataType {
pub fn to_tensor_dtype(&self) -> Option<DType> {
match self {
DataType::Float32 => Some(DType::Float32),
DataType::Float64 => Some(DType::Float64),
DataType::Int32 => Some(DType::Int32),
DataType::Int64 => Some(DType::Int64),
DataType::Tensor(dtype) => Some(*dtype),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct FormatSample {
pub features: Tensor<f32>,
pub labels: Tensor<f32>,
pub source_index: usize,
pub metadata: HashMap<String, String>,
}
pub trait FormatReader: Send + Sync {
fn metadata(&self) -> Result<FormatMetadata>;
fn get_sample(&self, index: usize) -> Result<FormatSample>;
fn get_samples(&self, indices: &[usize]) -> Result<Vec<FormatSample>> {
indices.iter().map(|&i| self.get_sample(i)).collect()
}
fn iter(&self) -> Box<dyn Iterator<Item = Result<FormatSample>> + '_>;
fn validate_schema(&self, expected: &[FieldInfo]) -> Result<()> {
let metadata = self.metadata()?;
if metadata.fields.len() != expected.len() {
return Err(error_helpers::schema_mismatch(
"validate_schema",
format!("{} fields", expected.len()),
format!("{} fields", metadata.fields.len()),
));
}
for (actual, expected) in metadata.fields.iter().zip(expected.iter()) {
if actual.name != expected.name {
return Err(error_helpers::schema_mismatch(
"validate_schema",
format!("field name '{}'", expected.name),
format!("field name '{}'", actual.name),
));
}
if actual.dtype != expected.dtype {
return Err(error_helpers::schema_mismatch(
"validate_schema",
format!("field '{}' type {:?}", expected.name, expected.dtype),
format!("field '{}' type {:?}", actual.name, actual.dtype),
));
}
}
Ok(())
}
fn supports_random_access(&self) -> bool {
self.metadata()
.map(|m| m.supports_random_access)
.unwrap_or(false)
}
fn supports_streaming(&self) -> bool {
self.metadata()
.map(|m| m.supports_streaming)
.unwrap_or(true)
}
fn len(&self) -> usize {
self.metadata().map(|m| m.num_samples).unwrap_or(0)
}
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct FormatDetection {
pub format_name: String,
pub confidence: f32,
pub method: DetectionMethod,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DetectionMethod {
Extension,
MagicBytes,
ContentAnalysis,
Explicit,
}
pub struct FormatRegistry {
factories: HashMap<String, Box<dyn FormatFactory>>,
}
pub trait FormatFactory: Send + Sync {
fn format_name(&self) -> &str;
fn extensions(&self) -> Vec<&str>;
fn can_read(&self, path: &Path) -> Result<FormatDetection>;
fn create_reader(&self, path: &Path) -> Result<Box<dyn FormatReader>>;
}
impl FormatRegistry {
pub fn new() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn register(&mut self, factory: Box<dyn FormatFactory>) {
self.factories
.insert(factory.format_name().to_string(), factory);
}
pub fn auto_detect(&self, path: &Path) -> Result<Box<dyn FormatReader>> {
let mut detections = Vec::new();
for factory in self.factories.values() {
if let Ok(detection) = factory.can_read(path) {
detections.push((detection, factory));
}
}
if detections.is_empty() {
return Err(error_helpers::data_corruption(
"auto_detect",
"No compatible format found",
Some(path.to_path_buf()),
));
}
detections.sort_by(|a, b| {
b.0.confidence
.partial_cmp(&a.0.confidence)
.expect("partial_cmp should not return None for valid values")
});
let (detection, factory) = &detections[0];
if detection.confidence < 0.5 {
return Err(error_helpers::data_corruption(
"auto_detect",
format!("Low confidence detection: {:.2}", detection.confidence),
Some(path.to_path_buf()),
));
}
factory.create_reader(path)
}
pub fn create_reader(&self, format: &str, path: &Path) -> Result<Box<dyn FormatReader>> {
match self.factories.get(format) {
Some(factory) => factory.create_reader(path),
None => Err(error_helpers::invalid_configuration(
"create_reader",
"format",
format!("Unknown format: {}", format),
)),
}
}
pub fn supported_formats(&self) -> Vec<String> {
self.factories.keys().cloned().collect()
}
pub fn get_factory(&self, format: &str) -> Option<&Box<dyn FormatFactory>> {
self.factories.get(format)
}
}
impl Default for FormatRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct FormatReaderBuilder {
path: PathBuf,
format: Option<String>,
expected_schema: Option<Vec<FieldInfo>>,
options: HashMap<String, String>,
}
impl FormatReaderBuilder {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self {
path: path.into(),
format: None,
expected_schema: None,
options: HashMap::new(),
}
}
pub fn with_format(mut self, format: impl Into<String>) -> Self {
self.format = Some(format.into());
self
}
pub fn with_schema(mut self, schema: Vec<FieldInfo>) -> Self {
self.expected_schema = Some(schema);
self
}
pub fn with_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.options.insert(key.into(), value.into());
self
}
pub fn build(self, registry: &FormatRegistry) -> Result<Box<dyn FormatReader>> {
let reader = if let Some(format) = &self.format {
registry.create_reader(format, &self.path)?
} else {
registry.auto_detect(&self.path)?
};
if let Some(schema) = self.expected_schema {
reader.validate_schema(&schema)?;
}
Ok(reader)
}
}
pub fn detect_format_from_extension(path: &Path) -> Option<String> {
path.extension()
.and_then(|ext| ext.to_str())
.map(|ext| ext.to_lowercase())
}
pub fn read_magic_bytes(path: &Path, num_bytes: usize) -> Result<Vec<u8>> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path)
.map_err(|e| error_helpers::file_not_found("read_magic_bytes", path.to_path_buf()))?;
let mut buffer = vec![0u8; num_bytes];
file.read_exact(&mut buffer).map_err(|e| {
error_helpers::data_corruption(
"read_magic_bytes",
format!("Failed to read magic bytes: {}", e),
Some(path.to_path_buf()),
)
})?;
Ok(buffer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_type_conversion() {
assert_eq!(DataType::Float32.to_tensor_dtype(), Some(DType::Float32));
assert_eq!(DataType::Float64.to_tensor_dtype(), Some(DType::Float64));
assert_eq!(DataType::Int32.to_tensor_dtype(), Some(DType::Int32));
assert_eq!(DataType::String.to_tensor_dtype(), None);
}
#[test]
fn test_format_metadata_creation() {
let metadata = FormatMetadata {
format_name: "test_format".to_string(),
version: Some("1.0".to_string()),
num_samples: 100,
fields: vec![FieldInfo {
name: "feature".to_string(),
dtype: DataType::Float32,
shape: Some(vec![10]),
nullable: false,
description: None,
}],
metadata: HashMap::new(),
supports_random_access: true,
supports_streaming: true,
};
assert_eq!(metadata.format_name, "test_format");
assert_eq!(metadata.num_samples, 100);
assert_eq!(metadata.fields.len(), 1);
}
#[test]
fn test_format_registry() {
let registry = FormatRegistry::new();
assert!(registry.supported_formats().is_empty());
}
#[test]
fn test_format_detection_from_extension() {
assert_eq!(
detect_format_from_extension(Path::new("data.json")),
Some("json".to_string())
);
assert_eq!(
detect_format_from_extension(Path::new("data.csv")),
Some("csv".to_string())
);
assert_eq!(
detect_format_from_extension(Path::new("data.CSV")),
Some("csv".to_string())
);
assert_eq!(detect_format_from_extension(Path::new("data")), None);
}
#[test]
fn test_reader_builder() {
let builder = FormatReaderBuilder::new("test.json")
.with_format("json")
.with_option("encoding", "utf-8");
assert_eq!(builder.format, Some("json".to_string()));
assert_eq!(builder.options.get("encoding"), Some(&"utf-8".to_string()));
}
#[test]
fn test_field_info_creation() {
let field = FieldInfo {
name: "test_field".to_string(),
dtype: DataType::Float32,
shape: Some(vec![3, 224, 224]),
nullable: false,
description: Some("Test field".to_string()),
};
assert_eq!(field.name, "test_field");
assert_eq!(field.dtype, DataType::Float32);
assert_eq!(field.shape, Some(vec![3, 224, 224]));
assert!(!field.nullable);
}
#[test]
fn test_data_type_equality() {
assert_eq!(DataType::Float32, DataType::Float32);
assert_ne!(DataType::Float32, DataType::Float64);
assert_eq!(
DataType::List(Box::new(DataType::Int32)),
DataType::List(Box::new(DataType::Int32))
);
}
#[test]
fn test_detection_method() {
let detection = FormatDetection {
format_name: "json".to_string(),
confidence: 0.95,
method: DetectionMethod::Extension,
};
assert_eq!(detection.format_name, "json");
assert_eq!(detection.confidence, 0.95);
assert_eq!(detection.method, DetectionMethod::Extension);
}
#[test]
fn test_format_sample_metadata() {
let mut metadata = HashMap::new();
metadata.insert("source".to_string(), "test".to_string());
let sample = FormatSample {
features: Tensor::<f32>::zeros(&[10]),
labels: Tensor::<f32>::zeros(&[1]),
source_index: 42,
metadata,
};
assert_eq!(sample.source_index, 42);
assert_eq!(sample.metadata.get("source"), Some(&"test".to_string()));
}
}