Skip to main content

alimentar/
tensor.rs

1//! Tensor conversion utilities for ML framework integration.
2//!
3//! Provides utilities for converting Arrow data to tensor-friendly formats
4//! suitable for ML training. This module enables efficient zero-copy or
5//! minimal-copy data transfer to ML frameworks.
6//!
7//! # Example
8//!
9//! ```
10//! use alimentar::{ArrowDataset, Dataset};
11//! use alimentar::tensor::{TensorData, TensorExtractor};
12//!
13//! # fn main() -> alimentar::Result<()> {
14//! # use std::sync::Arc;
15//! # use arrow::{array::{Int32Array, Float64Array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch};
16//! # let schema = Arc::new(Schema::new(vec![
17//! #     Field::new("x", DataType::Float64, false),
18//! #     Field::new("y", DataType::Float64, false),
19//! # ]));
20//! # let batch = RecordBatch::try_new(
21//! #     schema,
22//! #     vec![
23//! #         Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
24//! #         Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0])),
25//! #     ],
26//! # )?;
27//! # let dataset = ArrowDataset::from_batch(batch)?;
28//! // Extract features as f32 tensor data
29//! let extractor = TensorExtractor::new(&["x", "y"]);
30//! let tensor_data = extractor.extract_f32(dataset.get_batch(0).unwrap())?;
31//!
32//! println!("Shape: {:?}", tensor_data.shape());
33//! println!("Data: {:?}", tensor_data.as_slice());
34//! # Ok(())
35//! # }
36//! ```
37
38use std::sync::Arc;
39
40use arrow::{
41    array::{Array, AsArray},
42    datatypes::DataType,
43    record_batch::RecordBatch,
44};
45
46use crate::error::{Error, Result};
47
48/// Tensor data in a contiguous memory layout.
49///
50/// This struct holds tensor data in row-major (C-style) order,
51/// suitable for direct transfer to ML frameworks.
52#[derive(Debug, Clone)]
53pub struct TensorData<T> {
54    /// The underlying data buffer
55    data: Vec<T>,
56    /// Shape of the tensor [rows, cols]
57    shape: [usize; 2],
58}
59
60impl<T: Clone + Default> TensorData<T> {
61    /// Creates a new tensor with the given shape, filled with default values.
62    pub fn new(rows: usize, cols: usize) -> Self {
63        Self {
64            data: vec![T::default(); rows * cols],
65            shape: [rows, cols],
66        }
67    }
68
69    /// Creates a tensor from existing data and shape.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if the data length doesn't match rows * cols.
74    ///
75    /// #[requires(data.len() == rows * cols)]
76    /// #[ensures(result.is_ok() ==> result.shape() == [rows, cols])]
77    /// #[ensures(result.is_ok() ==> result.as_slice().len() == rows * cols)]
78    /// #[invariant(self.data.len() == self.shape[0] * self.shape[1])]
79    pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Result<Self> {
80        if data.len() != rows * cols {
81            return Err(Error::data(format!(
82                "Data length {} doesn't match shape [{}, {}]",
83                data.len(),
84                rows,
85                cols
86            )));
87        }
88        Ok(Self {
89            data,
90            shape: [rows, cols],
91        })
92    }
93
94    /// Returns the shape of the tensor as [rows, cols].
95    pub fn shape(&self) -> [usize; 2] {
96        self.shape
97    }
98
99    /// Returns the number of rows.
100    pub fn rows(&self) -> usize {
101        self.shape[0]
102    }
103
104    /// Returns the number of columns.
105    pub fn cols(&self) -> usize {
106        self.shape[1]
107    }
108
109    /// Returns the underlying data as a slice.
110    pub fn as_slice(&self) -> &[T] {
111        &self.data
112    }
113
114    /// Returns the underlying data as a mutable slice.
115    pub fn as_mut_slice(&mut self) -> &mut [T] {
116        &mut self.data
117    }
118
119    /// Consumes the tensor and returns the underlying data.
120    pub fn into_vec(self) -> Vec<T> {
121        self.data
122    }
123
124    /// Returns a raw pointer to the underlying data.
125    ///
126    /// Useful for FFI integration with ML frameworks.
127    pub fn as_ptr(&self) -> *const T {
128        self.data.as_ptr()
129    }
130
131    /// Gets an element at the given row and column.
132    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
133        if row < self.shape[0] && col < self.shape[1] {
134            Some(&self.data[row * self.shape[1] + col])
135        } else {
136            None
137        }
138    }
139
140    /// Sets an element at the given row and column.
141    ///
142    /// # Panics
143    ///
144    /// Panics if the indices are out of bounds.
145    pub fn set(&mut self, row: usize, col: usize, value: T) {
146        assert!(row < self.shape[0] && col < self.shape[1]);
147        self.data[row * self.shape[1] + col] = value;
148    }
149}
150
151/// Extracts tensor data from Arrow RecordBatches.
152///
153/// This struct configures which columns to extract and how to convert them.
154#[derive(Debug, Clone)]
155pub struct TensorExtractor {
156    /// Column names to extract
157    columns: Vec<String>,
158}
159
160impl TensorExtractor {
161    /// Creates a new extractor for the specified columns.
162    pub fn new(columns: &[&str]) -> Self {
163        Self {
164            columns: columns.iter().map(|s| (*s).to_string()).collect(),
165        }
166    }
167
168    /// Creates an extractor from owned column names.
169    pub fn from_columns(columns: Vec<String>) -> Self {
170        Self { columns }
171    }
172
173    /// Returns the column names being extracted.
174    pub fn columns(&self) -> &[String] {
175        &self.columns
176    }
177
178    /// Extracts data as f32 tensor.
179    ///
180    /// Numeric columns are converted to f32. Non-numeric columns cause an
181    /// error.
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if:
186    /// - A requested column doesn't exist
187    /// - A column contains non-numeric data
188    pub fn extract_f32(&self, batch: &RecordBatch) -> Result<TensorData<f32>> {
189        let rows = batch.num_rows();
190        let cols = self.columns.len();
191
192        let mut data = vec![0.0f32; rows * cols];
193
194        for (col_idx, col_name) in self.columns.iter().enumerate() {
195            let col_index = batch
196                .schema()
197                .index_of(col_name)
198                .map_err(|_| Error::column_not_found(col_name))?;
199
200            let array = batch.column(col_index);
201            Self::extract_column_f32(array, &mut data, col_idx, cols, rows)?;
202        }
203
204        TensorData::from_vec(data, rows, cols)
205    }
206
207    /// Extracts data as f64 tensor.
208    ///
209    /// Numeric columns are converted to f64. Non-numeric columns cause an
210    /// error.
211    ///
212    /// # Errors
213    ///
214    /// Returns an error if:
215    /// - A requested column doesn't exist
216    /// - A column contains non-numeric data
217    pub fn extract_f64(&self, batch: &RecordBatch) -> Result<TensorData<f64>> {
218        let rows = batch.num_rows();
219        let cols = self.columns.len();
220
221        let mut data = vec![0.0f64; rows * cols];
222
223        for (col_idx, col_name) in self.columns.iter().enumerate() {
224            let col_index = batch
225                .schema()
226                .index_of(col_name)
227                .map_err(|_| Error::column_not_found(col_name))?;
228
229            let array = batch.column(col_index);
230            Self::extract_column_f64(array, &mut data, col_idx, cols, rows)?;
231        }
232
233        TensorData::from_vec(data, rows, cols)
234    }
235
236    /// Extracts data as i64 tensor.
237    ///
238    /// Integer columns are converted to i64. Non-integer columns cause an
239    /// error.
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if:
244    /// - A requested column doesn't exist
245    /// - A column contains non-integer data
246    pub fn extract_i64(&self, batch: &RecordBatch) -> Result<TensorData<i64>> {
247        let rows = batch.num_rows();
248        let cols = self.columns.len();
249
250        let mut data = vec![0i64; rows * cols];
251
252        for (col_idx, col_name) in self.columns.iter().enumerate() {
253            let col_index = batch
254                .schema()
255                .index_of(col_name)
256                .map_err(|_| Error::column_not_found(col_name))?;
257
258            let array = batch.column(col_index);
259            Self::extract_column_i64(array, &mut data, col_idx, cols, rows)?;
260        }
261
262        TensorData::from_vec(data, rows, cols)
263    }
264
265    fn extract_column_f32(
266        array: &Arc<dyn Array>,
267        data: &mut [f32],
268        col_idx: usize,
269        num_cols: usize,
270        num_rows: usize,
271    ) -> Result<()> {
272        match array.data_type() {
273            DataType::Float32 => {
274                let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
275                for row in 0..num_rows {
276                    data[row * num_cols + col_idx] = arr.value(row);
277                }
278            }
279            DataType::Float64 => {
280                let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
281                for row in 0..num_rows {
282                    #[allow(clippy::cast_possible_truncation)]
283                    {
284                        data[row * num_cols + col_idx] = arr.value(row) as f32;
285                    }
286                }
287            }
288            DataType::Int8 => {
289                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
290                for row in 0..num_rows {
291                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
292                }
293            }
294            DataType::Int16 => {
295                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
296                for row in 0..num_rows {
297                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
298                }
299            }
300            DataType::Int32 => {
301                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
302                for row in 0..num_rows {
303                    #[allow(clippy::cast_precision_loss)]
304                    {
305                        data[row * num_cols + col_idx] = arr.value(row) as f32;
306                    }
307                }
308            }
309            DataType::Int64 => {
310                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
311                for row in 0..num_rows {
312                    #[allow(clippy::cast_precision_loss)]
313                    {
314                        data[row * num_cols + col_idx] = arr.value(row) as f32;
315                    }
316                }
317            }
318            DataType::UInt8 => {
319                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
320                for row in 0..num_rows {
321                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
322                }
323            }
324            DataType::UInt16 => {
325                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
326                for row in 0..num_rows {
327                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
328                }
329            }
330            DataType::UInt32 => {
331                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
332                for row in 0..num_rows {
333                    #[allow(clippy::cast_precision_loss)]
334                    {
335                        data[row * num_cols + col_idx] = arr.value(row) as f32;
336                    }
337                }
338            }
339            DataType::UInt64 => {
340                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
341                for row in 0..num_rows {
342                    #[allow(clippy::cast_precision_loss)]
343                    {
344                        data[row * num_cols + col_idx] = arr.value(row) as f32;
345                    }
346                }
347            }
348            dt => {
349                return Err(Error::data(format!(
350                    "Cannot convert {:?} to f32 tensor",
351                    dt
352                )));
353            }
354        }
355        Ok(())
356    }
357
358    fn extract_column_f64(
359        array: &Arc<dyn Array>,
360        data: &mut [f64],
361        col_idx: usize,
362        num_cols: usize,
363        num_rows: usize,
364    ) -> Result<()> {
365        match array.data_type() {
366            DataType::Float32 => {
367                let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
368                for row in 0..num_rows {
369                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
370                }
371            }
372            DataType::Float64 => {
373                let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
374                for row in 0..num_rows {
375                    data[row * num_cols + col_idx] = arr.value(row);
376                }
377            }
378            DataType::Int8 => {
379                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
380                for row in 0..num_rows {
381                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
382                }
383            }
384            DataType::Int16 => {
385                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
386                for row in 0..num_rows {
387                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
388                }
389            }
390            DataType::Int32 => {
391                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
392                for row in 0..num_rows {
393                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
394                }
395            }
396            DataType::Int64 => {
397                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
398                for row in 0..num_rows {
399                    #[allow(clippy::cast_precision_loss)]
400                    {
401                        data[row * num_cols + col_idx] = arr.value(row) as f64;
402                    }
403                }
404            }
405            DataType::UInt8 => {
406                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
407                for row in 0..num_rows {
408                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
409                }
410            }
411            DataType::UInt16 => {
412                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
413                for row in 0..num_rows {
414                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
415                }
416            }
417            DataType::UInt32 => {
418                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
419                for row in 0..num_rows {
420                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
421                }
422            }
423            DataType::UInt64 => {
424                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
425                for row in 0..num_rows {
426                    #[allow(clippy::cast_precision_loss)]
427                    {
428                        data[row * num_cols + col_idx] = arr.value(row) as f64;
429                    }
430                }
431            }
432            dt => {
433                return Err(Error::data(format!(
434                    "Cannot convert {:?} to f64 tensor",
435                    dt
436                )));
437            }
438        }
439        Ok(())
440    }
441
442    fn extract_column_i64(
443        array: &Arc<dyn Array>,
444        data: &mut [i64],
445        col_idx: usize,
446        num_cols: usize,
447        num_rows: usize,
448    ) -> Result<()> {
449        match array.data_type() {
450            DataType::Int8 => {
451                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
452                for row in 0..num_rows {
453                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
454                }
455            }
456            DataType::Int16 => {
457                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
458                for row in 0..num_rows {
459                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
460                }
461            }
462            DataType::Int32 => {
463                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
464                for row in 0..num_rows {
465                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
466                }
467            }
468            DataType::Int64 => {
469                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
470                for row in 0..num_rows {
471                    data[row * num_cols + col_idx] = arr.value(row);
472                }
473            }
474            DataType::UInt8 => {
475                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
476                for row in 0..num_rows {
477                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
478                }
479            }
480            DataType::UInt16 => {
481                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
482                for row in 0..num_rows {
483                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
484                }
485            }
486            DataType::UInt32 => {
487                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
488                for row in 0..num_rows {
489                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
490                }
491            }
492            DataType::UInt64 => {
493                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
494                for row in 0..num_rows {
495                    #[allow(clippy::cast_possible_wrap)]
496                    {
497                        data[row * num_cols + col_idx] = arr.value(row) as i64;
498                    }
499                }
500            }
501            dt => {
502                return Err(Error::data(format!(
503                    "Cannot convert {:?} to i64 tensor",
504                    dt
505                )));
506            }
507        }
508        Ok(())
509    }
510}
511
512/// Extracts a single numeric column as a 1D vector.
513///
514/// # Errors
515///
516/// Returns an error if the column doesn't exist or is non-numeric.
517pub fn extract_column_f32(batch: &RecordBatch, column: &str) -> Result<Vec<f32>> {
518    let extractor = TensorExtractor::new(&[column]);
519    let tensor = extractor.extract_f32(batch)?;
520    Ok(tensor.into_vec())
521}
522
523/// Extracts a single numeric column as a 1D vector.
524///
525/// # Errors
526///
527/// Returns an error if the column doesn't exist or is non-numeric.
528pub fn extract_column_f64(batch: &RecordBatch, column: &str) -> Result<Vec<f64>> {
529    let extractor = TensorExtractor::new(&[column]);
530    let tensor = extractor.extract_f64(batch)?;
531    Ok(tensor.into_vec())
532}
533
534/// Extracts label column as integer indices.
535///
536/// String labels are converted to indices based on unique values.
537///
538/// # Errors
539///
540/// Returns an error if the column doesn't exist.
541pub fn extract_labels_i64(batch: &RecordBatch, column: &str) -> Result<Vec<i64>> {
542    use arrow::datatypes::{
543        Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
544        UInt32Type, UInt64Type, UInt8Type,
545    };
546
547    let col_index = batch
548        .schema()
549        .index_of(column)
550        .map_err(|_| Error::column_not_found(column))?;
551
552    let array = batch.column(col_index);
553
554    match array.data_type() {
555        DataType::Int8 => cast_and_collect::<Int8Type>(array, "Int8Array", |v| i64::from(v)),
556        DataType::Int16 => cast_and_collect::<Int16Type>(array, "Int16Array", |v| i64::from(v)),
557        DataType::Int32 => cast_and_collect::<Int32Type>(array, "Int32Array", |v| i64::from(v)),
558        DataType::Int64 => cast_and_collect::<Int64Type>(array, "Int64Array", |v| v),
559        DataType::UInt8 => cast_and_collect::<UInt8Type>(array, "UInt8Array", |v| i64::from(v)),
560        DataType::UInt16 => cast_and_collect::<UInt16Type>(array, "UInt16Array", |v| i64::from(v)),
561        DataType::UInt32 => cast_and_collect::<UInt32Type>(array, "UInt32Array", |v| i64::from(v)),
562        DataType::UInt64 =>
563        {
564            #[allow(clippy::cast_possible_wrap)]
565            cast_and_collect::<UInt64Type>(array, "UInt64Array", |v| v as i64)
566        }
567        DataType::Float32 =>
568        {
569            #[allow(clippy::cast_possible_truncation)]
570            cast_and_collect::<Float32Type>(array, "Float32Array", |v| v as i64)
571        }
572        DataType::Float64 =>
573        {
574            #[allow(clippy::cast_possible_truncation)]
575            cast_and_collect::<Float64Type>(array, "Float64Array", |v| v as i64)
576        }
577        dt => Err(Error::data(format!("Cannot extract labels from {:?}", dt))),
578    }
579}
580
581/// Downcast `array` to `PrimitiveArray<T>` and map each element (or
582/// default on null) through `cast`.
583fn cast_and_collect<T>(
584    array: &dyn arrow::array::Array,
585    type_name: &str,
586    cast: impl Fn(T::Native) -> i64,
587) -> Result<Vec<i64>>
588where
589    T: arrow::datatypes::ArrowPrimitiveType,
590    T::Native: Copy + Default,
591{
592    let arr = array
593        .as_any()
594        .downcast_ref::<arrow::array::PrimitiveArray<T>>()
595        .ok_or_else(|| Error::data(format!("Failed to downcast to {type_name}")))?;
596    Ok(arr.iter().map(|v| cast(v.unwrap_or_default())).collect())
597}
598
599#[cfg(test)]
600#[allow(
601    clippy::cast_possible_truncation,
602    clippy::cast_possible_wrap,
603    clippy::uninlined_format_args,
604    clippy::unwrap_used,
605    clippy::expect_used,
606    clippy::float_cmp
607)]
608mod tests {
609    use arrow::{
610        array::{
611            Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
612            UInt32Array, UInt64Array, UInt8Array,
613        },
614        datatypes::{Field, Schema},
615    };
616
617    use super::*;
618
619    fn create_numeric_batch() -> RecordBatch {
620        let schema = Arc::new(Schema::new(vec![
621            Field::new("f32_col", DataType::Float32, false),
622            Field::new("f64_col", DataType::Float64, false),
623            Field::new("i32_col", DataType::Int32, false),
624            Field::new("i64_col", DataType::Int64, false),
625        ]));
626
627        RecordBatch::try_new(
628            schema,
629            vec![
630                Arc::new(Float32Array::from(vec![1.0f32, 2.0, 3.0])),
631                Arc::new(Float64Array::from(vec![4.0f64, 5.0, 6.0])),
632                Arc::new(Int32Array::from(vec![7, 8, 9])),
633                Arc::new(Int64Array::from(vec![10i64, 11, 12])),
634            ],
635        )
636        .unwrap()
637    }
638
639    #[test]
640    fn test_tensor_data_new() {
641        let tensor: TensorData<f32> = TensorData::new(3, 4);
642        assert_eq!(tensor.shape(), [3, 4]);
643        assert_eq!(tensor.rows(), 3);
644        assert_eq!(tensor.cols(), 4);
645        assert_eq!(tensor.as_slice().len(), 12);
646    }
647
648    #[test]
649    fn test_tensor_data_from_vec() {
650        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
651        let tensor = TensorData::from_vec(data, 2, 3).unwrap();
652        assert_eq!(tensor.shape(), [2, 3]);
653        assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
654    }
655
656    #[test]
657    fn test_tensor_data_from_vec_invalid_shape() {
658        let data = vec![1.0f32, 2.0, 3.0, 4.0];
659        let result = TensorData::from_vec(data, 2, 3);
660        assert!(result.is_err());
661    }
662
663    #[test]
664    fn test_tensor_data_get_set() {
665        let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], 2, 2).unwrap();
666
667        assert_eq!(tensor.get(0, 0), Some(&1.0f32));
668        assert_eq!(tensor.get(0, 1), Some(&2.0f32));
669        assert_eq!(tensor.get(1, 0), Some(&3.0f32));
670        assert_eq!(tensor.get(1, 1), Some(&4.0f32));
671        assert_eq!(tensor.get(2, 0), None);
672
673        tensor.set(0, 1, 99.0);
674        assert_eq!(tensor.get(0, 1), Some(&99.0f32));
675    }
676
677    #[test]
678    fn test_tensor_data_into_vec() {
679        let data = vec![1.0f32, 2.0, 3.0];
680        let tensor = TensorData::from_vec(data.clone(), 1, 3).unwrap();
681        assert_eq!(tensor.into_vec(), data);
682    }
683
684    #[test]
685    fn test_tensor_data_as_ptr() {
686        let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
687        let ptr = tensor.as_ptr();
688        assert!(!ptr.is_null());
689    }
690
691    #[test]
692    fn test_tensor_data_as_mut_slice() {
693        let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
694        let slice = tensor.as_mut_slice();
695        slice[0] = 10.0;
696        assert_eq!(tensor.as_slice()[0], 10.0);
697    }
698
699    #[test]
700    fn test_tensor_data_clone() {
701        let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
702        let cloned = tensor.clone();
703        assert_eq!(cloned.shape(), tensor.shape());
704        assert_eq!(cloned.as_slice(), tensor.as_slice());
705    }
706
707    #[test]
708    fn test_tensor_data_debug() {
709        let tensor = TensorData::from_vec(vec![1.0f32], 1, 1).unwrap();
710        let debug = format!("{:?}", tensor);
711        assert!(debug.contains("TensorData"));
712    }
713
714    #[test]
715    fn test_extractor_new() {
716        let extractor = TensorExtractor::new(&["a", "b", "c"]);
717        assert_eq!(extractor.columns().len(), 3);
718        assert_eq!(extractor.columns()[0], "a");
719    }
720
721    #[test]
722    fn test_extractor_from_columns() {
723        let extractor = TensorExtractor::from_columns(vec!["x".to_string(), "y".to_string()]);
724        assert_eq!(extractor.columns().len(), 2);
725    }
726
727    #[test]
728    fn test_extractor_clone() {
729        let extractor = TensorExtractor::new(&["a", "b"]);
730        let cloned = extractor.clone();
731        assert_eq!(cloned.columns(), extractor.columns());
732    }
733
734    #[test]
735    fn test_extractor_debug() {
736        let extractor = TensorExtractor::new(&["col"]);
737        let debug = format!("{:?}", extractor);
738        assert!(debug.contains("TensorExtractor"));
739    }
740
741    #[test]
742    fn test_extract_f32() {
743        let batch = create_numeric_batch();
744        let extractor = TensorExtractor::new(&["f32_col", "i32_col"]);
745        let tensor = extractor.extract_f32(&batch).unwrap();
746
747        assert_eq!(tensor.shape(), [3, 2]);
748        assert_eq!(tensor.get(0, 0), Some(&1.0f32));
749        assert_eq!(tensor.get(0, 1), Some(&7.0f32));
750        assert_eq!(tensor.get(2, 0), Some(&3.0f32));
751        assert_eq!(tensor.get(2, 1), Some(&9.0f32));
752    }
753
754    #[test]
755    fn test_extract_f64() {
756        let batch = create_numeric_batch();
757        let extractor = TensorExtractor::new(&["f64_col", "i64_col"]);
758        let tensor = extractor.extract_f64(&batch).unwrap();
759
760        assert_eq!(tensor.shape(), [3, 2]);
761        assert_eq!(tensor.get(0, 0), Some(&4.0f64));
762        assert_eq!(tensor.get(0, 1), Some(&10.0f64));
763    }
764
765    #[test]
766    fn test_extract_i64() {
767        let batch = create_numeric_batch();
768        let extractor = TensorExtractor::new(&["i32_col", "i64_col"]);
769        let tensor = extractor.extract_i64(&batch).unwrap();
770
771        assert_eq!(tensor.shape(), [3, 2]);
772        assert_eq!(tensor.get(0, 0), Some(&7i64));
773        assert_eq!(tensor.get(0, 1), Some(&10i64));
774    }
775
776    #[test]
777    fn test_extract_column_not_found() {
778        let batch = create_numeric_batch();
779        let extractor = TensorExtractor::new(&["nonexistent"]);
780        let result = extractor.extract_f32(&batch);
781        assert!(result.is_err());
782    }
783
784    #[test]
785    fn test_extract_column_f32_helper() {
786        let batch = create_numeric_batch();
787        let data = extract_column_f32(&batch, "f32_col").unwrap();
788        assert_eq!(data, vec![1.0f32, 2.0, 3.0]);
789    }
790
791    #[test]
792    fn test_extract_column_f64_helper() {
793        let batch = create_numeric_batch();
794        let data = extract_column_f64(&batch, "f64_col").unwrap();
795        assert_eq!(data, vec![4.0f64, 5.0, 6.0]);
796    }
797
798    #[test]
799    fn test_extract_labels_i64() {
800        let batch = create_numeric_batch();
801        let labels = extract_labels_i64(&batch, "i32_col").unwrap();
802        assert_eq!(labels, vec![7i64, 8, 9]);
803    }
804
805    #[test]
806    fn test_extract_labels_i64_from_float() {
807        let schema = Arc::new(Schema::new(vec![Field::new(
808            "label",
809            DataType::Float64,
810            false,
811        )]));
812        let batch = RecordBatch::try_new(
813            schema,
814            vec![Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0]))],
815        )
816        .unwrap();
817
818        let labels = extract_labels_i64(&batch, "label").unwrap();
819        assert_eq!(labels, vec![0i64, 1, 2]);
820    }
821
822    #[test]
823    fn test_extract_labels_column_not_found() {
824        let batch = create_numeric_batch();
825        let result = extract_labels_i64(&batch, "nonexistent");
826        assert!(result.is_err());
827    }
828
829    #[test]
830    fn test_extract_all_int_types() {
831        let schema = Arc::new(Schema::new(vec![
832            Field::new("i8", DataType::Int8, false),
833            Field::new("i16", DataType::Int16, false),
834            Field::new("u8", DataType::UInt8, false),
835            Field::new("u16", DataType::UInt16, false),
836            Field::new("u32", DataType::UInt32, false),
837            Field::new("u64", DataType::UInt64, false),
838        ]));
839
840        let batch = RecordBatch::try_new(
841            schema,
842            vec![
843                Arc::new(Int8Array::from(vec![1i8])),
844                Arc::new(Int16Array::from(vec![2i16])),
845                Arc::new(UInt8Array::from(vec![3u8])),
846                Arc::new(UInt16Array::from(vec![4u16])),
847                Arc::new(UInt32Array::from(vec![5u32])),
848                Arc::new(UInt64Array::from(vec![6u64])),
849            ],
850        )
851        .unwrap();
852
853        // Test f32 extraction
854        let extractor = TensorExtractor::new(&["i8", "i16", "u8", "u16", "u32", "u64"]);
855        let tensor = extractor.extract_f32(&batch).unwrap();
856        assert_eq!(tensor.as_slice(), &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
857
858        // Test f64 extraction
859        let tensor = extractor.extract_f64(&batch).unwrap();
860        assert_eq!(tensor.as_slice(), &[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
861
862        // Test i64 extraction
863        let tensor = extractor.extract_i64(&batch).unwrap();
864        assert_eq!(tensor.as_slice(), &[1i64, 2, 3, 4, 5, 6]);
865    }
866
867    #[test]
868    fn test_extract_f32_from_f64() {
869        let schema = Arc::new(Schema::new(vec![Field::new(
870            "value",
871            DataType::Float64,
872            false,
873        )]));
874        let batch = RecordBatch::try_new(
875            schema,
876            vec![Arc::new(Float64Array::from(vec![1.5f64, 2.5, 3.5]))],
877        )
878        .unwrap();
879
880        let extractor = TensorExtractor::new(&["value"]);
881        let tensor = extractor.extract_f32(&batch).unwrap();
882        assert_eq!(tensor.as_slice(), &[1.5f32, 2.5, 3.5]);
883    }
884
885    #[test]
886    fn test_extract_f64_from_f32() {
887        let schema = Arc::new(Schema::new(vec![Field::new(
888            "value",
889            DataType::Float32,
890            false,
891        )]));
892        let batch = RecordBatch::try_new(
893            schema,
894            vec![Arc::new(Float32Array::from(vec![1.5f32, 2.5, 3.5]))],
895        )
896        .unwrap();
897
898        let extractor = TensorExtractor::new(&["value"]);
899        let tensor = extractor.extract_f64(&batch).unwrap();
900        // f32 -> f64 conversion is exact for these values
901        assert_eq!(tensor.as_slice(), &[1.5f64, 2.5, 3.5]);
902    }
903
904    #[test]
905    fn test_extract_unsupported_type_f32() {
906        use arrow::array::StringArray;
907
908        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
909        let batch = RecordBatch::try_new(
910            schema,
911            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
912        )
913        .unwrap();
914
915        let extractor = TensorExtractor::new(&["text"]);
916        let result = extractor.extract_f32(&batch);
917        assert!(result.is_err());
918    }
919
920    #[test]
921    fn test_extract_unsupported_type_f64() {
922        use arrow::array::StringArray;
923
924        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
925        let batch = RecordBatch::try_new(
926            schema,
927            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
928        )
929        .unwrap();
930
931        let extractor = TensorExtractor::new(&["text"]);
932        let result = extractor.extract_f64(&batch);
933        assert!(result.is_err());
934    }
935
936    #[test]
937    fn test_extract_unsupported_type_i64() {
938        use arrow::array::StringArray;
939
940        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
941        let batch = RecordBatch::try_new(
942            schema,
943            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
944        )
945        .unwrap();
946
947        let extractor = TensorExtractor::new(&["text"]);
948        let result = extractor.extract_i64(&batch);
949        assert!(result.is_err());
950    }
951
952    #[test]
953    fn test_extract_labels_unsupported_type() {
954        use arrow::array::StringArray;
955
956        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
957        let batch = RecordBatch::try_new(
958            schema,
959            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
960        )
961        .unwrap();
962
963        let result = extract_labels_i64(&batch, "text");
964        assert!(result.is_err());
965    }
966
967    #[test]
968    fn test_extract_labels_all_uint_types() {
969        let schema = Arc::new(Schema::new(vec![
970            Field::new("u8", DataType::UInt8, false),
971            Field::new("u16", DataType::UInt16, false),
972            Field::new("u32", DataType::UInt32, false),
973            Field::new("u64", DataType::UInt64, false),
974        ]));
975
976        let batch = RecordBatch::try_new(
977            schema,
978            vec![
979                Arc::new(UInt8Array::from(vec![1u8])),
980                Arc::new(UInt16Array::from(vec![2u16])),
981                Arc::new(UInt32Array::from(vec![3u32])),
982                Arc::new(UInt64Array::from(vec![4u64])),
983            ],
984        )
985        .unwrap();
986
987        assert_eq!(extract_labels_i64(&batch, "u8").unwrap(), vec![1i64]);
988        assert_eq!(extract_labels_i64(&batch, "u16").unwrap(), vec![2i64]);
989        assert_eq!(extract_labels_i64(&batch, "u32").unwrap(), vec![3i64]);
990        assert_eq!(extract_labels_i64(&batch, "u64").unwrap(), vec![4i64]);
991    }
992
993    #[test]
994    fn test_extract_labels_all_int_types() {
995        let schema = Arc::new(Schema::new(vec![
996            Field::new("i8", DataType::Int8, false),
997            Field::new("i16", DataType::Int16, false),
998            Field::new("f32", DataType::Float32, false),
999        ]));
1000
1001        let batch = RecordBatch::try_new(
1002            schema,
1003            vec![
1004                Arc::new(Int8Array::from(vec![1i8])),
1005                Arc::new(Int16Array::from(vec![2i16])),
1006                Arc::new(Float32Array::from(vec![3.0f32])),
1007            ],
1008        )
1009        .unwrap();
1010
1011        assert_eq!(extract_labels_i64(&batch, "i8").unwrap(), vec![1i64]);
1012        assert_eq!(extract_labels_i64(&batch, "i16").unwrap(), vec![2i64]);
1013        assert_eq!(extract_labels_i64(&batch, "f32").unwrap(), vec![3i64]);
1014    }
1015}