Skip to main content

alimentar/
tensor.rs

1//! Tensor conversion utilities for ML framework integration.
2//!
3//! Provides utilities for converting Arrow data to tensor-friendly formats
4//! suitable for ML training. This module enables efficient zero-copy or
5//! minimal-copy data transfer to ML frameworks.
6//!
7//! # Example
8//!
9//! ```
10//! use alimentar::{ArrowDataset, Dataset};
11//! use alimentar::tensor::{TensorData, TensorExtractor};
12//!
13//! # fn main() -> alimentar::Result<()> {
14//! # use std::sync::Arc;
15//! # use arrow::{array::{Int32Array, Float64Array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch};
16//! # let schema = Arc::new(Schema::new(vec![
17//! #     Field::new("x", DataType::Float64, false),
18//! #     Field::new("y", DataType::Float64, false),
19//! # ]));
20//! # let batch = RecordBatch::try_new(
21//! #     schema,
22//! #     vec![
23//! #         Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
24//! #         Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0])),
25//! #     ],
26//! # )?;
27//! # let dataset = ArrowDataset::from_batch(batch)?;
28//! // Extract features as f32 tensor data
29//! let extractor = TensorExtractor::new(&["x", "y"]);
30//! let tensor_data = extractor.extract_f32(dataset.get_batch(0).unwrap())?;
31//!
32//! println!("Shape: {:?}", tensor_data.shape());
33//! println!("Data: {:?}", tensor_data.as_slice());
34//! # Ok(())
35//! # }
36//! ```
37
38use std::sync::Arc;
39
40use arrow::{
41    array::{
42        Array, AsArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
43        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
44    },
45    datatypes::DataType,
46    record_batch::RecordBatch,
47};
48
49use crate::error::{Error, Result};
50
51/// Tensor data in a contiguous memory layout.
52///
53/// This struct holds tensor data in row-major (C-style) order,
54/// suitable for direct transfer to ML frameworks.
55#[derive(Debug, Clone)]
56pub struct TensorData<T> {
57    /// The underlying data buffer
58    data: Vec<T>,
59    /// Shape of the tensor [rows, cols]
60    shape: [usize; 2],
61}
62
63impl<T: Clone + Default> TensorData<T> {
64    /// Creates a new tensor with the given shape, filled with default values.
65    pub fn new(rows: usize, cols: usize) -> Self {
66        Self {
67            data: vec![T::default(); rows * cols],
68            shape: [rows, cols],
69        }
70    }
71
72    /// Creates a tensor from existing data and shape.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the data length doesn't match rows * cols.
77    ///
78    /// #[requires(data.len() == rows * cols)]
79    /// #[ensures(result.is_ok() ==> result.shape() == [rows, cols])]
80    /// #[ensures(result.is_ok() ==> result.as_slice().len() == rows * cols)]
81    /// #[invariant(self.data.len() == self.shape[0] * self.shape[1])]
82    pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Result<Self> {
83        if data.len() != rows * cols {
84            return Err(Error::data(format!(
85                "Data length {} doesn't match shape [{}, {}]",
86                data.len(),
87                rows,
88                cols
89            )));
90        }
91        Ok(Self {
92            data,
93            shape: [rows, cols],
94        })
95    }
96
97    /// Returns the shape of the tensor as [rows, cols].
98    pub fn shape(&self) -> [usize; 2] {
99        self.shape
100    }
101
102    /// Returns the number of rows.
103    pub fn rows(&self) -> usize {
104        self.shape[0]
105    }
106
107    /// Returns the number of columns.
108    pub fn cols(&self) -> usize {
109        self.shape[1]
110    }
111
112    /// Returns the underlying data as a slice.
113    pub fn as_slice(&self) -> &[T] {
114        &self.data
115    }
116
117    /// Returns the underlying data as a mutable slice.
118    pub fn as_mut_slice(&mut self) -> &mut [T] {
119        &mut self.data
120    }
121
122    /// Consumes the tensor and returns the underlying data.
123    pub fn into_vec(self) -> Vec<T> {
124        self.data
125    }
126
127    /// Returns a raw pointer to the underlying data.
128    ///
129    /// Useful for FFI integration with ML frameworks.
130    pub fn as_ptr(&self) -> *const T {
131        self.data.as_ptr()
132    }
133
134    /// Gets an element at the given row and column.
135    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
136        if row < self.shape[0] && col < self.shape[1] {
137            Some(&self.data[row * self.shape[1] + col])
138        } else {
139            None
140        }
141    }
142
143    /// Sets an element at the given row and column.
144    ///
145    /// # Panics
146    ///
147    /// Panics if the indices are out of bounds.
148    pub fn set(&mut self, row: usize, col: usize, value: T) {
149        assert!(row < self.shape[0] && col < self.shape[1]);
150        self.data[row * self.shape[1] + col] = value;
151    }
152}
153
154/// Extracts tensor data from Arrow RecordBatches.
155///
156/// This struct configures which columns to extract and how to convert them.
157#[derive(Debug, Clone)]
158pub struct TensorExtractor {
159    /// Column names to extract
160    columns: Vec<String>,
161}
162
163impl TensorExtractor {
164    /// Creates a new extractor for the specified columns.
165    pub fn new(columns: &[&str]) -> Self {
166        Self {
167            columns: columns.iter().map(|s| (*s).to_string()).collect(),
168        }
169    }
170
171    /// Creates an extractor from owned column names.
172    pub fn from_columns(columns: Vec<String>) -> Self {
173        Self { columns }
174    }
175
176    /// Returns the column names being extracted.
177    pub fn columns(&self) -> &[String] {
178        &self.columns
179    }
180
181    /// Extracts data as f32 tensor.
182    ///
183    /// Numeric columns are converted to f32. Non-numeric columns cause an
184    /// error.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if:
189    /// - A requested column doesn't exist
190    /// - A column contains non-numeric data
191    pub fn extract_f32(&self, batch: &RecordBatch) -> Result<TensorData<f32>> {
192        let rows = batch.num_rows();
193        let cols = self.columns.len();
194
195        let mut data = vec![0.0f32; rows * cols];
196
197        for (col_idx, col_name) in self.columns.iter().enumerate() {
198            let col_index = batch
199                .schema()
200                .index_of(col_name)
201                .map_err(|_| Error::column_not_found(col_name))?;
202
203            let array = batch.column(col_index);
204            Self::extract_column_f32(array, &mut data, col_idx, cols, rows)?;
205        }
206
207        TensorData::from_vec(data, rows, cols)
208    }
209
210    /// Extracts data as f64 tensor.
211    ///
212    /// Numeric columns are converted to f64. Non-numeric columns cause an
213    /// error.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if:
218    /// - A requested column doesn't exist
219    /// - A column contains non-numeric data
220    pub fn extract_f64(&self, batch: &RecordBatch) -> Result<TensorData<f64>> {
221        let rows = batch.num_rows();
222        let cols = self.columns.len();
223
224        let mut data = vec![0.0f64; rows * cols];
225
226        for (col_idx, col_name) in self.columns.iter().enumerate() {
227            let col_index = batch
228                .schema()
229                .index_of(col_name)
230                .map_err(|_| Error::column_not_found(col_name))?;
231
232            let array = batch.column(col_index);
233            Self::extract_column_f64(array, &mut data, col_idx, cols, rows)?;
234        }
235
236        TensorData::from_vec(data, rows, cols)
237    }
238
239    /// Extracts data as i64 tensor.
240    ///
241    /// Integer columns are converted to i64. Non-integer columns cause an
242    /// error.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if:
247    /// - A requested column doesn't exist
248    /// - A column contains non-integer data
249    pub fn extract_i64(&self, batch: &RecordBatch) -> Result<TensorData<i64>> {
250        let rows = batch.num_rows();
251        let cols = self.columns.len();
252
253        let mut data = vec![0i64; rows * cols];
254
255        for (col_idx, col_name) in self.columns.iter().enumerate() {
256            let col_index = batch
257                .schema()
258                .index_of(col_name)
259                .map_err(|_| Error::column_not_found(col_name))?;
260
261            let array = batch.column(col_index);
262            Self::extract_column_i64(array, &mut data, col_idx, cols, rows)?;
263        }
264
265        TensorData::from_vec(data, rows, cols)
266    }
267
268    fn extract_column_f32(
269        array: &Arc<dyn Array>,
270        data: &mut [f32],
271        col_idx: usize,
272        num_cols: usize,
273        num_rows: usize,
274    ) -> Result<()> {
275        match array.data_type() {
276            DataType::Float32 => {
277                let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
278                for row in 0..num_rows {
279                    data[row * num_cols + col_idx] = arr.value(row);
280                }
281            }
282            DataType::Float64 => {
283                let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
284                for row in 0..num_rows {
285                    #[allow(clippy::cast_possible_truncation)]
286                    {
287                        data[row * num_cols + col_idx] = arr.value(row) as f32;
288                    }
289                }
290            }
291            DataType::Int8 => {
292                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
293                for row in 0..num_rows {
294                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
295                }
296            }
297            DataType::Int16 => {
298                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
299                for row in 0..num_rows {
300                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
301                }
302            }
303            DataType::Int32 => {
304                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
305                for row in 0..num_rows {
306                    #[allow(clippy::cast_precision_loss)]
307                    {
308                        data[row * num_cols + col_idx] = arr.value(row) as f32;
309                    }
310                }
311            }
312            DataType::Int64 => {
313                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
314                for row in 0..num_rows {
315                    #[allow(clippy::cast_precision_loss)]
316                    {
317                        data[row * num_cols + col_idx] = arr.value(row) as f32;
318                    }
319                }
320            }
321            DataType::UInt8 => {
322                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
323                for row in 0..num_rows {
324                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
325                }
326            }
327            DataType::UInt16 => {
328                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
329                for row in 0..num_rows {
330                    data[row * num_cols + col_idx] = f32::from(arr.value(row));
331                }
332            }
333            DataType::UInt32 => {
334                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
335                for row in 0..num_rows {
336                    #[allow(clippy::cast_precision_loss)]
337                    {
338                        data[row * num_cols + col_idx] = arr.value(row) as f32;
339                    }
340                }
341            }
342            DataType::UInt64 => {
343                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
344                for row in 0..num_rows {
345                    #[allow(clippy::cast_precision_loss)]
346                    {
347                        data[row * num_cols + col_idx] = arr.value(row) as f32;
348                    }
349                }
350            }
351            dt => {
352                return Err(Error::data(format!(
353                    "Cannot convert {:?} to f32 tensor",
354                    dt
355                )));
356            }
357        }
358        Ok(())
359    }
360
361    fn extract_column_f64(
362        array: &Arc<dyn Array>,
363        data: &mut [f64],
364        col_idx: usize,
365        num_cols: usize,
366        num_rows: usize,
367    ) -> Result<()> {
368        match array.data_type() {
369            DataType::Float32 => {
370                let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
371                for row in 0..num_rows {
372                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
373                }
374            }
375            DataType::Float64 => {
376                let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
377                for row in 0..num_rows {
378                    data[row * num_cols + col_idx] = arr.value(row);
379                }
380            }
381            DataType::Int8 => {
382                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
383                for row in 0..num_rows {
384                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
385                }
386            }
387            DataType::Int16 => {
388                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
389                for row in 0..num_rows {
390                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
391                }
392            }
393            DataType::Int32 => {
394                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
395                for row in 0..num_rows {
396                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
397                }
398            }
399            DataType::Int64 => {
400                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
401                for row in 0..num_rows {
402                    #[allow(clippy::cast_precision_loss)]
403                    {
404                        data[row * num_cols + col_idx] = arr.value(row) as f64;
405                    }
406                }
407            }
408            DataType::UInt8 => {
409                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
410                for row in 0..num_rows {
411                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
412                }
413            }
414            DataType::UInt16 => {
415                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
416                for row in 0..num_rows {
417                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
418                }
419            }
420            DataType::UInt32 => {
421                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
422                for row in 0..num_rows {
423                    data[row * num_cols + col_idx] = f64::from(arr.value(row));
424                }
425            }
426            DataType::UInt64 => {
427                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
428                for row in 0..num_rows {
429                    #[allow(clippy::cast_precision_loss)]
430                    {
431                        data[row * num_cols + col_idx] = arr.value(row) as f64;
432                    }
433                }
434            }
435            dt => {
436                return Err(Error::data(format!(
437                    "Cannot convert {:?} to f64 tensor",
438                    dt
439                )));
440            }
441        }
442        Ok(())
443    }
444
445    fn extract_column_i64(
446        array: &Arc<dyn Array>,
447        data: &mut [i64],
448        col_idx: usize,
449        num_cols: usize,
450        num_rows: usize,
451    ) -> Result<()> {
452        match array.data_type() {
453            DataType::Int8 => {
454                let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
455                for row in 0..num_rows {
456                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
457                }
458            }
459            DataType::Int16 => {
460                let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
461                for row in 0..num_rows {
462                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
463                }
464            }
465            DataType::Int32 => {
466                let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
467                for row in 0..num_rows {
468                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
469                }
470            }
471            DataType::Int64 => {
472                let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
473                for row in 0..num_rows {
474                    data[row * num_cols + col_idx] = arr.value(row);
475                }
476            }
477            DataType::UInt8 => {
478                let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
479                for row in 0..num_rows {
480                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
481                }
482            }
483            DataType::UInt16 => {
484                let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
485                for row in 0..num_rows {
486                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
487                }
488            }
489            DataType::UInt32 => {
490                let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
491                for row in 0..num_rows {
492                    data[row * num_cols + col_idx] = i64::from(arr.value(row));
493                }
494            }
495            DataType::UInt64 => {
496                let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
497                for row in 0..num_rows {
498                    #[allow(clippy::cast_possible_wrap)]
499                    {
500                        data[row * num_cols + col_idx] = arr.value(row) as i64;
501                    }
502                }
503            }
504            dt => {
505                return Err(Error::data(format!(
506                    "Cannot convert {:?} to i64 tensor",
507                    dt
508                )));
509            }
510        }
511        Ok(())
512    }
513}
514
515/// Extracts a single numeric column as a 1D vector.
516///
517/// # Errors
518///
519/// Returns an error if the column doesn't exist or is non-numeric.
520pub fn extract_column_f32(batch: &RecordBatch, column: &str) -> Result<Vec<f32>> {
521    let extractor = TensorExtractor::new(&[column]);
522    let tensor = extractor.extract_f32(batch)?;
523    Ok(tensor.into_vec())
524}
525
526/// Extracts a single numeric column as a 1D vector.
527///
528/// # Errors
529///
530/// Returns an error if the column doesn't exist or is non-numeric.
531pub fn extract_column_f64(batch: &RecordBatch, column: &str) -> Result<Vec<f64>> {
532    let extractor = TensorExtractor::new(&[column]);
533    let tensor = extractor.extract_f64(batch)?;
534    Ok(tensor.into_vec())
535}
536
537/// Extracts label column as integer indices.
538///
539/// String labels are converted to indices based on unique values.
540///
541/// # Errors
542///
543/// Returns an error if the column doesn't exist.
544pub fn extract_labels_i64(batch: &RecordBatch, column: &str) -> Result<Vec<i64>> {
545    let col_index = batch
546        .schema()
547        .index_of(column)
548        .map_err(|_| Error::column_not_found(column))?;
549
550    let array = batch.column(col_index);
551
552    match array.data_type() {
553        DataType::Int8 => {
554            let arr = array
555                .as_any()
556                .downcast_ref::<Int8Array>()
557                .ok_or_else(|| Error::data("Failed to downcast to Int8Array"))?;
558            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
559        }
560        DataType::Int16 => {
561            let arr = array
562                .as_any()
563                .downcast_ref::<Int16Array>()
564                .ok_or_else(|| Error::data("Failed to downcast to Int16Array"))?;
565            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
566        }
567        DataType::Int32 => {
568            let arr = array
569                .as_any()
570                .downcast_ref::<Int32Array>()
571                .ok_or_else(|| Error::data("Failed to downcast to Int32Array"))?;
572            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
573        }
574        DataType::Int64 => {
575            let arr = array
576                .as_any()
577                .downcast_ref::<Int64Array>()
578                .ok_or_else(|| Error::data("Failed to downcast to Int64Array"))?;
579            Ok(arr.iter().map(|v| v.unwrap_or(0)).collect())
580        }
581        DataType::UInt8 => {
582            let arr = array
583                .as_any()
584                .downcast_ref::<UInt8Array>()
585                .ok_or_else(|| Error::data("Failed to downcast to UInt8Array"))?;
586            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
587        }
588        DataType::UInt16 => {
589            let arr = array
590                .as_any()
591                .downcast_ref::<UInt16Array>()
592                .ok_or_else(|| Error::data("Failed to downcast to UInt16Array"))?;
593            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
594        }
595        DataType::UInt32 => {
596            let arr = array
597                .as_any()
598                .downcast_ref::<UInt32Array>()
599                .ok_or_else(|| Error::data("Failed to downcast to UInt32Array"))?;
600            Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
601        }
602        DataType::UInt64 => {
603            let arr = array
604                .as_any()
605                .downcast_ref::<UInt64Array>()
606                .ok_or_else(|| Error::data("Failed to downcast to UInt64Array"))?;
607            #[allow(clippy::cast_possible_wrap)]
608            Ok(arr.iter().map(|v| v.unwrap_or(0) as i64).collect())
609        }
610        DataType::Float32 => {
611            let arr = array
612                .as_any()
613                .downcast_ref::<Float32Array>()
614                .ok_or_else(|| Error::data("Failed to downcast to Float32Array"))?;
615            #[allow(clippy::cast_possible_truncation)]
616            Ok(arr.iter().map(|v| v.unwrap_or(0.0) as i64).collect())
617        }
618        DataType::Float64 => {
619            let arr = array
620                .as_any()
621                .downcast_ref::<Float64Array>()
622                .ok_or_else(|| Error::data("Failed to downcast to Float64Array"))?;
623            #[allow(clippy::cast_possible_truncation)]
624            Ok(arr.iter().map(|v| v.unwrap_or(0.0) as i64).collect())
625        }
626        dt => Err(Error::data(format!("Cannot extract labels from {:?}", dt))),
627    }
628}
629
630#[cfg(test)]
631#[allow(
632    clippy::cast_possible_truncation,
633    clippy::cast_possible_wrap,
634    clippy::uninlined_format_args,
635    clippy::unwrap_used,
636    clippy::expect_used,
637    clippy::float_cmp
638)]
639mod tests {
640    use arrow::datatypes::{Field, Schema};
641
642    use super::*;
643
644    fn create_numeric_batch() -> RecordBatch {
645        let schema = Arc::new(Schema::new(vec![
646            Field::new("f32_col", DataType::Float32, false),
647            Field::new("f64_col", DataType::Float64, false),
648            Field::new("i32_col", DataType::Int32, false),
649            Field::new("i64_col", DataType::Int64, false),
650        ]));
651
652        RecordBatch::try_new(
653            schema,
654            vec![
655                Arc::new(Float32Array::from(vec![1.0f32, 2.0, 3.0])),
656                Arc::new(Float64Array::from(vec![4.0f64, 5.0, 6.0])),
657                Arc::new(Int32Array::from(vec![7, 8, 9])),
658                Arc::new(Int64Array::from(vec![10i64, 11, 12])),
659            ],
660        )
661        .unwrap()
662    }
663
664    #[test]
665    fn test_tensor_data_new() {
666        let tensor: TensorData<f32> = TensorData::new(3, 4);
667        assert_eq!(tensor.shape(), [3, 4]);
668        assert_eq!(tensor.rows(), 3);
669        assert_eq!(tensor.cols(), 4);
670        assert_eq!(tensor.as_slice().len(), 12);
671    }
672
673    #[test]
674    fn test_tensor_data_from_vec() {
675        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
676        let tensor = TensorData::from_vec(data, 2, 3).unwrap();
677        assert_eq!(tensor.shape(), [2, 3]);
678        assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
679    }
680
681    #[test]
682    fn test_tensor_data_from_vec_invalid_shape() {
683        let data = vec![1.0f32, 2.0, 3.0, 4.0];
684        let result = TensorData::from_vec(data, 2, 3);
685        assert!(result.is_err());
686    }
687
688    #[test]
689    fn test_tensor_data_get_set() {
690        let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], 2, 2).unwrap();
691
692        assert_eq!(tensor.get(0, 0), Some(&1.0f32));
693        assert_eq!(tensor.get(0, 1), Some(&2.0f32));
694        assert_eq!(tensor.get(1, 0), Some(&3.0f32));
695        assert_eq!(tensor.get(1, 1), Some(&4.0f32));
696        assert_eq!(tensor.get(2, 0), None);
697
698        tensor.set(0, 1, 99.0);
699        assert_eq!(tensor.get(0, 1), Some(&99.0f32));
700    }
701
702    #[test]
703    fn test_tensor_data_into_vec() {
704        let data = vec![1.0f32, 2.0, 3.0];
705        let tensor = TensorData::from_vec(data.clone(), 1, 3).unwrap();
706        assert_eq!(tensor.into_vec(), data);
707    }
708
709    #[test]
710    fn test_tensor_data_as_ptr() {
711        let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
712        let ptr = tensor.as_ptr();
713        assert!(!ptr.is_null());
714    }
715
716    #[test]
717    fn test_tensor_data_as_mut_slice() {
718        let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
719        let slice = tensor.as_mut_slice();
720        slice[0] = 10.0;
721        assert_eq!(tensor.as_slice()[0], 10.0);
722    }
723
724    #[test]
725    fn test_tensor_data_clone() {
726        let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
727        let cloned = tensor.clone();
728        assert_eq!(cloned.shape(), tensor.shape());
729        assert_eq!(cloned.as_slice(), tensor.as_slice());
730    }
731
732    #[test]
733    fn test_tensor_data_debug() {
734        let tensor = TensorData::from_vec(vec![1.0f32], 1, 1).unwrap();
735        let debug = format!("{:?}", tensor);
736        assert!(debug.contains("TensorData"));
737    }
738
739    #[test]
740    fn test_extractor_new() {
741        let extractor = TensorExtractor::new(&["a", "b", "c"]);
742        assert_eq!(extractor.columns().len(), 3);
743        assert_eq!(extractor.columns()[0], "a");
744    }
745
746    #[test]
747    fn test_extractor_from_columns() {
748        let extractor = TensorExtractor::from_columns(vec!["x".to_string(), "y".to_string()]);
749        assert_eq!(extractor.columns().len(), 2);
750    }
751
752    #[test]
753    fn test_extractor_clone() {
754        let extractor = TensorExtractor::new(&["a", "b"]);
755        let cloned = extractor.clone();
756        assert_eq!(cloned.columns(), extractor.columns());
757    }
758
759    #[test]
760    fn test_extractor_debug() {
761        let extractor = TensorExtractor::new(&["col"]);
762        let debug = format!("{:?}", extractor);
763        assert!(debug.contains("TensorExtractor"));
764    }
765
766    #[test]
767    fn test_extract_f32() {
768        let batch = create_numeric_batch();
769        let extractor = TensorExtractor::new(&["f32_col", "i32_col"]);
770        let tensor = extractor.extract_f32(&batch).unwrap();
771
772        assert_eq!(tensor.shape(), [3, 2]);
773        assert_eq!(tensor.get(0, 0), Some(&1.0f32));
774        assert_eq!(tensor.get(0, 1), Some(&7.0f32));
775        assert_eq!(tensor.get(2, 0), Some(&3.0f32));
776        assert_eq!(tensor.get(2, 1), Some(&9.0f32));
777    }
778
779    #[test]
780    fn test_extract_f64() {
781        let batch = create_numeric_batch();
782        let extractor = TensorExtractor::new(&["f64_col", "i64_col"]);
783        let tensor = extractor.extract_f64(&batch).unwrap();
784
785        assert_eq!(tensor.shape(), [3, 2]);
786        assert_eq!(tensor.get(0, 0), Some(&4.0f64));
787        assert_eq!(tensor.get(0, 1), Some(&10.0f64));
788    }
789
790    #[test]
791    fn test_extract_i64() {
792        let batch = create_numeric_batch();
793        let extractor = TensorExtractor::new(&["i32_col", "i64_col"]);
794        let tensor = extractor.extract_i64(&batch).unwrap();
795
796        assert_eq!(tensor.shape(), [3, 2]);
797        assert_eq!(tensor.get(0, 0), Some(&7i64));
798        assert_eq!(tensor.get(0, 1), Some(&10i64));
799    }
800
801    #[test]
802    fn test_extract_column_not_found() {
803        let batch = create_numeric_batch();
804        let extractor = TensorExtractor::new(&["nonexistent"]);
805        let result = extractor.extract_f32(&batch);
806        assert!(result.is_err());
807    }
808
809    #[test]
810    fn test_extract_column_f32_helper() {
811        let batch = create_numeric_batch();
812        let data = extract_column_f32(&batch, "f32_col").unwrap();
813        assert_eq!(data, vec![1.0f32, 2.0, 3.0]);
814    }
815
816    #[test]
817    fn test_extract_column_f64_helper() {
818        let batch = create_numeric_batch();
819        let data = extract_column_f64(&batch, "f64_col").unwrap();
820        assert_eq!(data, vec![4.0f64, 5.0, 6.0]);
821    }
822
823    #[test]
824    fn test_extract_labels_i64() {
825        let batch = create_numeric_batch();
826        let labels = extract_labels_i64(&batch, "i32_col").unwrap();
827        assert_eq!(labels, vec![7i64, 8, 9]);
828    }
829
830    #[test]
831    fn test_extract_labels_i64_from_float() {
832        let schema = Arc::new(Schema::new(vec![Field::new(
833            "label",
834            DataType::Float64,
835            false,
836        )]));
837        let batch = RecordBatch::try_new(
838            schema,
839            vec![Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0]))],
840        )
841        .unwrap();
842
843        let labels = extract_labels_i64(&batch, "label").unwrap();
844        assert_eq!(labels, vec![0i64, 1, 2]);
845    }
846
847    #[test]
848    fn test_extract_labels_column_not_found() {
849        let batch = create_numeric_batch();
850        let result = extract_labels_i64(&batch, "nonexistent");
851        assert!(result.is_err());
852    }
853
854    #[test]
855    fn test_extract_all_int_types() {
856        let schema = Arc::new(Schema::new(vec![
857            Field::new("i8", DataType::Int8, false),
858            Field::new("i16", DataType::Int16, false),
859            Field::new("u8", DataType::UInt8, false),
860            Field::new("u16", DataType::UInt16, false),
861            Field::new("u32", DataType::UInt32, false),
862            Field::new("u64", DataType::UInt64, false),
863        ]));
864
865        let batch = RecordBatch::try_new(
866            schema,
867            vec![
868                Arc::new(Int8Array::from(vec![1i8])),
869                Arc::new(Int16Array::from(vec![2i16])),
870                Arc::new(UInt8Array::from(vec![3u8])),
871                Arc::new(UInt16Array::from(vec![4u16])),
872                Arc::new(UInt32Array::from(vec![5u32])),
873                Arc::new(UInt64Array::from(vec![6u64])),
874            ],
875        )
876        .unwrap();
877
878        // Test f32 extraction
879        let extractor = TensorExtractor::new(&["i8", "i16", "u8", "u16", "u32", "u64"]);
880        let tensor = extractor.extract_f32(&batch).unwrap();
881        assert_eq!(tensor.as_slice(), &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
882
883        // Test f64 extraction
884        let tensor = extractor.extract_f64(&batch).unwrap();
885        assert_eq!(tensor.as_slice(), &[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
886
887        // Test i64 extraction
888        let tensor = extractor.extract_i64(&batch).unwrap();
889        assert_eq!(tensor.as_slice(), &[1i64, 2, 3, 4, 5, 6]);
890    }
891
892    #[test]
893    fn test_extract_f32_from_f64() {
894        let schema = Arc::new(Schema::new(vec![Field::new(
895            "value",
896            DataType::Float64,
897            false,
898        )]));
899        let batch = RecordBatch::try_new(
900            schema,
901            vec![Arc::new(Float64Array::from(vec![1.5f64, 2.5, 3.5]))],
902        )
903        .unwrap();
904
905        let extractor = TensorExtractor::new(&["value"]);
906        let tensor = extractor.extract_f32(&batch).unwrap();
907        assert_eq!(tensor.as_slice(), &[1.5f32, 2.5, 3.5]);
908    }
909
910    #[test]
911    fn test_extract_f64_from_f32() {
912        let schema = Arc::new(Schema::new(vec![Field::new(
913            "value",
914            DataType::Float32,
915            false,
916        )]));
917        let batch = RecordBatch::try_new(
918            schema,
919            vec![Arc::new(Float32Array::from(vec![1.5f32, 2.5, 3.5]))],
920        )
921        .unwrap();
922
923        let extractor = TensorExtractor::new(&["value"]);
924        let tensor = extractor.extract_f64(&batch).unwrap();
925        // f32 -> f64 conversion is exact for these values
926        assert_eq!(tensor.as_slice(), &[1.5f64, 2.5, 3.5]);
927    }
928
929    #[test]
930    fn test_extract_unsupported_type_f32() {
931        use arrow::array::StringArray;
932
933        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
934        let batch = RecordBatch::try_new(
935            schema,
936            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
937        )
938        .unwrap();
939
940        let extractor = TensorExtractor::new(&["text"]);
941        let result = extractor.extract_f32(&batch);
942        assert!(result.is_err());
943    }
944
945    #[test]
946    fn test_extract_unsupported_type_f64() {
947        use arrow::array::StringArray;
948
949        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
950        let batch = RecordBatch::try_new(
951            schema,
952            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
953        )
954        .unwrap();
955
956        let extractor = TensorExtractor::new(&["text"]);
957        let result = extractor.extract_f64(&batch);
958        assert!(result.is_err());
959    }
960
961    #[test]
962    fn test_extract_unsupported_type_i64() {
963        use arrow::array::StringArray;
964
965        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
966        let batch = RecordBatch::try_new(
967            schema,
968            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
969        )
970        .unwrap();
971
972        let extractor = TensorExtractor::new(&["text"]);
973        let result = extractor.extract_i64(&batch);
974        assert!(result.is_err());
975    }
976
977    #[test]
978    fn test_extract_labels_unsupported_type() {
979        use arrow::array::StringArray;
980
981        let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
982        let batch = RecordBatch::try_new(
983            schema,
984            vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
985        )
986        .unwrap();
987
988        let result = extract_labels_i64(&batch, "text");
989        assert!(result.is_err());
990    }
991
992    #[test]
993    fn test_extract_labels_all_uint_types() {
994        let schema = Arc::new(Schema::new(vec![
995            Field::new("u8", DataType::UInt8, false),
996            Field::new("u16", DataType::UInt16, false),
997            Field::new("u32", DataType::UInt32, false),
998            Field::new("u64", DataType::UInt64, false),
999        ]));
1000
1001        let batch = RecordBatch::try_new(
1002            schema,
1003            vec![
1004                Arc::new(UInt8Array::from(vec![1u8])),
1005                Arc::new(UInt16Array::from(vec![2u16])),
1006                Arc::new(UInt32Array::from(vec![3u32])),
1007                Arc::new(UInt64Array::from(vec![4u64])),
1008            ],
1009        )
1010        .unwrap();
1011
1012        assert_eq!(extract_labels_i64(&batch, "u8").unwrap(), vec![1i64]);
1013        assert_eq!(extract_labels_i64(&batch, "u16").unwrap(), vec![2i64]);
1014        assert_eq!(extract_labels_i64(&batch, "u32").unwrap(), vec![3i64]);
1015        assert_eq!(extract_labels_i64(&batch, "u64").unwrap(), vec![4i64]);
1016    }
1017
1018    #[test]
1019    fn test_extract_labels_all_int_types() {
1020        let schema = Arc::new(Schema::new(vec![
1021            Field::new("i8", DataType::Int8, false),
1022            Field::new("i16", DataType::Int16, false),
1023            Field::new("f32", DataType::Float32, false),
1024        ]));
1025
1026        let batch = RecordBatch::try_new(
1027            schema,
1028            vec![
1029                Arc::new(Int8Array::from(vec![1i8])),
1030                Arc::new(Int16Array::from(vec![2i16])),
1031                Arc::new(Float32Array::from(vec![3.0f32])),
1032            ],
1033        )
1034        .unwrap();
1035
1036        assert_eq!(extract_labels_i64(&batch, "i8").unwrap(), vec![1i64]);
1037        assert_eq!(extract_labels_i64(&batch, "i16").unwrap(), vec![2i64]);
1038        assert_eq!(extract_labels_i64(&batch, "f32").unwrap(), vec![3i64]);
1039    }
1040}