use crate::error_taxonomy::helpers as error_helpers;
use crate::Dataset;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tenflowers_core::{DType, Result, Tensor, TensorError};
#[cfg(feature = "parquet")]
use arrow::array::*;
#[cfg(feature = "parquet")]
use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
#[cfg(feature = "parquet")]
use arrow::record_batch::RecordBatch;
#[cfg(feature = "parquet")]
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
#[cfg(feature = "parquet")]
use parquet::file::reader::FileReader;
#[derive(Debug, Clone)]
pub struct ArrowConfig {
pub feature_columns: Option<Vec<String>>,
pub label_column: String,
pub batch_size: usize,
pub zero_copy: bool,
pub max_rows: Option<usize>,
pub validate_schema: bool,
}
impl Default for ArrowConfig {
fn default() -> Self {
Self {
feature_columns: None,
label_column: "label".to_string(),
batch_size: 1024,
zero_copy: true,
max_rows: None,
validate_schema: true,
}
}
}
#[cfg(feature = "parquet")]
pub struct ArrowDataset<T> {
batches: Vec<RecordBatch>,
feature_columns: Vec<String>,
label_column: String,
config: ArrowConfig,
total_rows: usize,
_phantom: std::marker::PhantomData<T>,
}
#[cfg(feature = "parquet")]
impl<T> ArrowDataset<T> {
pub fn from_parquet(path: impl AsRef<Path>, config: ArrowConfig) -> Result<Self> {
let path = path.as_ref();
if !path.exists() {
return Err(error_helpers::file_not_found(
"ArrowDataset::from_parquet",
path,
));
}
let file = std::fs::File::open(path)
.map_err(|e| error_helpers::file_not_found("ArrowDataset::from_parquet", path))?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| {
TensorError::io_error_simple(format!("Failed to open Parquet file: {}", e))
})?;
let metadata = builder.metadata();
let schema = builder.schema();
if config.validate_schema {
Self::validate_schema(schema, &config)?;
}
let feature_columns = if let Some(ref cols) = config.feature_columns {
cols.clone()
} else {
schema
.fields()
.iter()
.map(|f| f.name().clone())
.filter(|name| name != &config.label_column)
.collect()
};
let reader = builder
.with_batch_size(config.batch_size)
.build()
.map_err(|e| {
TensorError::io_error_simple(format!("Failed to build Parquet reader: {}", e))
})?;
let mut batches = Vec::new();
let mut total_rows = 0;
for batch_result in reader {
let batch = batch_result.map_err(|e| {
TensorError::io_error_simple(format!("Failed to read batch: {}", e))
})?;
total_rows += batch.num_rows();
batches.push(batch);
if let Some(max_rows) = config.max_rows {
if total_rows >= max_rows {
break;
}
}
}
Ok(Self {
batches,
feature_columns,
label_column: config.label_column.clone(),
config,
total_rows,
_phantom: std::marker::PhantomData,
})
}
pub fn from_batches(
batches: Vec<RecordBatch>,
feature_columns: Vec<String>,
label_column: String,
config: ArrowConfig,
) -> Result<Self> {
let total_rows = batches.iter().map(|b| b.num_rows()).sum();
Ok(Self {
batches,
feature_columns,
label_column,
config,
total_rows,
_phantom: std::marker::PhantomData,
})
}
fn validate_schema(schema: &Schema, config: &ArrowConfig) -> Result<()> {
if schema.column_with_name(&config.label_column).is_none() {
return Err(error_helpers::schema_mismatch(
"ArrowDataset::validate_schema",
format!("Label column '{}'", config.label_column),
"not found in schema",
));
}
if let Some(ref feature_cols) = config.feature_columns {
for col in feature_cols {
if schema.column_with_name(col).is_none() {
return Err(error_helpers::schema_mismatch(
"ArrowDataset::validate_schema",
format!("Feature column '{}'", col),
"not found in schema",
));
}
}
}
Ok(())
}
pub fn schema(&self) -> Option<Arc<Schema>> {
self.batches.first().map(|b| b.schema())
}
pub fn num_batches(&self) -> usize {
self.batches.len()
}
pub fn total_rows(&self) -> usize {
self.total_rows
}
pub fn feature_columns(&self) -> &[String] {
&self.feature_columns
}
pub fn label_column(&self) -> &str {
&self.label_column
}
pub fn get_batch(&self, batch_index: usize) -> Option<&RecordBatch> {
self.batches.get(batch_index)
}
fn array_to_tensor<U: scirs2_core::numeric::NumCast + Clone + Default>(
&self,
array: &dyn Array,
) -> Result<Tensor<U>> {
match array.data_type() {
ArrowDataType::Float32 => {
let arr = array
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Failed to downcast to Float32Array".to_string(),
)
})?;
let values: Vec<U> = arr
.values()
.iter()
.map(|&v| U::from(v).unwrap_or_default())
.collect();
Tensor::from_vec(values, &[arr.len()])
}
ArrowDataType::Float64 => {
let arr = array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Failed to downcast to Float64Array".to_string(),
)
})?;
let values: Vec<U> = arr
.values()
.iter()
.map(|&v| U::from(v).unwrap_or_default())
.collect();
Tensor::from_vec(values, &[arr.len()])
}
ArrowDataType::Int32 => {
let arr = array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Failed to downcast to Int32Array".to_string(),
)
})?;
let values: Vec<U> = arr
.values()
.iter()
.map(|&v| U::from(v).unwrap_or_default())
.collect();
Tensor::from_vec(values, &[arr.len()])
}
ArrowDataType::Int64 => {
let arr = array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Failed to downcast to Int64Array".to_string(),
)
})?;
let values: Vec<U> = arr
.values()
.iter()
.map(|&v| U::from(v).unwrap_or_default())
.collect();
Tensor::from_vec(values, &[arr.len()])
}
ArrowDataType::UInt32 => {
let arr = array
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Failed to downcast to UInt32Array".to_string(),
)
})?;
let values: Vec<U> = arr
.values()
.iter()
.map(|&v| U::from(v).unwrap_or_default())
.collect();
Tensor::from_vec(values, &[arr.len()])
}
dt => Err(TensorError::unsupported_operation_simple(format!(
"Unsupported Arrow data type: {:?}",
dt
))),
}
}
fn find_batch_and_local_index(&self, global_index: usize) -> Result<(usize, usize)> {
let mut cumulative = 0;
for (batch_idx, batch) in self.batches.iter().enumerate() {
let batch_size = batch.num_rows();
if global_index < cumulative + batch_size {
return Ok((batch_idx, global_index - cumulative));
}
cumulative += batch_size;
}
Err(error_helpers::index_out_of_bounds(
"ArrowDataset::find_batch_and_local_index",
global_index,
self.total_rows,
))
}
}
#[cfg(feature = "parquet")]
impl<T> ArrowDataset<T> {
pub fn get_batch_range(
&self,
start: usize,
end: usize,
) -> Result<(Vec<Tensor<T>>, Vec<Tensor<T>>)>
where
T: scirs2_core::numeric::NumCast + Clone + Default,
{
if start >= end || end > self.total_rows {
return Err(error_helpers::index_out_of_bounds(
"ArrowDataset::get_batch_range",
end,
self.total_rows,
));
}
let mut features_batch = Vec::new();
let mut labels_batch = Vec::new();
for i in start..end {
let (features, labels) = self.get(i)?;
features_batch.push(features);
labels_batch.push(labels);
}
Ok((features_batch, labels_batch))
}
pub fn get_raw_batch(&self, batch_index: usize) -> Option<&RecordBatch> {
self.batches.get(batch_index)
}
pub fn batches_iter(&self) -> impl Iterator<Item = &RecordBatch> {
self.batches.iter()
}
pub fn get_column_view(&self, column_name: &str) -> Result<Vec<ArrowTensorView<'_, T>>>
where
T: 'static,
{
let mut views = Vec::new();
for batch in &self.batches {
let column = batch.column_by_name(column_name).ok_or_else(|| {
error_helpers::schema_mismatch(
"ArrowDataset::get_column_view",
column_name,
"column not found",
)
})?;
}
Ok(views)
}
}
#[cfg(feature = "parquet")]
impl<T: scirs2_core::numeric::NumCast + Clone + Default> Dataset<T> for ArrowDataset<T> {
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
let (batch_idx, local_idx) = self.find_batch_and_local_index(index)?;
let batch = &self.batches[batch_idx];
let mut feature_values = Vec::new();
for col_name in &self.feature_columns {
let column = batch.column_by_name(col_name).ok_or_else(|| {
error_helpers::schema_mismatch("ArrowDataset::get", col_name, "column not found")
})?;
match column.data_type() {
ArrowDataType::Float32 => {
let arr = column
.as_any()
.downcast_ref::<Float32Array>()
.expect("data type already checked as Float32");
if !arr.is_null(local_idx) {
let val = T::from(arr.value(local_idx)).unwrap_or_default();
feature_values.push(val);
}
}
ArrowDataType::Float64 => {
let arr = column
.as_any()
.downcast_ref::<Float64Array>()
.expect("data type already checked as Float64");
if !arr.is_null(local_idx) {
let val = T::from(arr.value(local_idx)).unwrap_or_default();
feature_values.push(val);
}
}
ArrowDataType::Int32 => {
let arr = column
.as_any()
.downcast_ref::<Int32Array>()
.expect("data type already checked as Int32");
if !arr.is_null(local_idx) {
let val = T::from(arr.value(local_idx)).unwrap_or_default();
feature_values.push(val);
}
}
ArrowDataType::Int64 => {
let arr = column
.as_any()
.downcast_ref::<Int64Array>()
.expect("data type already checked as Int64");
if !arr.is_null(local_idx) {
let val = T::from(arr.value(local_idx)).unwrap_or_default();
feature_values.push(val);
}
}
_ => {
return Err(TensorError::unsupported_operation_simple(format!(
"Unsupported data type for column {}",
col_name
)));
}
}
}
let features = Tensor::from_vec(feature_values, &[self.feature_columns.len()])?;
let label_column = batch.column_by_name(&self.label_column).ok_or_else(|| {
error_helpers::schema_mismatch(
"ArrowDataset::get",
&self.label_column,
"label column not found",
)
})?;
let label_value = match label_column.data_type() {
ArrowDataType::Float32 => {
let arr = label_column
.as_any()
.downcast_ref::<Float32Array>()
.expect("label data type already checked as Float32");
T::from(arr.value(local_idx)).unwrap_or_default()
}
ArrowDataType::Float64 => {
let arr = label_column
.as_any()
.downcast_ref::<Float64Array>()
.expect("label data type already checked as Float64");
T::from(arr.value(local_idx)).unwrap_or_default()
}
ArrowDataType::Int32 => {
let arr = label_column
.as_any()
.downcast_ref::<Int32Array>()
.expect("label data type already checked as Int32");
T::from(arr.value(local_idx)).unwrap_or_default()
}
ArrowDataType::Int64 => {
let arr = label_column
.as_any()
.downcast_ref::<Int64Array>()
.expect("label data type already checked as Int64");
T::from(arr.value(local_idx)).unwrap_or_default()
}
_ => {
return Err(TensorError::unsupported_operation_simple(
"Unsupported label data type".to_string(),
));
}
};
let label = Tensor::from_vec(vec![label_value], &[1])?;
Ok((features, label))
}
fn len(&self) -> usize {
self.total_rows
}
}
#[cfg(feature = "parquet")]
pub struct ArrowDatasetBuilder {
config: ArrowConfig,
path: Option<PathBuf>,
}
#[cfg(feature = "parquet")]
impl ArrowDatasetBuilder {
pub fn new() -> Self {
Self {
config: ArrowConfig::default(),
path: None,
}
}
pub fn path(mut self, path: impl AsRef<Path>) -> Self {
self.path = Some(path.as_ref().to_path_buf());
self
}
pub fn feature_columns(mut self, columns: Vec<String>) -> Self {
self.config.feature_columns = Some(columns);
self
}
pub fn label_column(mut self, column: impl Into<String>) -> Self {
self.config.label_column = column.into();
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
pub fn zero_copy(mut self, enabled: bool) -> Self {
self.config.zero_copy = enabled;
self
}
pub fn max_rows(mut self, max: usize) -> Self {
self.config.max_rows = Some(max);
self
}
pub fn validate_schema(mut self, validate: bool) -> Self {
self.config.validate_schema = validate;
self
}
pub fn build<T: scirs2_core::numeric::NumCast + Clone + Default>(
self,
) -> Result<ArrowDataset<T>> {
let path = self.path.ok_or_else(|| {
error_helpers::invalid_configuration(
"ArrowDatasetBuilder::build",
"path",
"path must be set",
)
})?;
ArrowDataset::from_parquet(path, self.config)
}
}
#[cfg(feature = "parquet")]
impl Default for ArrowDatasetBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "parquet")]
pub struct ArrowTensorView<'a, T> {
data: &'a [T],
shape: Vec<usize>,
stride: Vec<usize>,
}
#[cfg(feature = "parquet")]
impl<'a, T: Clone + Default> ArrowTensorView<'a, T> {
pub fn new(data: &'a [T], shape: Vec<usize>) -> Self {
let stride = Self::compute_strides(&shape);
Self {
data,
shape,
stride,
}
}
pub fn new_with_strides(data: &'a [T], shape: Vec<usize>, stride: Vec<usize>) -> Self {
Self {
data,
shape,
stride,
}
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn data(&self) -> &[T] {
self.data
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.stride
}
pub fn get(&self, indices: &[usize]) -> Option<&T> {
if indices.len() != self.shape.len() {
return None;
}
let mut flat_index = 0;
for (i, (&idx, &dim)) in indices.iter().zip(&self.shape).enumerate() {
if idx >= dim {
return None;
}
flat_index += idx * self.stride[i];
}
self.data.get(flat_index)
}
pub fn slice(&self, start: usize, end: usize) -> Option<ArrowTensorView<'a, T>> {
if start >= end || end > self.data.len() {
return None;
}
let len = end - start;
Some(ArrowTensorView {
data: &self.data[start..end],
shape: vec![len],
stride: vec![1],
})
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Option<ArrowTensorView<'a, T>> {
let old_size: usize = self.shape.iter().product();
let new_size: usize = new_shape.iter().product();
if old_size != new_size {
return None;
}
Some(ArrowTensorView {
data: self.data,
shape: new_shape.clone(),
stride: Self::compute_strides(&new_shape),
})
}
pub fn to_tensor(&self) -> Result<Tensor<T>> {
Tensor::from_vec(self.data.to_vec(), &self.shape)
}
pub fn is_contiguous(&self) -> bool {
let expected_strides = Self::compute_strides(&self.shape);
self.stride == expected_strides
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[cfg(feature = "parquet")]
pub trait ArrowArrayExt {
fn as_tensor_view<T>(&self) -> Option<ArrowTensorView<'_, T>>;
fn to_tensor<T: scirs2_core::numeric::NumCast + Clone + Default>(&self) -> Result<Tensor<T>>;
}
#[cfg(feature = "parquet")]
impl ArrowArrayExt for dyn Array {
fn as_tensor_view<T>(&self) -> Option<ArrowTensorView<'_, T>> {
None }
fn to_tensor<T: scirs2_core::numeric::NumCast + Clone + Default>(&self) -> Result<Tensor<T>> {
Err(TensorError::unsupported_operation_simple(
"to_tensor not implemented for this array type".to_string(),
))
}
}
#[cfg(feature = "parquet")]
mod unified_integration {
use super::*;
use crate::formats::unified_reader::{
DataType as UnifiedDataType, DetectionMethod, FieldInfo, FormatDetection, FormatFactory,
FormatMetadata, FormatReader, FormatSample,
};
pub struct ArrowFormatReader<T> {
dataset: Arc<ArrowDataset<T>>,
current_index: usize,
}
impl<T: scirs2_core::numeric::NumCast + Clone + Default> ArrowFormatReader<T> {
pub fn new(dataset: ArrowDataset<T>) -> Self {
Self {
dataset: Arc::new(dataset),
current_index: 0,
}
}
}
impl<T: scirs2_core::numeric::NumCast + Clone + Default + Send + Sync + 'static> FormatReader
for ArrowFormatReader<T>
{
fn metadata(&self) -> Result<FormatMetadata> {
let schema = self
.dataset
.schema()
.ok_or_else(|| TensorError::io_error_simple("No schema available".to_string()))?;
let fields: Vec<FieldInfo> = schema
.fields()
.iter()
.map(|f| FieldInfo {
name: f.name().clone(),
dtype: arrow_to_unified_type(f.data_type()),
shape: None,
nullable: f.is_nullable(),
description: None,
})
.collect();
let mut metadata = std::collections::HashMap::new();
metadata.insert(
"num_features".to_string(),
self.dataset.feature_columns().len().to_string(),
);
Ok(FormatMetadata {
format_name: "Arrow/Parquet".to_string(),
version: None,
num_samples: self.dataset.total_rows(),
fields,
metadata,
supports_random_access: true,
supports_streaming: true,
})
}
fn get_sample(&self, index: usize) -> Result<FormatSample> {
Err(TensorError::unsupported_operation_simple(
"get_sample not supported for generic Arrow reader - use create_reader_f32"
.to_string(),
))
}
fn get_samples(&self, indices: &[usize]) -> Result<Vec<FormatSample>> {
indices.iter().map(|&idx| self.get_sample(idx)).collect()
}
fn iter(&self) -> Box<dyn Iterator<Item = Result<FormatSample>> + '_> {
Box::new((0..self.dataset.len()).map(|idx| self.get_sample(idx)))
}
fn validate_schema(&self, expected: &[FieldInfo]) -> Result<()> {
let metadata = self.metadata()?;
if metadata.fields.len() != expected.len() {
return Err(error_helpers::schema_mismatch(
"ArrowFormatReader::validate_schema",
format!("{} fields", expected.len()),
format!("{} fields", metadata.fields.len()),
));
}
for (actual, expected) in metadata.fields.iter().zip(expected.iter()) {
if actual.name != expected.name {
return Err(error_helpers::schema_mismatch(
"ArrowFormatReader::validate_schema",
&expected.name,
&actual.name,
));
}
}
Ok(())
}
fn supports_random_access(&self) -> bool {
true
}
fn supports_streaming(&self) -> bool {
true
}
}
pub struct ArrowFormatFactory;
impl FormatFactory for ArrowFormatFactory {
fn format_name(&self) -> &str {
"Arrow/Parquet"
}
fn extensions(&self) -> Vec<&str> {
vec!["parquet", "arrow", "ipc"]
}
fn can_read(&self, path: &Path) -> Result<FormatDetection> {
if !path.exists() {
return Ok(FormatDetection {
confidence: 0.0,
format_name: self.format_name().to_string(),
method: DetectionMethod::Extension,
});
}
let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
let confidence = if self.extensions().contains(&extension) {
0.9
} else {
if let Ok(magic) = crate::formats::unified_reader::read_magic_bytes(path, 4) {
if magic == b"PAR1" || magic.starts_with(b"ARRO") {
0.95 } else {
0.0
}
} else {
0.0
}
};
Ok(FormatDetection {
confidence,
format_name: self.format_name().to_string(),
method: if extension.is_empty() {
DetectionMethod::MagicBytes
} else {
DetectionMethod::Extension
},
})
}
fn create_reader(&self, path: &Path) -> Result<Box<dyn FormatReader>> {
let config = ArrowConfig::default();
let dataset = ArrowDataset::<f32>::from_parquet(path, config)?;
Ok(Box::new(ArrowFormatReader::new(dataset)))
}
}
fn arrow_to_unified_type(arrow_type: &ArrowDataType) -> UnifiedDataType {
match arrow_type {
ArrowDataType::Boolean => UnifiedDataType::Bool,
ArrowDataType::Int8 => UnifiedDataType::Int8,
ArrowDataType::Int16 => UnifiedDataType::Int16,
ArrowDataType::Int32 => UnifiedDataType::Int32,
ArrowDataType::Int64 => UnifiedDataType::Int64,
ArrowDataType::UInt8 => UnifiedDataType::UInt8,
ArrowDataType::UInt16 => UnifiedDataType::UInt16,
ArrowDataType::UInt32 => UnifiedDataType::UInt32,
ArrowDataType::UInt64 => UnifiedDataType::UInt64,
ArrowDataType::Float32 => UnifiedDataType::Float32,
ArrowDataType::Float64 => UnifiedDataType::Float64,
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => UnifiedDataType::String,
ArrowDataType::Binary | ArrowDataType::LargeBinary => UnifiedDataType::Binary,
ArrowDataType::List(field) | ArrowDataType::LargeList(field) => {
UnifiedDataType::List(Box::new(arrow_to_unified_type(field.data_type())))
}
_ => UnifiedDataType::Binary, }
}
}
#[cfg(feature = "parquet")]
pub use unified_integration::{ArrowFormatFactory, ArrowFormatReader};
#[cfg(test)]
#[cfg(feature = "parquet")]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_arrow_config_default() {
let config = ArrowConfig::default();
assert_eq!(config.label_column, "label");
assert_eq!(config.batch_size, 1024);
assert!(config.zero_copy);
assert!(config.validate_schema);
assert!(config.max_rows.is_none());
}
#[test]
fn test_arrow_dataset_builder() {
let builder = ArrowDatasetBuilder::new()
.label_column("target")
.batch_size(512)
.zero_copy(false)
.max_rows(1000)
.validate_schema(false);
assert_eq!(builder.config.label_column, "target");
assert_eq!(builder.config.batch_size, 512);
assert!(!builder.config.zero_copy);
assert_eq!(builder.config.max_rows, Some(1000));
assert!(!builder.config.validate_schema);
}
#[test]
fn test_arrow_config_custom() {
let config = ArrowConfig {
feature_columns: Some(vec!["col1".to_string(), "col2".to_string()]),
label_column: "target".to_string(),
batch_size: 2048,
zero_copy: false,
max_rows: Some(10000),
validate_schema: false,
};
assert_eq!(
config
.feature_columns
.as_ref()
.expect("test: value should be present")
.len(),
2
);
assert_eq!(config.label_column, "target");
assert_eq!(config.batch_size, 2048);
assert!(!config.zero_copy);
assert_eq!(config.max_rows, Some(10000));
}
#[test]
fn test_arrow_tensor_view() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![2, 3]);
assert_eq!(view.data(), &data);
assert_eq!(view.shape(), &[2, 3]);
assert_eq!(view.len(), 6);
assert!(!view.is_empty());
assert!(view.is_contiguous());
let tensor = view
.to_tensor()
.expect("test: tensor conversion should succeed");
assert_eq!(tensor.shape().dims(), &[2, 3]);
}
#[test]
fn test_arrow_tensor_view_get() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![2, 3]);
assert_eq!(view.get(&[0, 0]), Some(&1.0));
assert_eq!(view.get(&[0, 2]), Some(&3.0));
assert_eq!(view.get(&[1, 0]), Some(&4.0));
assert_eq!(view.get(&[1, 2]), Some(&6.0));
assert_eq!(view.get(&[2, 0]), None);
assert_eq!(view.get(&[0, 3]), None);
assert_eq!(view.get(&[0]), None); }
#[test]
fn test_arrow_tensor_view_slice() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![6]);
let slice = view.slice(1, 4).expect("test: slice should succeed");
assert_eq!(slice.data(), &[2.0, 3.0, 4.0]);
assert_eq!(slice.shape(), &[3]);
assert!(view.slice(5, 3).is_none()); assert!(view.slice(0, 10).is_none()); }
#[test]
fn test_arrow_tensor_view_reshape() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![6]);
let reshaped = view
.reshape(vec![2, 3])
.expect("test: reshape should succeed");
assert_eq!(reshaped.shape(), &[2, 3]);
assert_eq!(reshaped.data(), view.data());
let reshaped2 = view
.reshape(vec![3, 2])
.expect("test: reshape should succeed");
assert_eq!(reshaped2.shape(), &[3, 2]);
assert!(view.reshape(vec![2, 4]).is_none());
assert!(view.reshape(vec![5]).is_none());
}
#[test]
fn test_arrow_tensor_view_strides() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![2, 3]);
assert_eq!(view.strides(), &[3, 1]);
let reshaped = view
.reshape(vec![3, 2])
.expect("test: reshape should succeed");
assert_eq!(reshaped.strides(), &[2, 1]);
}
#[test]
fn test_arrow_tensor_view_custom_strides() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let custom_strides = vec![2, 1];
let view = ArrowTensorView::new_with_strides(&data, vec![2, 3], custom_strides.clone());
assert_eq!(view.strides(), &custom_strides);
assert!(!view.is_contiguous()); }
#[test]
fn test_arrow_tensor_view_contiguity() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let view = ArrowTensorView::new(&data, vec![2, 3]);
assert!(view.is_contiguous());
let non_contiguous = ArrowTensorView::new_with_strides(&data, vec![2, 3], vec![4, 1]);
assert!(!non_contiguous.is_contiguous());
}
}