#[cfg(feature = "formats")]
use crate::error::{DatasetsError, Result};
#[cfg(feature = "formats")]
use crate::utils::Dataset;
#[cfg(feature = "formats")]
use scirs2_core::ndarray::{Array1, Array2};
#[cfg(feature = "formats")]
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormatType {
Parquet,
Arrow,
Hdf5,
Csv,
}
impl FormatType {
pub fn from_extension(path: &str) -> Option<Self> {
let lower = path.to_lowercase();
if lower.ends_with(".parquet") || lower.ends_with(".pq") {
Some(FormatType::Parquet)
} else if lower.ends_with(".arrow") {
Some(FormatType::Arrow)
} else if lower.ends_with(".h5") || lower.ends_with(".hdf5") {
Some(FormatType::Hdf5)
} else if lower.ends_with(".csv") {
Some(FormatType::Csv)
} else {
None
}
}
pub fn extension(&self) -> &'static str {
match self {
FormatType::Parquet => "parquet",
FormatType::Arrow => "arrow",
FormatType::Hdf5 => "h5",
FormatType::Csv => "csv",
}
}
}
#[derive(Debug, Clone)]
pub struct FormatConfig {
pub chunk_size: usize,
pub compression: Option<CompressionCodec>,
pub use_mmap: bool,
pub buffer_size: usize,
}
impl Default for FormatConfig {
fn default() -> Self {
Self {
chunk_size: 10_000,
compression: Some(CompressionCodec::Snappy),
use_mmap: true,
buffer_size: 8 * 1024 * 1024, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionCodec {
None,
Snappy,
Gzip,
Lz4,
Zstd,
}
impl CompressionCodec {
pub fn level(&self) -> Option<i32> {
match self {
CompressionCodec::None | CompressionCodec::Snappy | CompressionCodec::Lz4 => None,
CompressionCodec::Gzip => Some(6), CompressionCodec::Zstd => Some(3), }
}
}
#[cfg(feature = "formats")]
const PARQUET_TARGET_COLUMN: &str = "__target__";
#[cfg(feature = "formats")]
fn feature_column_names(dataset: &Dataset) -> Vec<String> {
let n = dataset.n_features();
match &dataset.featurenames {
Some(names) if names.len() == n => names.clone(),
_ => (0..n).map(|i| format!("feature_{i}")).collect(),
}
}
#[cfg(feature = "formats")]
fn parquet_data_to_dataset(pdata: &scirs2_io::parquet::ParquetData) -> Result<Dataset> {
let all_columns = pdata.schema().column_names();
let n_rows = pdata.num_rows();
let feat_names: Vec<String> = all_columns
.iter()
.filter(|n| n.as_str() != PARQUET_TARGET_COLUMN)
.cloned()
.collect();
if feat_names.is_empty() {
return Err(DatasetsError::InvalidFormat(
"Parquet file contains no feature columns (only '__target__' found)".to_string(),
));
}
let n_features = feat_names.len();
let mut flat: Vec<f64> = Vec::with_capacity(n_rows * n_features);
for col_name in &feat_names {
let col = pdata.get_column_f64(col_name).map_err(|e| {
DatasetsError::InvalidFormat(format!(
"Failed to read feature column '{}': {}",
col_name, e
))
})?;
flat.extend(col.iter());
}
let column_major = Array2::from_shape_vec((n_features, n_rows), flat).map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to shape feature matrix: {e}"))
})?;
let data = column_major.t().to_owned();
let target: Option<Array1<f64>> = if all_columns
.iter()
.any(|n| n.as_str() == PARQUET_TARGET_COLUMN)
{
let col = pdata.get_column_f64(PARQUET_TARGET_COLUMN).map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to read target column: {e}"))
})?;
Some(Array1::from_vec(col.to_vec()))
} else {
None
};
let mut ds = Dataset::new(data, target);
ds.featurenames = Some(feat_names);
Ok(ds)
}
#[cfg(feature = "formats")]
fn write_dataset_parquet<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use scirs2_io::parquet::{ParquetWriteOptions, ParquetWriter as IoParquetWriter};
use std::sync::Arc;
let col_names = feature_column_names(dataset);
let n_rows = dataset.n_samples();
let n_feats = dataset.n_features();
let mut fields: Vec<Field> = col_names
.iter()
.map(|name| Field::new(name.as_str(), DataType::Float64, false))
.collect();
let has_target = dataset.target.is_some();
if has_target {
fields.push(Field::new(PARQUET_TARGET_COLUMN, DataType::Float64, false));
}
let schema = Arc::new(Schema::new(fields));
let mut arrays: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(n_feats + 1);
for col_idx in 0..n_feats {
let col_data: Vec<f64> = (0..n_rows)
.map(|row| dataset.data[[row, col_idx]])
.collect();
arrays.push(Arc::new(Float64Array::from(col_data)));
}
if let Some(target) = &dataset.target {
let tgt_data: Vec<f64> = target.to_vec();
arrays.push(Arc::new(Float64Array::from(tgt_data)));
}
let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
.map_err(|e| DatasetsError::InvalidFormat(format!("Failed to build RecordBatch: {e}")))?;
let options = ParquetWriteOptions::default();
let mut writer = IoParquetWriter::from_path(path, schema, options)
.map_err(|e| DatasetsError::InvalidFormat(format!("Parquet writer creation error: {e}")))?;
writer
.write_batch(&batch)
.map_err(|e| DatasetsError::InvalidFormat(format!("Parquet write error: {e}")))?;
writer
.close()
.map_err(|e| DatasetsError::InvalidFormat(format!("Parquet close error: {e}")))
}
#[cfg(feature = "formats")]
pub struct ParquetReader {
config: FormatConfig,
}
#[cfg(feature = "formats")]
impl ParquetReader {
pub fn new() -> Self {
Self {
config: FormatConfig::default(),
}
}
pub fn with_config(config: FormatConfig) -> Self {
Self { config }
}
pub fn read<P: AsRef<Path>>(&self, path: P) -> Result<Dataset> {
let pdata = scirs2_io::parquet::read_parquet(path)
.map_err(|e| DatasetsError::InvalidFormat(format!("Parquet read error: {e}")))?;
let _ = &self.config; parquet_data_to_dataset(&pdata)
}
}
#[cfg(feature = "formats")]
impl Default for ParquetReader {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
pub struct ParquetWriter {
config: FormatConfig,
}
#[cfg(feature = "formats")]
impl ParquetWriter {
pub fn new() -> Self {
Self {
config: FormatConfig::default(),
}
}
pub fn with_config(config: FormatConfig) -> Self {
Self { config }
}
pub fn write<P: AsRef<Path>>(&self, dataset: &Dataset, path: P) -> Result<()> {
let _ = &self.config; write_dataset_parquet(dataset, path)
}
}
#[cfg(feature = "formats")]
impl Default for ParquetWriter {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
fn hdf5_err(msg: impl std::fmt::Display) -> DatasetsError {
DatasetsError::InvalidFormat(format!("HDF5 error: {msg}"))
}
#[cfg(feature = "formats")]
fn read_dataset_hdf5<P: AsRef<Path>>(path: P, dataset_name: &str) -> Result<Dataset> {
use scirs2_io::hdf5::read_hdf5;
let root = read_hdf5(path).map_err(hdf5_err)?;
let ds = root.datasets.get(dataset_name).ok_or_else(|| {
DatasetsError::InvalidFormat(format!("Dataset '{}' not found in HDF5 file", dataset_name))
})?;
let shape = &ds.shape;
if shape.len() != 2 {
return Err(DatasetsError::InvalidFormat(format!(
"Expected 2-D dataset for '{}', got {}-D",
dataset_name,
shape.len()
)));
}
let n_rows = shape[0];
let n_cols = shape[1];
let float_data = ds.as_float_vec().ok_or_else(|| {
DatasetsError::InvalidFormat(format!(
"Dataset '{}' contains non-numeric data",
dataset_name
))
})?;
let data = Array2::from_shape_vec((n_rows, n_cols), float_data).map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to shape feature matrix: {e}"))
})?;
let target_name = format!("{}_target", dataset_name);
let target: Option<Array1<f64>> = if let Some(tds) = root.datasets.get(&target_name) {
let tvec = tds.as_float_vec().ok_or_else(|| {
DatasetsError::InvalidFormat(format!(
"Target dataset '{}' contains non-numeric data",
target_name
))
})?;
Some(Array1::from_vec(tvec))
} else {
None
};
Ok(Dataset::new(data, target))
}
#[cfg(feature = "formats")]
fn write_dataset_hdf5<P: AsRef<Path>>(
dataset: &Dataset,
path: P,
dataset_name: &str,
) -> Result<()> {
use scirs2_core::ndarray::IxDyn;
use scirs2_io::hdf5::write_hdf5;
use std::collections::HashMap;
let mut map: HashMap<String, scirs2_core::ndarray::ArrayD<f64>> = HashMap::new();
let n_rows = dataset.n_samples();
let n_cols = dataset.n_features();
let flat: Vec<f64> = dataset.data.iter().cloned().collect();
let arr_dyn = scirs2_core::ndarray::ArrayD::from_shape_vec(IxDyn(&[n_rows, n_cols]), flat)
.map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to convert data to ArrayD: {e}"))
})?;
map.insert(dataset_name.to_string(), arr_dyn);
if let Some(target) = &dataset.target {
let tvec: Vec<f64> = target.to_vec();
let tlen = tvec.len();
let tarr =
scirs2_core::ndarray::ArrayD::from_shape_vec(IxDyn(&[tlen]), tvec).map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to convert target to ArrayD: {e}"))
})?;
map.insert(format!("{}_target", dataset_name), tarr);
}
write_hdf5(path, map).map_err(hdf5_err)
}
#[cfg(feature = "formats")]
pub struct Hdf5Reader {
config: FormatConfig,
}
#[cfg(feature = "formats")]
impl Hdf5Reader {
pub fn new() -> Self {
Self {
config: FormatConfig::default(),
}
}
pub fn with_config(config: FormatConfig) -> Self {
Self { config }
}
pub fn read<P: AsRef<Path>>(&self, path: P, dataset_name: &str) -> Result<Dataset> {
let _ = &self.config;
read_dataset_hdf5(path, dataset_name)
}
}
#[cfg(feature = "formats")]
impl Default for Hdf5Reader {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
pub struct Hdf5Writer {
config: FormatConfig,
}
#[cfg(feature = "formats")]
impl Hdf5Writer {
pub fn new() -> Self {
Self {
config: FormatConfig::default(),
}
}
pub fn with_config(config: FormatConfig) -> Self {
Self { config }
}
pub fn write<P: AsRef<Path>>(
&self,
dataset: &Dataset,
path: P,
dataset_name: &str,
) -> Result<()> {
let _ = &self.config;
write_dataset_hdf5(dataset, path, dataset_name)
}
}
#[cfg(feature = "formats")]
impl Default for Hdf5Writer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CsvConfig {
pub has_header: bool,
pub delimiter: char,
pub float_precision: usize,
}
impl Default for CsvConfig {
fn default() -> Self {
Self {
has_header: true,
delimiter: ',',
float_precision: 17,
}
}
}
fn commit_row(
current_field: &mut String,
current_row: &mut Vec<String>,
out_headers: &mut Vec<String>,
out_rows: &mut Vec<Vec<String>>,
has_header: bool,
first_row: bool,
) -> bool {
let last_field = current_field.clone();
current_field.clear();
current_row.push(last_field);
let is_comment = current_row
.first()
.map(|f| f.trim_start().starts_with('#'))
.unwrap_or(false);
let is_blank = current_row.iter().all(|f| f.is_empty());
let new_first_row = if !is_blank && !is_comment {
if first_row && has_header {
*out_headers = current_row.clone();
false
} else {
out_rows.push(current_row.clone());
false
}
} else {
first_row
};
current_row.clear();
new_first_row
}
pub(crate) fn parse_csv_text(
text: &str,
has_header: bool,
delimiter: char,
) -> (Vec<String>, Vec<Vec<String>>) {
let mut out_headers: Vec<String> = Vec::new();
let mut out_rows: Vec<Vec<String>> = Vec::new();
let mut current_field = String::new();
let mut current_row: Vec<String> = Vec::new();
let mut in_quotes = false;
let mut first_row = true;
let chars: Vec<char> = text.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
let ch = chars[i];
if in_quotes {
if ch == '"' {
if i + 1 < len && chars[i + 1] == '"' {
current_field.push('"');
i += 2;
} else {
in_quotes = false;
i += 1;
}
} else {
current_field.push(ch);
i += 1;
}
} else {
if ch == '"' {
in_quotes = true;
i += 1;
} else if ch == delimiter {
current_row.push(current_field.clone());
current_field.clear();
i += 1;
} else if ch == '\n' {
first_row = commit_row(
&mut current_field,
&mut current_row,
&mut out_headers,
&mut out_rows,
has_header,
first_row,
);
i += 1;
} else if ch == '\r' {
i += 1;
} else {
current_field.push(ch);
i += 1;
}
}
}
if !current_row.is_empty() || !current_field.is_empty() {
commit_row(
&mut current_field,
&mut current_row,
&mut out_headers,
&mut out_rows,
has_header,
first_row,
);
}
(out_headers, out_rows)
}
pub(crate) fn csv_encode_field(field: &str, delimiter: char) -> String {
let needs_quoting = field.contains(delimiter)
|| field.contains('"')
|| field.contains('\n')
|| field.contains('\r');
if needs_quoting {
let escaped = field.replace('"', "\"\"");
format!("\"{escaped}\"")
} else {
field.to_owned()
}
}
pub(crate) fn format_csv_rows(headers: &[String], rows: &[Vec<String>], delimiter: char) -> String {
let mut out = String::new();
if !headers.is_empty() {
let encoded: Vec<String> = headers
.iter()
.map(|h| csv_encode_field(h, delimiter))
.collect();
out.push_str(&encoded.join(&delimiter.to_string()));
out.push('\n');
}
for row in rows {
let encoded: Vec<String> = row.iter().map(|f| csv_encode_field(f, delimiter)).collect();
out.push_str(&encoded.join(&delimiter.to_string()));
out.push('\n');
}
out
}
#[cfg(feature = "formats")]
fn read_dataset_csv<P: AsRef<Path>>(path: P, config: &CsvConfig) -> Result<Dataset> {
use std::fs;
let text = fs::read_to_string(path)
.map_err(|e| DatasetsError::InvalidFormat(format!("CSV read error: {e}")))?;
let (headers, rows) = parse_csv_text(&text, config.has_header, config.delimiter);
let n_cols = if !headers.is_empty() {
headers.len()
} else if let Some(first) = rows.first() {
first.len()
} else {
return Err(DatasetsError::InvalidFormat(
"CSV file is empty or contains only comments".to_string(),
));
};
let col_names: Vec<String> = if !headers.is_empty() {
headers.clone()
} else {
(0..n_cols).map(|i| format!("feature_{i}")).collect()
};
let target_col_idx: Option<usize> = col_names.iter().position(|n| n == PARQUET_TARGET_COLUMN);
let feat_indices: Vec<usize> = (0..n_cols).filter(|&i| Some(i) != target_col_idx).collect();
let feat_names: Vec<String> = feat_indices.iter().map(|&i| col_names[i].clone()).collect();
let n_features = feat_indices.len();
let n_rows = rows.len();
let mut flat: Vec<f64> = Vec::with_capacity(n_rows * n_features);
let mut target_vals: Vec<f64> = Vec::with_capacity(n_rows);
for (row_idx, row) in rows.iter().enumerate() {
if row.len() != n_cols {
return Err(DatasetsError::InvalidFormat(format!(
"CSV row {} has {} fields, expected {}",
row_idx + 1,
row.len(),
n_cols
)));
}
for &col_idx in &feat_indices {
let s = row[col_idx].trim();
let v = s.parse::<f64>().map_err(|e| {
DatasetsError::InvalidFormat(format!(
"CSV cell at row {}, col {} is not numeric ('{s}'): {e}",
row_idx + 1,
col_idx
))
})?;
flat.push(v);
}
if let Some(t_idx) = target_col_idx {
let s = row[t_idx].trim();
let v = s.parse::<f64>().map_err(|e| {
DatasetsError::InvalidFormat(format!(
"CSV target cell at row {} is not numeric ('{s}'): {e}",
row_idx + 1,
))
})?;
target_vals.push(v);
}
}
let data = if n_rows == 0 {
Array2::zeros((0, n_features))
} else {
Array2::from_shape_vec((n_rows, n_features), flat).map_err(|e| {
DatasetsError::InvalidFormat(format!("Failed to shape CSV feature matrix: {e}"))
})?
};
let target: Option<Array1<f64>> = if target_col_idx.is_some() && !target_vals.is_empty() {
Some(Array1::from_vec(target_vals))
} else {
None
};
let mut ds = Dataset::new(data, target);
ds.featurenames = Some(feat_names);
Ok(ds)
}
#[cfg(feature = "formats")]
fn write_dataset_csv<P: AsRef<Path>>(dataset: &Dataset, path: P, config: &CsvConfig) -> Result<()> {
use std::fs;
use std::io::Write;
let col_names = feature_column_names(dataset);
let n_rows = dataset.n_samples();
let n_feats = dataset.n_features();
let has_target = dataset.target.is_some();
let prec = config.float_precision;
let delim = config.delimiter;
let mut headers: Vec<String> = col_names;
if has_target {
headers.push(PARQUET_TARGET_COLUMN.to_string());
}
let mut data_rows: Vec<Vec<String>> = Vec::with_capacity(n_rows);
for row_idx in 0..n_rows {
let mut fields: Vec<String> = (0..n_feats)
.map(|col_idx| format!("{:.prec$e}", dataset.data[[row_idx, col_idx]], prec = prec))
.collect();
if let Some(target) = &dataset.target {
fields.push(format!("{:.prec$e}", target[row_idx], prec = prec));
}
data_rows.push(fields);
}
let csv_text = format_csv_rows(&headers, &data_rows, delim);
let mut file = fs::File::create(path)
.map_err(|e| DatasetsError::InvalidFormat(format!("CSV create error: {e}")))?;
file.write_all(csv_text.as_bytes())
.map_err(|e| DatasetsError::InvalidFormat(format!("CSV write error: {e}")))?;
file.flush()
.map_err(|e| DatasetsError::InvalidFormat(format!("CSV flush error: {e}")))?;
Ok(())
}
#[cfg(feature = "formats")]
pub struct CsvReader {
config: CsvConfig,
}
#[cfg(feature = "formats")]
impl CsvReader {
pub fn new() -> Self {
Self {
config: CsvConfig::default(),
}
}
pub fn with_config(config: CsvConfig) -> Self {
Self { config }
}
pub fn read<P: AsRef<Path>>(&self, path: P) -> Result<Dataset> {
read_dataset_csv(path, &self.config)
}
}
#[cfg(feature = "formats")]
impl Default for CsvReader {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
pub struct CsvWriter {
config: CsvConfig,
}
#[cfg(feature = "formats")]
impl CsvWriter {
pub fn new() -> Self {
Self {
config: CsvConfig::default(),
}
}
pub fn with_config(config: CsvConfig) -> Self {
Self { config }
}
pub fn write<P: AsRef<Path>>(&self, dataset: &Dataset, path: P) -> Result<()> {
write_dataset_csv(dataset, path, &self.config)
}
}
#[cfg(feature = "formats")]
impl Default for CsvWriter {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
pub struct FormatConverter {
config: FormatConfig,
}
#[cfg(feature = "formats")]
impl FormatConverter {
pub fn new() -> Self {
Self {
config: FormatConfig::default(),
}
}
pub fn convert<P1: AsRef<Path>, P2: AsRef<Path>>(
&self,
input_path: P1,
input_format: FormatType,
output_path: P2,
output_format: FormatType,
) -> Result<()> {
let dataset = match input_format {
FormatType::Parquet => ParquetReader::new().read(input_path)?,
FormatType::Hdf5 => Hdf5Reader::new().read(input_path, "data")?,
FormatType::Csv => CsvReader::new().read(input_path)?,
FormatType::Arrow => {
return Err(DatasetsError::InvalidFormat(
"Arrow format not yet supported".to_string(),
))
}
};
match output_format {
FormatType::Parquet => ParquetWriter::new().write(&dataset, output_path)?,
FormatType::Hdf5 => Hdf5Writer::new().write(&dataset, output_path, "data")?,
FormatType::Csv => CsvWriter::new().write(&dataset, output_path)?,
FormatType::Arrow => {
return Err(DatasetsError::InvalidFormat(
"Arrow format not yet supported".to_string(),
))
}
}
Ok(())
}
pub fn read_auto<P: AsRef<Path>>(&self, path: P) -> Result<Dataset> {
let path_str = path
.as_ref()
.to_str()
.ok_or_else(|| DatasetsError::InvalidFormat("Invalid path".to_string()))?;
let format = FormatType::from_extension(path_str)
.ok_or_else(|| DatasetsError::InvalidFormat("Could not detect format".to_string()))?;
match format {
FormatType::Parquet => ParquetReader::new().read(path),
FormatType::Hdf5 => Hdf5Reader::new().read(path, "data"),
FormatType::Csv => CsvReader::new().read(path),
FormatType::Arrow => Err(DatasetsError::InvalidFormat(format!(
"Unsupported format: {:?}",
format
))),
}
}
}
#[cfg(feature = "formats")]
impl Default for FormatConverter {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "formats")]
pub fn read_parquet<P: AsRef<Path>>(path: P) -> Result<Dataset> {
ParquetReader::new().read(path)
}
#[cfg(feature = "formats")]
pub fn write_parquet<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
ParquetWriter::new().write(dataset, path)
}
#[cfg(feature = "formats")]
pub fn read_hdf5<P: AsRef<Path>>(path: P, dataset_name: &str) -> Result<Dataset> {
Hdf5Reader::new().read(path, dataset_name)
}
#[cfg(feature = "formats")]
pub fn write_hdf5<P: AsRef<Path>>(dataset: &Dataset, path: P, dataset_name: &str) -> Result<()> {
Hdf5Writer::new().write(dataset, path, dataset_name)
}
#[cfg(feature = "formats")]
pub fn read_auto<P: AsRef<Path>>(path: P) -> Result<Dataset> {
FormatConverter::new().read_auto(path)
}
#[cfg(feature = "formats")]
pub fn read_csv<P: AsRef<Path>>(path: P) -> Result<Dataset> {
CsvReader::new().read(path)
}
#[cfg(feature = "formats")]
pub fn write_csv<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
CsvWriter::new().write(dataset, path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_detection() {
assert_eq!(
FormatType::from_extension("data.parquet"),
Some(FormatType::Parquet)
);
assert_eq!(
FormatType::from_extension("data.h5"),
Some(FormatType::Hdf5)
);
assert_eq!(
FormatType::from_extension("data.csv"),
Some(FormatType::Csv)
);
assert_eq!(FormatType::from_extension("data.txt"), None);
}
#[test]
fn test_format_extension() {
assert_eq!(FormatType::Parquet.extension(), "parquet");
assert_eq!(FormatType::Hdf5.extension(), "h5");
assert_eq!(FormatType::Csv.extension(), "csv");
}
#[test]
fn test_compression_codec() {
assert_eq!(CompressionCodec::None.level(), None);
assert_eq!(CompressionCodec::Snappy.level(), None);
assert_eq!(CompressionCodec::Gzip.level(), Some(6));
assert_eq!(CompressionCodec::Zstd.level(), Some(3));
}
#[test]
fn test_format_config() {
let config = FormatConfig::default();
assert_eq!(config.chunk_size, 10_000);
assert_eq!(config.compression, Some(CompressionCodec::Snappy));
assert!(config.use_mmap);
}
#[cfg(feature = "formats")]
#[test]
fn test_parquet_roundtrip_no_target() {
use scirs2_core::ndarray::Array2;
let data = Array2::from_shape_vec(
(4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("shape");
let ds = Dataset::new(data.clone(), None);
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_parquet_roundtrip_no_target.parquet");
write_parquet(&ds, &tmp).expect("parquet write");
let recovered = read_parquet(&tmp).expect("parquet read");
assert_eq!(recovered.n_samples(), 4, "n_samples mismatch");
assert_eq!(recovered.n_features(), 3, "n_features mismatch");
assert!(recovered.target.is_none(), "unexpected target");
for row in 0..4 {
for col in 0..3 {
let expected = data[[row, col]];
let actual = recovered.data[[row, col]];
assert!(
(expected - actual).abs() < 1e-10,
"mismatch at [{row},{col}]: expected {expected}, got {actual}"
);
}
}
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_parquet_roundtrip_with_target() {
use scirs2_core::ndarray::{Array1, Array2};
let data =
Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("shape");
let target = Some(Array1::from_vec(vec![0.0, 1.0, 0.0]));
let ds = Dataset::new(data.clone(), target.clone());
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_parquet_roundtrip_with_target.parquet");
write_parquet(&ds, &tmp).expect("parquet write");
let recovered = read_parquet(&tmp).expect("parquet read");
assert_eq!(recovered.n_samples(), 3);
assert_eq!(recovered.n_features(), 2);
assert!(
recovered.target.is_some(),
"target missing after round-trip"
);
let rtarget = recovered.target.as_ref().expect("target");
assert_eq!(rtarget.len(), 3);
for (i, (&expected, &actual)) in target
.as_ref()
.expect("t")
.iter()
.zip(rtarget.iter())
.enumerate()
{
assert!(
(expected - actual).abs() < 1e-10,
"target mismatch at [{i}]: expected {expected}, got {actual}"
);
}
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_parquet_roundtrip_feature_names() {
use scirs2_core::ndarray::Array2;
let data = Array2::from_shape_vec((2, 2), vec![10.0, 20.0, 30.0, 40.0]).expect("shape");
let mut ds = Dataset::new(data, None);
ds.featurenames = Some(vec!["alpha".to_string(), "beta".to_string()]);
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_parquet_feature_names.parquet");
write_parquet(&ds, &tmp).expect("parquet write");
let recovered = read_parquet(&tmp).expect("parquet read");
let names = recovered.featurenames.as_ref().expect("featurenames");
assert_eq!(names, &["alpha", "beta"]);
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_hdf5_roundtrip_no_target() {
use scirs2_core::ndarray::Array2;
let data = Array2::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("shape");
let ds = Dataset::new(data.clone(), None);
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_hdf5_roundtrip_no_target.h5");
write_hdf5(&ds, &tmp, "mydata").expect("hdf5 write");
let recovered = read_hdf5(&tmp, "mydata").expect("hdf5 read");
assert_eq!(recovered.n_samples(), 3, "n_samples mismatch");
assert_eq!(recovered.n_features(), 4, "n_features mismatch");
assert!(recovered.target.is_none());
for row in 0..3 {
for col in 0..4 {
let expected = data[[row, col]];
let actual = recovered.data[[row, col]];
assert!(
(expected - actual).abs() < 1e-10,
"mismatch [{row},{col}]: {expected} != {actual}"
);
}
}
let _ = std::fs::remove_file(&tmp);
let sidecar = format!("{}.json", tmp.to_string_lossy());
let _ = std::fs::remove_file(&sidecar);
}
#[test]
fn test_csv_parse_roundtrip_strings() {
let text = "a,b,c\n1,2,3\n4,5,6\n7,8,9\n";
let (headers, rows) = parse_csv_text(text, true, ',');
assert_eq!(headers, ["a", "b", "c"]);
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], ["1", "2", "3"]);
assert_eq!(rows[2], ["7", "8", "9"]);
}
#[test]
fn test_csv_quoted_field_with_comma() {
let original_field = "hello, world";
let headers = vec!["phrase".to_string(), "num".to_string()];
let rows = vec![vec![original_field.to_string(), "42".to_string()]];
let csv_text = format_csv_rows(&headers, &rows, ',');
assert!(csv_text.contains('"'), "expected quoting in: {csv_text:?}");
let (h2, r2) = parse_csv_text(&csv_text, true, ',');
assert_eq!(h2, ["phrase", "num"]);
assert_eq!(r2.len(), 1);
assert_eq!(r2[0][0], original_field, "quoted field not preserved");
assert_eq!(r2[0][1], "42");
}
#[test]
fn test_csv_skip_comments_and_blanks() {
let text = "# file-level comment\n\nx,y\n# row comment\n1,2\n\n3,4\n";
let (headers, rows) = parse_csv_text(text, true, ',');
assert_eq!(headers, ["x", "y"]);
assert_eq!(rows.len(), 2);
assert_eq!(rows[0], ["1", "2"]);
assert_eq!(rows[1], ["3", "4"]);
}
#[cfg(feature = "formats")]
#[test]
fn test_csv_dataset_roundtrip_3col_5row() {
use scirs2_core::ndarray::Array2;
let vals: Vec<f64> = (0..15).map(|x| x as f64 * 1.1).collect();
let data = Array2::from_shape_vec((5, 3), vals).expect("shape");
let ds = Dataset::new(data.clone(), None);
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_csv_roundtrip_3col_5row.csv");
write_csv(&ds, &tmp).expect("csv write");
let recovered = read_csv(&tmp).expect("csv read");
assert_eq!(recovered.n_samples(), 5, "n_samples mismatch");
assert_eq!(recovered.n_features(), 3, "n_features mismatch");
assert!(recovered.target.is_none());
for row in 0..5 {
for col in 0..3 {
let expected = data[[row, col]];
let actual = recovered.data[[row, col]];
assert!(
(expected - actual).abs() < 1e-10,
"mismatch at [{row},{col}]: expected {expected}, got {actual}"
);
}
}
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_csv_dataset_roundtrip_with_target() {
use scirs2_core::ndarray::{Array1, Array2};
let data =
Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("shape");
let target = Some(Array1::from_vec(vec![0.0, 1.0, 0.0]));
let ds = Dataset::new(data.clone(), target.clone());
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_csv_roundtrip_with_target.csv");
write_csv(&ds, &tmp).expect("csv write");
let recovered = read_csv(&tmp).expect("csv read");
assert_eq!(recovered.n_samples(), 3);
assert_eq!(recovered.n_features(), 2);
assert!(
recovered.target.is_some(),
"target missing after CSV round-trip"
);
let rtarget = recovered.target.as_ref().expect("target");
assert_eq!(rtarget.len(), 3);
for (i, (&e, &a)) in target
.as_ref()
.expect("target")
.iter()
.zip(rtarget.iter())
.enumerate()
{
assert!(
(e - a).abs() < 1e-10,
"target mismatch at [{i}]: expected {e}, got {a}"
);
}
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_csv_float_precision() {
use scirs2_core::ndarray::Array2;
let vals = vec![
std::f64::consts::PI,
std::f64::consts::E,
1.0 / 3.0,
1.0 / 7.0,
std::f64::consts::SQRT_2,
0.1 + 0.2, ];
let data = Array2::from_shape_vec((2, 3), vals.clone()).expect("shape");
let ds = Dataset::new(data, None);
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_csv_float_precision.csv");
write_csv(&ds, &tmp).expect("csv write");
let recovered = read_csv(&tmp).expect("csv read");
assert_eq!(recovered.n_samples(), 2);
assert_eq!(recovered.n_features(), 3);
for (idx, &original) in vals.iter().enumerate() {
let row = idx / 3;
let col = idx % 3;
let got = recovered.data[[row, col]];
assert!(
(original - got).abs() < 1e-10,
"float precision failure at [{row},{col}]: original={original}, got={got}"
);
}
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_csv_header_only_no_panic() {
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_csv_header_only.csv");
std::fs::write(&tmp, "feature_0,feature_1,feature_2\n").expect("write header-only csv");
let recovered = read_csv(&tmp).expect("csv read must not fail on header-only input");
assert_eq!(recovered.n_samples(), 0, "expected zero rows");
assert_eq!(recovered.n_features(), 3, "expected 3 feature columns");
assert!(recovered.target.is_none());
let _ = std::fs::remove_file(&tmp);
}
#[cfg(feature = "formats")]
#[test]
fn test_hdf5_roundtrip_with_target() {
use scirs2_core::ndarray::{Array1, Array2};
let data =
Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("shape");
let target = Some(Array1::from_vec(vec![1.0, 0.0]));
let ds = Dataset::new(data.clone(), target.clone());
let mut tmp = std::env::temp_dir();
tmp.push("scirs2_test_hdf5_roundtrip_with_target.h5");
write_hdf5(&ds, &tmp, "experiment").expect("hdf5 write");
let recovered = read_hdf5(&tmp, "experiment").expect("hdf5 read");
assert_eq!(recovered.n_samples(), 2);
assert_eq!(recovered.n_features(), 3);
assert!(recovered.target.is_some());
let rtarget = recovered.target.as_ref().expect("target");
assert_eq!(rtarget.len(), 2);
for (i, (&e, &a)) in target
.as_ref()
.expect("t")
.iter()
.zip(rtarget.iter())
.enumerate()
{
assert!((e - a).abs() < 1e-10, "target mismatch [{i}]: {e} != {a}");
}
let _ = std::fs::remove_file(&tmp);
let sidecar = format!("{}.json", tmp.to_string_lossy());
let _ = std::fs::remove_file(&sidecar);
}
}