use crate::Dataset;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "compression")]
use oxiarc_archive::GzipReader;
#[cfg(feature = "cloud")]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub enum CloudBackend {
Local,
#[cfg(feature = "cloud")]
S3 {
bucket: String,
region: String,
endpoint: Option<String>,
},
#[cfg(feature = "cloud")]
Gcs { bucket: String },
#[cfg(feature = "cloud")]
Azure { account: String, container: String },
}
#[cfg(feature = "cloud")]
impl Default for CloudBackend {
fn default() -> Self {
Self::Local
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub enum ZarrCompressionType {
None,
#[cfg(feature = "compression")]
Gzip,
#[cfg(feature = "compression")]
Blosc,
#[cfg(feature = "compression")]
Lz4,
#[cfg(feature = "compression")]
Zstd,
}
impl Default for ZarrCompressionType {
fn default() -> Self {
Self::None
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct ZarrConfig {
pub array_path: PathBuf,
pub labels_path: Option<PathBuf>,
pub chunk_cache_size: usize,
pub lazy_loading: bool,
pub max_parallel_chunks: usize,
pub use_memory_mapping: bool,
pub dimension_order: Option<String>,
#[cfg(feature = "cloud")]
pub cloud_backend: CloudBackend,
pub compression: ZarrCompressionType,
pub async_io: bool,
pub connection_timeout: u64,
pub retry_attempts: usize,
}
impl Default for ZarrConfig {
fn default() -> Self {
Self {
array_path: PathBuf::new(),
labels_path: None,
chunk_cache_size: 100_000_000, lazy_loading: true,
max_parallel_chunks: 4,
use_memory_mapping: true,
dimension_order: None,
#[cfg(feature = "cloud")]
cloud_backend: CloudBackend::default(),
compression: ZarrCompressionType::default(),
async_io: false,
connection_timeout: 30,
retry_attempts: 3,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct ZarrArrayInfo {
pub shape: Vec<usize>,
pub dtype: String,
pub chunks: Vec<usize>,
pub compressor: Option<String>,
pub fill_value: Option<f64>,
pub order: String, pub zarr_format: u32,
}
#[derive(Debug, Clone)]
pub struct ZarrDatasetBuilder<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static,
{
config: ZarrConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<T> ZarrDatasetBuilder<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
pub fn new() -> Self {
Self {
config: ZarrConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn array_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.config.array_path = path.as_ref().to_path_buf();
self
}
pub fn labels_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.config.labels_path = Some(path.as_ref().to_path_buf());
self
}
pub fn chunk_cache_size(mut self, size: usize) -> Self {
self.config.chunk_cache_size = size;
self
}
pub fn lazy_loading(mut self, enabled: bool) -> Self {
self.config.lazy_loading = enabled;
self
}
pub fn max_parallel_chunks(mut self, count: usize) -> Self {
self.config.max_parallel_chunks = count;
self
}
pub fn use_memory_mapping(mut self, enabled: bool) -> Self {
self.config.use_memory_mapping = enabled;
self
}
pub fn dimension_order<S: AsRef<str>>(mut self, order: S) -> Self {
self.config.dimension_order = Some(order.as_ref().to_string());
self
}
#[cfg(feature = "cloud")]
pub fn cloud_backend(mut self, backend: CloudBackend) -> Self {
self.config.cloud_backend = backend;
self
}
pub fn compression(mut self, compression: ZarrCompressionType) -> Self {
self.config.compression = compression;
self
}
pub fn async_io(mut self, enabled: bool) -> Self {
self.config.async_io = enabled;
self
}
pub fn connection_timeout(mut self, timeout: u64) -> Self {
self.config.connection_timeout = timeout;
self
}
pub fn retry_attempts(mut self, attempts: usize) -> Self {
self.config.retry_attempts = attempts;
self
}
pub fn build(self) -> Result<ZarrDataset<T>> {
ZarrDataset::from_config(self.config)
}
}
impl<T> Default for ZarrDatasetBuilder<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ZarrDataset<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
config: ZarrConfig,
array_info: ZarrArrayInfo,
labels_info: Option<ZarrArrayInfo>,
chunk_cache: Arc<std::sync::Mutex<HashMap<Vec<usize>, Vec<T>>>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> ZarrDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
pub fn from_config(config: ZarrConfig) -> Result<Self> {
let array_info = Self::read_array_metadata(&config.array_path)?;
let labels_info = if let Some(ref labels_path) = config.labels_path {
Some(Self::read_array_metadata(labels_path)?)
} else {
None
};
let chunk_cache = Arc::new(std::sync::Mutex::new(HashMap::new()));
Ok(Self {
config,
array_info,
labels_info,
chunk_cache,
_phantom: std::marker::PhantomData,
})
}
fn read_array_metadata(array_path: &Path) -> Result<ZarrArrayInfo> {
let zarray_path = array_path.join(".zarray");
if !zarray_path.exists() {
return Err(TensorError::invalid_argument(format!(
"Zarr metadata file not found: {zarray_path:?}"
)));
}
let metadata_content = fs::read_to_string(&zarray_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read Zarr metadata: {e}"))
})?;
Self::parse_zarr_metadata(&metadata_content)
}
fn parse_zarr_metadata(content: &str) -> Result<ZarrArrayInfo> {
let shape = if content.contains("\"shape\"") {
if let Some(start) = content.find("\"shape\"") {
if let Some(array_start) = content[start..].find('[') {
if let Some(array_end) = content[start + array_start..].find(']') {
let array_content =
&content[start + array_start + 1..start + array_start + array_end];
let numbers: Vec<usize> = array_content
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
if !numbers.is_empty() {
numbers
} else {
vec![1000, 224, 224, 3] }
} else {
vec![1000, 224, 224, 3]
}
} else {
vec![1000, 224, 224, 3]
}
} else {
vec![1000, 224, 224, 3]
}
} else {
vec![1000, 224, 224, 3]
};
let dtype = if content.contains("\"dtype\"") {
if let Some(start) = content.find("\"dtype\"") {
if let Some(colon) = content[start..].find(':') {
if let Some(quote_start) = content[start + colon..].find('\"') {
if let Some(quote_end) =
content[start + colon + quote_start + 1..].find('\"')
{
content[start + colon + quote_start + 1
..start + colon + quote_start + 1 + quote_end]
.to_string()
} else {
"<f4".to_string()
}
} else {
"<f4".to_string()
}
} else {
"<f4".to_string()
}
} else {
"<f4".to_string()
}
} else {
"<f4".to_string()
};
let chunks = if content.contains("\"chunks\"") {
if let Some(start) = content.find("\"chunks\"") {
if let Some(array_start) = content[start..].find('[') {
if let Some(array_end) = content[start + array_start..].find(']') {
let array_content =
&content[start + array_start + 1..start + array_start + array_end];
let numbers: Vec<usize> = array_content
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
if !numbers.is_empty() {
numbers
} else {
vec![100, 224, 224, 3] }
} else {
vec![100, 224, 224, 3]
}
} else {
vec![100, 224, 224, 3]
}
} else {
vec![100, 224, 224, 3]
}
} else {
vec![100, 224, 224, 3]
};
let compressor = if content.contains("\"compressor\"") {
if content.contains("\"blosc\"") {
Some("blosc".to_string())
} else if content.contains("\"gzip\"") {
Some("gzip".to_string())
} else if content.contains("\"lz4\"") {
Some("lz4".to_string())
} else if content.contains("\"zstd\"") {
Some("zstd".to_string())
} else {
None
}
} else {
None
};
let fill_value = if content.contains("\"fill_value\"") {
if let Some(start) = content.find("\"fill_value\"") {
if let Some(colon) = content[start..].find(':') {
let after_colon = &content[start + colon + 1..];
let end_pos = after_colon
.find(',')
.or_else(|| after_colon.find('}'))
.unwrap_or(after_colon.len());
let value_str = after_colon[..end_pos].trim();
if value_str == "null" {
None
} else {
value_str.parse::<f64>().ok()
}
} else {
Some(0.0)
}
} else {
Some(0.0)
}
} else {
None
};
let order = if content.contains("\"order\": \"F\"") {
"F".to_string()
} else {
"C".to_string()
};
let zarr_format = if content.contains("\"zarr_format\": 3") {
3
} else {
2
};
Ok(ZarrArrayInfo {
shape,
dtype,
chunks,
compressor,
fill_value,
order,
zarr_format,
})
}
pub fn array_info(&self) -> &ZarrArrayInfo {
&self.array_info
}
pub fn labels_info(&self) -> Option<&ZarrArrayInfo> {
self.labels_info.as_ref()
}
pub fn load_chunk(&self, chunk_coords: &[usize]) -> Result<Vec<T>> {
{
let cache = self
.chunk_cache
.lock()
.expect("lock should not be poisoned");
if let Some(cached_data) = cache.get(chunk_coords) {
return Ok(cached_data.clone());
}
}
let chunk_data = self.load_chunk_from_disk(chunk_coords)?;
{
let mut cache = self
.chunk_cache
.lock()
.expect("lock should not be poisoned");
cache.insert(chunk_coords.to_vec(), chunk_data.clone());
}
Ok(chunk_data)
}
fn load_chunk_from_disk(&self, chunk_coords: &[usize]) -> Result<Vec<T>> {
let chunk_name = chunk_coords
.iter()
.map(|&coord| coord.to_string())
.collect::<Vec<_>>()
.join(".");
let chunk_path = self.config.array_path.join(chunk_name);
if !chunk_path.exists() {
return Err(TensorError::invalid_argument(format!(
"Chunk file not found: {chunk_path:?}"
)));
}
let chunk_bytes = fs::read(&chunk_path)
.map_err(|e| TensorError::invalid_argument(format!("Failed to read chunk: {e}")))?;
let decompressed_data = self.decompress_chunk_data(&chunk_bytes)?;
self.bytes_to_typed_data(&decompressed_data)
}
fn bytes_to_typed_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
match self.array_info.dtype.as_str() {
"<f4" | ">f4" => self.parse_float32_data(bytes),
"<f8" | ">f8" => self.parse_float64_data(bytes),
"<i4" | ">i4" => self.parse_int32_data(bytes),
"<i8" | ">i8" => self.parse_int64_data(bytes),
"<u1" | ">u1" => self.parse_uint8_data(bytes),
_ => {
self.parse_float32_data(bytes)
}
}
}
fn parse_float32_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
if bytes.len() % 4 != 0 {
return Err(TensorError::invalid_argument(
"Byte array length not divisible by 4 for float32 data".to_string(),
));
}
let num_elements = bytes.len() / 4;
let mut data = Vec::with_capacity(num_elements);
let is_little_endian = self.array_info.dtype.starts_with('<');
for i in 0..num_elements {
let byte_slice = &bytes[i * 4..(i + 1) * 4];
let _value = if is_little_endian {
f32::from_le_bytes([byte_slice[0], byte_slice[1], byte_slice[2], byte_slice[3]])
} else {
f32::from_be_bytes([byte_slice[0], byte_slice[1], byte_slice[2], byte_slice[3]])
};
let converted =
scirs2_core::num_traits::cast::NumCast::from(_value).unwrap_or_else(|| {
eprintln!(
"Warning: Failed to convert f32 {_value} to target type, using default"
);
T::default()
});
data.push(converted);
}
Ok(data)
}
fn parse_float64_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
if bytes.len() % 8 != 0 {
return Err(TensorError::invalid_argument(
"Byte array length not divisible by 8 for float64 data".to_string(),
));
}
let num_elements = bytes.len() / 8;
let mut data = Vec::with_capacity(num_elements);
let is_little_endian = self.array_info.dtype.starts_with('<');
for i in 0..num_elements {
let byte_slice = &bytes[i * 8..(i + 1) * 8];
let _value = if is_little_endian {
f64::from_le_bytes([
byte_slice[0],
byte_slice[1],
byte_slice[2],
byte_slice[3],
byte_slice[4],
byte_slice[5],
byte_slice[6],
byte_slice[7],
])
} else {
f64::from_be_bytes([
byte_slice[0],
byte_slice[1],
byte_slice[2],
byte_slice[3],
byte_slice[4],
byte_slice[5],
byte_slice[6],
byte_slice[7],
])
};
let converted =
scirs2_core::num_traits::cast::NumCast::from(_value).unwrap_or_else(|| {
eprintln!(
"Warning: Failed to convert f64 {_value} to target type, using default"
);
T::default()
});
data.push(converted);
}
Ok(data)
}
fn parse_int32_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
if bytes.len() % 4 != 0 {
return Err(TensorError::invalid_argument(
"Byte array length not divisible by 4 for int32 data".to_string(),
));
}
let num_elements = bytes.len() / 4;
let mut data = Vec::with_capacity(num_elements);
let is_little_endian = self.array_info.dtype.starts_with('<');
for i in 0..num_elements {
let byte_slice = &bytes[i * 4..(i + 1) * 4];
let _value = if is_little_endian {
i32::from_le_bytes([byte_slice[0], byte_slice[1], byte_slice[2], byte_slice[3]])
} else {
i32::from_be_bytes([byte_slice[0], byte_slice[1], byte_slice[2], byte_slice[3]])
};
let converted =
scirs2_core::num_traits::cast::NumCast::from(_value).unwrap_or_else(|| {
eprintln!(
"Warning: Failed to convert i32 {_value} to target type, using default"
);
T::default()
});
data.push(converted);
}
Ok(data)
}
fn parse_int64_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
if bytes.len() % 8 != 0 {
return Err(TensorError::invalid_argument(
"Byte array length not divisible by 8 for int64 data".to_string(),
));
}
let num_elements = bytes.len() / 8;
let mut data = Vec::with_capacity(num_elements);
let is_little_endian = self.array_info.dtype.starts_with('<');
for i in 0..num_elements {
let byte_slice = &bytes[i * 8..(i + 1) * 8];
let _value = if is_little_endian {
i64::from_le_bytes([
byte_slice[0],
byte_slice[1],
byte_slice[2],
byte_slice[3],
byte_slice[4],
byte_slice[5],
byte_slice[6],
byte_slice[7],
])
} else {
i64::from_be_bytes([
byte_slice[0],
byte_slice[1],
byte_slice[2],
byte_slice[3],
byte_slice[4],
byte_slice[5],
byte_slice[6],
byte_slice[7],
])
};
let converted =
scirs2_core::num_traits::cast::NumCast::from(_value).unwrap_or_else(|| {
eprintln!(
"Warning: Failed to convert i64 {_value} to target type, using default"
);
T::default()
});
data.push(converted);
}
Ok(data)
}
fn parse_uint8_data(&self, bytes: &[u8]) -> Result<Vec<T>> {
let mut data = Vec::with_capacity(bytes.len());
for &_byte in bytes {
let converted =
scirs2_core::num_traits::cast::NumCast::from(_byte).unwrap_or_else(|| {
eprintln!(
"Warning: Failed to convert u8 {_byte} to target type, using default"
);
T::default()
});
data.push(converted);
}
Ok(data)
}
fn decompress_chunk_data(&self, compressed_data: &[u8]) -> Result<Vec<u8>> {
match &self.array_info.compressor {
Some(compressor) => {
match compressor.as_str() {
#[cfg(feature = "compression")]
"gzip" => {
let mut gzip_reader = GzipReader::new(std::io::Cursor::new(
compressed_data,
))
.map_err(|e| {
TensorError::invalid_argument(format!(
"Gzip decompression init failed: {e}"
))
})?;
let decompressed = gzip_reader.decompress().map_err(|e| {
TensorError::invalid_argument(format!("Gzip decompression failed: {e}"))
})?;
Ok(decompressed)
}
"blosc" => {
Ok(compressed_data.to_vec())
}
"lz4" => {
Ok(compressed_data.to_vec())
}
"zstd" => {
Ok(compressed_data.to_vec())
}
_ => {
Ok(compressed_data.to_vec())
}
}
}
None => {
Ok(compressed_data.to_vec())
}
}
}
fn sample_to_chunk_coords(&self, index: usize) -> Vec<usize> {
let chunk_size = self.array_info.chunks[0];
vec![index / chunk_size]
}
pub fn get_sample_data(&self, index: usize) -> Result<(Vec<T>, Option<T>)> {
let chunk_coords = self.sample_to_chunk_coords(index);
let chunk_data = self.load_chunk(&chunk_coords)?;
let chunk_size = self.array_info.chunks[0];
let sample_offset = index % chunk_size;
let sample_size = self.array_info.shape[1..].iter().product::<usize>();
let start_idx = sample_offset * sample_size;
let end_idx = start_idx + sample_size;
if end_idx > chunk_data.len() {
return Err(TensorError::invalid_argument(format!(
"Sample index {index} out of bounds"
)));
}
let features = chunk_data[start_idx..end_idx].to_vec();
let label = if self.labels_info.is_some() {
Some(T::default())
} else {
None
};
Ok((features, label))
}
}
impl<T> Dataset<T> for ZarrDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
fn len(&self) -> usize {
self.array_info.shape[0]
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of size {}",
index,
self.len()
)));
}
let (features, label) = self.get_sample_data(index)?;
let feature_shape = self.array_info.shape[1..].to_vec();
let feature_tensor = Tensor::from_vec(features, &feature_shape)?;
let label_tensor = if let Some(label_val) = label {
Tensor::from_vec(vec![label_val], &[1])?
} else {
Tensor::from_vec(vec![T::default()], &[1])?
};
Ok((feature_tensor, label_tensor))
}
}
pub trait ZarrDatasetExt<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static,
{
fn from_zarr_path<P: AsRef<Path>>(path: P) -> Result<ZarrDataset<T>>;
fn from_zarr_with_labels<P: AsRef<Path>>(
array_path: P,
labels_path: P,
) -> Result<ZarrDataset<T>>;
}
impl<T> ZarrDatasetExt<T> for ZarrDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::str::FromStr
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
fn from_zarr_path<P: AsRef<Path>>(path: P) -> Result<ZarrDataset<T>> {
ZarrDatasetBuilder::new().array_path(path).build()
}
fn from_zarr_with_labels<P: AsRef<Path>>(
array_path: P,
labels_path: P,
) -> Result<ZarrDataset<T>> {
ZarrDatasetBuilder::new()
.array_path(array_path)
.labels_path(labels_path)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zarr_config_creation() {
let config = ZarrConfig::default();
assert_eq!(config.chunk_cache_size, 100_000_000);
assert!(config.lazy_loading);
assert_eq!(config.max_parallel_chunks, 4);
assert!(config.use_memory_mapping);
#[cfg(feature = "cloud")]
assert_eq!(config.cloud_backend, CloudBackend::Local);
assert_eq!(config.compression, ZarrCompressionType::None);
assert!(!config.async_io);
assert_eq!(config.connection_timeout, 30);
assert_eq!(config.retry_attempts, 3);
}
#[test]
fn test_zarr_builder() {
let builder = ZarrDatasetBuilder::<f32>::new()
.array_path("/path/to/array")
.chunk_cache_size(50_000_000)
.lazy_loading(false)
.max_parallel_chunks(8)
.compression(ZarrCompressionType::None)
.async_io(true)
.connection_timeout(60)
.retry_attempts(5);
assert_eq!(builder.config.array_path, PathBuf::from("/path/to/array"));
assert_eq!(builder.config.chunk_cache_size, 50_000_000);
assert!(!builder.config.lazy_loading);
assert_eq!(builder.config.max_parallel_chunks, 8);
assert_eq!(builder.config.compression, ZarrCompressionType::None);
assert!(builder.config.async_io);
assert_eq!(builder.config.connection_timeout, 60);
assert_eq!(builder.config.retry_attempts, 5);
}
#[test]
fn test_chunk_coordinate_calculation() {
let config = ZarrConfig {
array_path: PathBuf::from("/test"),
..Default::default()
};
let array_info = ZarrArrayInfo {
shape: vec![1000, 224, 224, 3],
dtype: "<f4".to_string(),
chunks: vec![100, 224, 224, 3],
compressor: None,
fill_value: None,
order: "C".to_string(),
zarr_format: 2,
};
let dataset = ZarrDataset::<f32> {
config,
array_info,
labels_info: None,
chunk_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
_phantom: std::marker::PhantomData,
};
assert_eq!(dataset.sample_to_chunk_coords(0), vec![0]);
assert_eq!(dataset.sample_to_chunk_coords(50), vec![0]);
assert_eq!(dataset.sample_to_chunk_coords(100), vec![1]);
assert_eq!(dataset.sample_to_chunk_coords(250), vec![2]);
}
#[test]
fn test_zarr_array_info() {
let info = ZarrArrayInfo {
shape: vec![1000, 224, 224, 3],
dtype: "<f4".to_string(),
chunks: vec![100, 224, 224, 3],
compressor: Some("blosc".to_string()),
fill_value: Some(0.0),
order: "C".to_string(),
zarr_format: 2,
};
assert_eq!(info.shape, vec![1000, 224, 224, 3]);
assert_eq!(info.dtype, "<f4");
assert_eq!(
info.compressor
.as_ref()
.expect("test: value should be present"),
"blosc"
);
assert_eq!(info.zarr_format, 2);
}
#[test]
fn test_cloud_backend_configuration() {
#[cfg(feature = "cloud")]
{
let s3_backend = CloudBackend::S3 {
bucket: "my-bucket".to_string(),
region: "us-west-2".to_string(),
endpoint: None,
};
assert!(matches!(s3_backend, CloudBackend::S3 { .. }));
let gcs_backend = CloudBackend::Gcs {
bucket: "my-gcs-bucket".to_string(),
};
assert!(matches!(gcs_backend, CloudBackend::Gcs { .. }));
let local_backend = CloudBackend::Local;
assert_eq!(local_backend, CloudBackend::default());
}
}
#[test]
fn test_compression_types() {
let none_compression = ZarrCompressionType::None;
assert_eq!(none_compression, ZarrCompressionType::default());
#[cfg(feature = "compression")]
{
let gzip_compression = ZarrCompressionType::Gzip;
assert!(matches!(gzip_compression, ZarrCompressionType::Gzip));
let blosc_compression = ZarrCompressionType::Blosc;
assert!(matches!(blosc_compression, ZarrCompressionType::Blosc));
}
}
#[test]
fn test_enhanced_metadata_parsing() {
let json_content = r#"{
"chunks": [100, 224, 224, 3],
"compressor": {"id": "blosc"},
"dtype": "<f4",
"fill_value": 0.0,
"filters": null,
"order": "C",
"shape": [1000, 224, 224, 3],
"zarr_format": 2
}"#;
let result = ZarrDataset::<f32>::parse_zarr_metadata(json_content);
assert!(result.is_ok());
let metadata = result.expect("test: operation should succeed");
assert_eq!(metadata.shape, vec![1000, 224, 224, 3]);
assert_eq!(metadata.dtype, "<f4");
assert_eq!(metadata.chunks, vec![100, 224, 224, 3]);
assert_eq!(metadata.order, "C");
assert_eq!(metadata.zarr_format, 2);
}
}