1use std::sync::Arc;
39
40use arrow::{
41 array::{Array, AsArray},
42 datatypes::DataType,
43 record_batch::RecordBatch,
44};
45
46use crate::error::{Error, Result};
47
48#[derive(Debug, Clone)]
53pub struct TensorData<T> {
54 data: Vec<T>,
56 shape: [usize; 2],
58}
59
60impl<T: Clone + Default> TensorData<T> {
61 pub fn new(rows: usize, cols: usize) -> Self {
63 Self {
64 data: vec![T::default(); rows * cols],
65 shape: [rows, cols],
66 }
67 }
68
69 pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Result<Self> {
80 if data.len() != rows * cols {
81 return Err(Error::data(format!(
82 "Data length {} doesn't match shape [{}, {}]",
83 data.len(),
84 rows,
85 cols
86 )));
87 }
88 Ok(Self {
89 data,
90 shape: [rows, cols],
91 })
92 }
93
94 pub fn shape(&self) -> [usize; 2] {
96 self.shape
97 }
98
99 pub fn rows(&self) -> usize {
101 self.shape[0]
102 }
103
104 pub fn cols(&self) -> usize {
106 self.shape[1]
107 }
108
109 pub fn as_slice(&self) -> &[T] {
111 &self.data
112 }
113
114 pub fn as_mut_slice(&mut self) -> &mut [T] {
116 &mut self.data
117 }
118
119 pub fn into_vec(self) -> Vec<T> {
121 self.data
122 }
123
124 pub fn as_ptr(&self) -> *const T {
128 self.data.as_ptr()
129 }
130
131 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
133 if row < self.shape[0] && col < self.shape[1] {
134 Some(&self.data[row * self.shape[1] + col])
135 } else {
136 None
137 }
138 }
139
140 pub fn set(&mut self, row: usize, col: usize, value: T) {
146 assert!(row < self.shape[0] && col < self.shape[1]);
147 self.data[row * self.shape[1] + col] = value;
148 }
149}
150
151#[derive(Debug, Clone)]
155pub struct TensorExtractor {
156 columns: Vec<String>,
158}
159
160impl TensorExtractor {
161 pub fn new(columns: &[&str]) -> Self {
163 Self {
164 columns: columns.iter().map(|s| (*s).to_string()).collect(),
165 }
166 }
167
168 pub fn from_columns(columns: Vec<String>) -> Self {
170 Self { columns }
171 }
172
173 pub fn columns(&self) -> &[String] {
175 &self.columns
176 }
177
178 pub fn extract_f32(&self, batch: &RecordBatch) -> Result<TensorData<f32>> {
189 let rows = batch.num_rows();
190 let cols = self.columns.len();
191
192 let mut data = vec![0.0f32; rows * cols];
193
194 for (col_idx, col_name) in self.columns.iter().enumerate() {
195 let col_index = batch
196 .schema()
197 .index_of(col_name)
198 .map_err(|_| Error::column_not_found(col_name))?;
199
200 let array = batch.column(col_index);
201 Self::extract_column_f32(array, &mut data, col_idx, cols, rows)?;
202 }
203
204 TensorData::from_vec(data, rows, cols)
205 }
206
207 pub fn extract_f64(&self, batch: &RecordBatch) -> Result<TensorData<f64>> {
218 let rows = batch.num_rows();
219 let cols = self.columns.len();
220
221 let mut data = vec![0.0f64; rows * cols];
222
223 for (col_idx, col_name) in self.columns.iter().enumerate() {
224 let col_index = batch
225 .schema()
226 .index_of(col_name)
227 .map_err(|_| Error::column_not_found(col_name))?;
228
229 let array = batch.column(col_index);
230 Self::extract_column_f64(array, &mut data, col_idx, cols, rows)?;
231 }
232
233 TensorData::from_vec(data, rows, cols)
234 }
235
236 pub fn extract_i64(&self, batch: &RecordBatch) -> Result<TensorData<i64>> {
247 let rows = batch.num_rows();
248 let cols = self.columns.len();
249
250 let mut data = vec![0i64; rows * cols];
251
252 for (col_idx, col_name) in self.columns.iter().enumerate() {
253 let col_index = batch
254 .schema()
255 .index_of(col_name)
256 .map_err(|_| Error::column_not_found(col_name))?;
257
258 let array = batch.column(col_index);
259 Self::extract_column_i64(array, &mut data, col_idx, cols, rows)?;
260 }
261
262 TensorData::from_vec(data, rows, cols)
263 }
264
265 fn extract_column_f32(
266 array: &Arc<dyn Array>,
267 data: &mut [f32],
268 col_idx: usize,
269 num_cols: usize,
270 num_rows: usize,
271 ) -> Result<()> {
272 match array.data_type() {
273 DataType::Float32 => {
274 let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
275 for row in 0..num_rows {
276 data[row * num_cols + col_idx] = arr.value(row);
277 }
278 }
279 DataType::Float64 => {
280 let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
281 for row in 0..num_rows {
282 #[allow(clippy::cast_possible_truncation)]
283 {
284 data[row * num_cols + col_idx] = arr.value(row) as f32;
285 }
286 }
287 }
288 DataType::Int8 => {
289 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
290 for row in 0..num_rows {
291 data[row * num_cols + col_idx] = f32::from(arr.value(row));
292 }
293 }
294 DataType::Int16 => {
295 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
296 for row in 0..num_rows {
297 data[row * num_cols + col_idx] = f32::from(arr.value(row));
298 }
299 }
300 DataType::Int32 => {
301 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
302 for row in 0..num_rows {
303 #[allow(clippy::cast_precision_loss)]
304 {
305 data[row * num_cols + col_idx] = arr.value(row) as f32;
306 }
307 }
308 }
309 DataType::Int64 => {
310 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
311 for row in 0..num_rows {
312 #[allow(clippy::cast_precision_loss)]
313 {
314 data[row * num_cols + col_idx] = arr.value(row) as f32;
315 }
316 }
317 }
318 DataType::UInt8 => {
319 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
320 for row in 0..num_rows {
321 data[row * num_cols + col_idx] = f32::from(arr.value(row));
322 }
323 }
324 DataType::UInt16 => {
325 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
326 for row in 0..num_rows {
327 data[row * num_cols + col_idx] = f32::from(arr.value(row));
328 }
329 }
330 DataType::UInt32 => {
331 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
332 for row in 0..num_rows {
333 #[allow(clippy::cast_precision_loss)]
334 {
335 data[row * num_cols + col_idx] = arr.value(row) as f32;
336 }
337 }
338 }
339 DataType::UInt64 => {
340 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
341 for row in 0..num_rows {
342 #[allow(clippy::cast_precision_loss)]
343 {
344 data[row * num_cols + col_idx] = arr.value(row) as f32;
345 }
346 }
347 }
348 dt => {
349 return Err(Error::data(format!(
350 "Cannot convert {:?} to f32 tensor",
351 dt
352 )));
353 }
354 }
355 Ok(())
356 }
357
358 fn extract_column_f64(
359 array: &Arc<dyn Array>,
360 data: &mut [f64],
361 col_idx: usize,
362 num_cols: usize,
363 num_rows: usize,
364 ) -> Result<()> {
365 match array.data_type() {
366 DataType::Float32 => {
367 let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
368 for row in 0..num_rows {
369 data[row * num_cols + col_idx] = f64::from(arr.value(row));
370 }
371 }
372 DataType::Float64 => {
373 let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
374 for row in 0..num_rows {
375 data[row * num_cols + col_idx] = arr.value(row);
376 }
377 }
378 DataType::Int8 => {
379 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
380 for row in 0..num_rows {
381 data[row * num_cols + col_idx] = f64::from(arr.value(row));
382 }
383 }
384 DataType::Int16 => {
385 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
386 for row in 0..num_rows {
387 data[row * num_cols + col_idx] = f64::from(arr.value(row));
388 }
389 }
390 DataType::Int32 => {
391 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
392 for row in 0..num_rows {
393 data[row * num_cols + col_idx] = f64::from(arr.value(row));
394 }
395 }
396 DataType::Int64 => {
397 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
398 for row in 0..num_rows {
399 #[allow(clippy::cast_precision_loss)]
400 {
401 data[row * num_cols + col_idx] = arr.value(row) as f64;
402 }
403 }
404 }
405 DataType::UInt8 => {
406 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
407 for row in 0..num_rows {
408 data[row * num_cols + col_idx] = f64::from(arr.value(row));
409 }
410 }
411 DataType::UInt16 => {
412 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
413 for row in 0..num_rows {
414 data[row * num_cols + col_idx] = f64::from(arr.value(row));
415 }
416 }
417 DataType::UInt32 => {
418 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
419 for row in 0..num_rows {
420 data[row * num_cols + col_idx] = f64::from(arr.value(row));
421 }
422 }
423 DataType::UInt64 => {
424 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
425 for row in 0..num_rows {
426 #[allow(clippy::cast_precision_loss)]
427 {
428 data[row * num_cols + col_idx] = arr.value(row) as f64;
429 }
430 }
431 }
432 dt => {
433 return Err(Error::data(format!(
434 "Cannot convert {:?} to f64 tensor",
435 dt
436 )));
437 }
438 }
439 Ok(())
440 }
441
442 fn extract_column_i64(
443 array: &Arc<dyn Array>,
444 data: &mut [i64],
445 col_idx: usize,
446 num_cols: usize,
447 num_rows: usize,
448 ) -> Result<()> {
449 match array.data_type() {
450 DataType::Int8 => {
451 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
452 for row in 0..num_rows {
453 data[row * num_cols + col_idx] = i64::from(arr.value(row));
454 }
455 }
456 DataType::Int16 => {
457 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
458 for row in 0..num_rows {
459 data[row * num_cols + col_idx] = i64::from(arr.value(row));
460 }
461 }
462 DataType::Int32 => {
463 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
464 for row in 0..num_rows {
465 data[row * num_cols + col_idx] = i64::from(arr.value(row));
466 }
467 }
468 DataType::Int64 => {
469 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
470 for row in 0..num_rows {
471 data[row * num_cols + col_idx] = arr.value(row);
472 }
473 }
474 DataType::UInt8 => {
475 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
476 for row in 0..num_rows {
477 data[row * num_cols + col_idx] = i64::from(arr.value(row));
478 }
479 }
480 DataType::UInt16 => {
481 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
482 for row in 0..num_rows {
483 data[row * num_cols + col_idx] = i64::from(arr.value(row));
484 }
485 }
486 DataType::UInt32 => {
487 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
488 for row in 0..num_rows {
489 data[row * num_cols + col_idx] = i64::from(arr.value(row));
490 }
491 }
492 DataType::UInt64 => {
493 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
494 for row in 0..num_rows {
495 #[allow(clippy::cast_possible_wrap)]
496 {
497 data[row * num_cols + col_idx] = arr.value(row) as i64;
498 }
499 }
500 }
501 dt => {
502 return Err(Error::data(format!(
503 "Cannot convert {:?} to i64 tensor",
504 dt
505 )));
506 }
507 }
508 Ok(())
509 }
510}
511
512pub fn extract_column_f32(batch: &RecordBatch, column: &str) -> Result<Vec<f32>> {
518 let extractor = TensorExtractor::new(&[column]);
519 let tensor = extractor.extract_f32(batch)?;
520 Ok(tensor.into_vec())
521}
522
523pub fn extract_column_f64(batch: &RecordBatch, column: &str) -> Result<Vec<f64>> {
529 let extractor = TensorExtractor::new(&[column]);
530 let tensor = extractor.extract_f64(batch)?;
531 Ok(tensor.into_vec())
532}
533
534pub fn extract_labels_i64(batch: &RecordBatch, column: &str) -> Result<Vec<i64>> {
542 use arrow::datatypes::{
543 Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
544 UInt32Type, UInt64Type, UInt8Type,
545 };
546
547 let col_index = batch
548 .schema()
549 .index_of(column)
550 .map_err(|_| Error::column_not_found(column))?;
551
552 let array = batch.column(col_index);
553
554 match array.data_type() {
555 DataType::Int8 => cast_and_collect::<Int8Type>(array, "Int8Array", |v| i64::from(v)),
556 DataType::Int16 => cast_and_collect::<Int16Type>(array, "Int16Array", |v| i64::from(v)),
557 DataType::Int32 => cast_and_collect::<Int32Type>(array, "Int32Array", |v| i64::from(v)),
558 DataType::Int64 => cast_and_collect::<Int64Type>(array, "Int64Array", |v| v),
559 DataType::UInt8 => cast_and_collect::<UInt8Type>(array, "UInt8Array", |v| i64::from(v)),
560 DataType::UInt16 => cast_and_collect::<UInt16Type>(array, "UInt16Array", |v| i64::from(v)),
561 DataType::UInt32 => cast_and_collect::<UInt32Type>(array, "UInt32Array", |v| i64::from(v)),
562 DataType::UInt64 =>
563 {
564 #[allow(clippy::cast_possible_wrap)]
565 cast_and_collect::<UInt64Type>(array, "UInt64Array", |v| v as i64)
566 }
567 DataType::Float32 =>
568 {
569 #[allow(clippy::cast_possible_truncation)]
570 cast_and_collect::<Float32Type>(array, "Float32Array", |v| v as i64)
571 }
572 DataType::Float64 =>
573 {
574 #[allow(clippy::cast_possible_truncation)]
575 cast_and_collect::<Float64Type>(array, "Float64Array", |v| v as i64)
576 }
577 dt => Err(Error::data(format!("Cannot extract labels from {:?}", dt))),
578 }
579}
580
581fn cast_and_collect<T>(
584 array: &dyn arrow::array::Array,
585 type_name: &str,
586 cast: impl Fn(T::Native) -> i64,
587) -> Result<Vec<i64>>
588where
589 T: arrow::datatypes::ArrowPrimitiveType,
590 T::Native: Copy + Default,
591{
592 let arr = array
593 .as_any()
594 .downcast_ref::<arrow::array::PrimitiveArray<T>>()
595 .ok_or_else(|| Error::data(format!("Failed to downcast to {type_name}")))?;
596 Ok(arr.iter().map(|v| cast(v.unwrap_or_default())).collect())
597}
598
599#[cfg(test)]
600#[allow(
601 clippy::cast_possible_truncation,
602 clippy::cast_possible_wrap,
603 clippy::uninlined_format_args,
604 clippy::unwrap_used,
605 clippy::expect_used,
606 clippy::float_cmp
607)]
608mod tests {
609 use arrow::{
610 array::{
611 Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
612 UInt32Array, UInt64Array, UInt8Array,
613 },
614 datatypes::{Field, Schema},
615 };
616
617 use super::*;
618
619 fn create_numeric_batch() -> RecordBatch {
620 let schema = Arc::new(Schema::new(vec![
621 Field::new("f32_col", DataType::Float32, false),
622 Field::new("f64_col", DataType::Float64, false),
623 Field::new("i32_col", DataType::Int32, false),
624 Field::new("i64_col", DataType::Int64, false),
625 ]));
626
627 RecordBatch::try_new(
628 schema,
629 vec![
630 Arc::new(Float32Array::from(vec![1.0f32, 2.0, 3.0])),
631 Arc::new(Float64Array::from(vec![4.0f64, 5.0, 6.0])),
632 Arc::new(Int32Array::from(vec![7, 8, 9])),
633 Arc::new(Int64Array::from(vec![10i64, 11, 12])),
634 ],
635 )
636 .unwrap()
637 }
638
639 #[test]
640 fn test_tensor_data_new() {
641 let tensor: TensorData<f32> = TensorData::new(3, 4);
642 assert_eq!(tensor.shape(), [3, 4]);
643 assert_eq!(tensor.rows(), 3);
644 assert_eq!(tensor.cols(), 4);
645 assert_eq!(tensor.as_slice().len(), 12);
646 }
647
648 #[test]
649 fn test_tensor_data_from_vec() {
650 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
651 let tensor = TensorData::from_vec(data, 2, 3).unwrap();
652 assert_eq!(tensor.shape(), [2, 3]);
653 assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
654 }
655
656 #[test]
657 fn test_tensor_data_from_vec_invalid_shape() {
658 let data = vec![1.0f32, 2.0, 3.0, 4.0];
659 let result = TensorData::from_vec(data, 2, 3);
660 assert!(result.is_err());
661 }
662
663 #[test]
664 fn test_tensor_data_get_set() {
665 let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], 2, 2).unwrap();
666
667 assert_eq!(tensor.get(0, 0), Some(&1.0f32));
668 assert_eq!(tensor.get(0, 1), Some(&2.0f32));
669 assert_eq!(tensor.get(1, 0), Some(&3.0f32));
670 assert_eq!(tensor.get(1, 1), Some(&4.0f32));
671 assert_eq!(tensor.get(2, 0), None);
672
673 tensor.set(0, 1, 99.0);
674 assert_eq!(tensor.get(0, 1), Some(&99.0f32));
675 }
676
677 #[test]
678 fn test_tensor_data_into_vec() {
679 let data = vec![1.0f32, 2.0, 3.0];
680 let tensor = TensorData::from_vec(data.clone(), 1, 3).unwrap();
681 assert_eq!(tensor.into_vec(), data);
682 }
683
684 #[test]
685 fn test_tensor_data_as_ptr() {
686 let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
687 let ptr = tensor.as_ptr();
688 assert!(!ptr.is_null());
689 }
690
691 #[test]
692 fn test_tensor_data_as_mut_slice() {
693 let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
694 let slice = tensor.as_mut_slice();
695 slice[0] = 10.0;
696 assert_eq!(tensor.as_slice()[0], 10.0);
697 }
698
699 #[test]
700 fn test_tensor_data_clone() {
701 let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
702 let cloned = tensor.clone();
703 assert_eq!(cloned.shape(), tensor.shape());
704 assert_eq!(cloned.as_slice(), tensor.as_slice());
705 }
706
707 #[test]
708 fn test_tensor_data_debug() {
709 let tensor = TensorData::from_vec(vec![1.0f32], 1, 1).unwrap();
710 let debug = format!("{:?}", tensor);
711 assert!(debug.contains("TensorData"));
712 }
713
714 #[test]
715 fn test_extractor_new() {
716 let extractor = TensorExtractor::new(&["a", "b", "c"]);
717 assert_eq!(extractor.columns().len(), 3);
718 assert_eq!(extractor.columns()[0], "a");
719 }
720
721 #[test]
722 fn test_extractor_from_columns() {
723 let extractor = TensorExtractor::from_columns(vec!["x".to_string(), "y".to_string()]);
724 assert_eq!(extractor.columns().len(), 2);
725 }
726
727 #[test]
728 fn test_extractor_clone() {
729 let extractor = TensorExtractor::new(&["a", "b"]);
730 let cloned = extractor.clone();
731 assert_eq!(cloned.columns(), extractor.columns());
732 }
733
734 #[test]
735 fn test_extractor_debug() {
736 let extractor = TensorExtractor::new(&["col"]);
737 let debug = format!("{:?}", extractor);
738 assert!(debug.contains("TensorExtractor"));
739 }
740
741 #[test]
742 fn test_extract_f32() {
743 let batch = create_numeric_batch();
744 let extractor = TensorExtractor::new(&["f32_col", "i32_col"]);
745 let tensor = extractor.extract_f32(&batch).unwrap();
746
747 assert_eq!(tensor.shape(), [3, 2]);
748 assert_eq!(tensor.get(0, 0), Some(&1.0f32));
749 assert_eq!(tensor.get(0, 1), Some(&7.0f32));
750 assert_eq!(tensor.get(2, 0), Some(&3.0f32));
751 assert_eq!(tensor.get(2, 1), Some(&9.0f32));
752 }
753
754 #[test]
755 fn test_extract_f64() {
756 let batch = create_numeric_batch();
757 let extractor = TensorExtractor::new(&["f64_col", "i64_col"]);
758 let tensor = extractor.extract_f64(&batch).unwrap();
759
760 assert_eq!(tensor.shape(), [3, 2]);
761 assert_eq!(tensor.get(0, 0), Some(&4.0f64));
762 assert_eq!(tensor.get(0, 1), Some(&10.0f64));
763 }
764
765 #[test]
766 fn test_extract_i64() {
767 let batch = create_numeric_batch();
768 let extractor = TensorExtractor::new(&["i32_col", "i64_col"]);
769 let tensor = extractor.extract_i64(&batch).unwrap();
770
771 assert_eq!(tensor.shape(), [3, 2]);
772 assert_eq!(tensor.get(0, 0), Some(&7i64));
773 assert_eq!(tensor.get(0, 1), Some(&10i64));
774 }
775
776 #[test]
777 fn test_extract_column_not_found() {
778 let batch = create_numeric_batch();
779 let extractor = TensorExtractor::new(&["nonexistent"]);
780 let result = extractor.extract_f32(&batch);
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn test_extract_column_f32_helper() {
786 let batch = create_numeric_batch();
787 let data = extract_column_f32(&batch, "f32_col").unwrap();
788 assert_eq!(data, vec![1.0f32, 2.0, 3.0]);
789 }
790
791 #[test]
792 fn test_extract_column_f64_helper() {
793 let batch = create_numeric_batch();
794 let data = extract_column_f64(&batch, "f64_col").unwrap();
795 assert_eq!(data, vec![4.0f64, 5.0, 6.0]);
796 }
797
798 #[test]
799 fn test_extract_labels_i64() {
800 let batch = create_numeric_batch();
801 let labels = extract_labels_i64(&batch, "i32_col").unwrap();
802 assert_eq!(labels, vec![7i64, 8, 9]);
803 }
804
805 #[test]
806 fn test_extract_labels_i64_from_float() {
807 let schema = Arc::new(Schema::new(vec![Field::new(
808 "label",
809 DataType::Float64,
810 false,
811 )]));
812 let batch = RecordBatch::try_new(
813 schema,
814 vec![Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0]))],
815 )
816 .unwrap();
817
818 let labels = extract_labels_i64(&batch, "label").unwrap();
819 assert_eq!(labels, vec![0i64, 1, 2]);
820 }
821
822 #[test]
823 fn test_extract_labels_column_not_found() {
824 let batch = create_numeric_batch();
825 let result = extract_labels_i64(&batch, "nonexistent");
826 assert!(result.is_err());
827 }
828
829 #[test]
830 fn test_extract_all_int_types() {
831 let schema = Arc::new(Schema::new(vec![
832 Field::new("i8", DataType::Int8, false),
833 Field::new("i16", DataType::Int16, false),
834 Field::new("u8", DataType::UInt8, false),
835 Field::new("u16", DataType::UInt16, false),
836 Field::new("u32", DataType::UInt32, false),
837 Field::new("u64", DataType::UInt64, false),
838 ]));
839
840 let batch = RecordBatch::try_new(
841 schema,
842 vec![
843 Arc::new(Int8Array::from(vec![1i8])),
844 Arc::new(Int16Array::from(vec![2i16])),
845 Arc::new(UInt8Array::from(vec![3u8])),
846 Arc::new(UInt16Array::from(vec![4u16])),
847 Arc::new(UInt32Array::from(vec![5u32])),
848 Arc::new(UInt64Array::from(vec![6u64])),
849 ],
850 )
851 .unwrap();
852
853 let extractor = TensorExtractor::new(&["i8", "i16", "u8", "u16", "u32", "u64"]);
855 let tensor = extractor.extract_f32(&batch).unwrap();
856 assert_eq!(tensor.as_slice(), &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
857
858 let tensor = extractor.extract_f64(&batch).unwrap();
860 assert_eq!(tensor.as_slice(), &[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
861
862 let tensor = extractor.extract_i64(&batch).unwrap();
864 assert_eq!(tensor.as_slice(), &[1i64, 2, 3, 4, 5, 6]);
865 }
866
867 #[test]
868 fn test_extract_f32_from_f64() {
869 let schema = Arc::new(Schema::new(vec![Field::new(
870 "value",
871 DataType::Float64,
872 false,
873 )]));
874 let batch = RecordBatch::try_new(
875 schema,
876 vec![Arc::new(Float64Array::from(vec![1.5f64, 2.5, 3.5]))],
877 )
878 .unwrap();
879
880 let extractor = TensorExtractor::new(&["value"]);
881 let tensor = extractor.extract_f32(&batch).unwrap();
882 assert_eq!(tensor.as_slice(), &[1.5f32, 2.5, 3.5]);
883 }
884
885 #[test]
886 fn test_extract_f64_from_f32() {
887 let schema = Arc::new(Schema::new(vec![Field::new(
888 "value",
889 DataType::Float32,
890 false,
891 )]));
892 let batch = RecordBatch::try_new(
893 schema,
894 vec![Arc::new(Float32Array::from(vec![1.5f32, 2.5, 3.5]))],
895 )
896 .unwrap();
897
898 let extractor = TensorExtractor::new(&["value"]);
899 let tensor = extractor.extract_f64(&batch).unwrap();
900 assert_eq!(tensor.as_slice(), &[1.5f64, 2.5, 3.5]);
902 }
903
904 #[test]
905 fn test_extract_unsupported_type_f32() {
906 use arrow::array::StringArray;
907
908 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
909 let batch = RecordBatch::try_new(
910 schema,
911 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
912 )
913 .unwrap();
914
915 let extractor = TensorExtractor::new(&["text"]);
916 let result = extractor.extract_f32(&batch);
917 assert!(result.is_err());
918 }
919
920 #[test]
921 fn test_extract_unsupported_type_f64() {
922 use arrow::array::StringArray;
923
924 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
925 let batch = RecordBatch::try_new(
926 schema,
927 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
928 )
929 .unwrap();
930
931 let extractor = TensorExtractor::new(&["text"]);
932 let result = extractor.extract_f64(&batch);
933 assert!(result.is_err());
934 }
935
936 #[test]
937 fn test_extract_unsupported_type_i64() {
938 use arrow::array::StringArray;
939
940 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
941 let batch = RecordBatch::try_new(
942 schema,
943 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
944 )
945 .unwrap();
946
947 let extractor = TensorExtractor::new(&["text"]);
948 let result = extractor.extract_i64(&batch);
949 assert!(result.is_err());
950 }
951
952 #[test]
953 fn test_extract_labels_unsupported_type() {
954 use arrow::array::StringArray;
955
956 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
957 let batch = RecordBatch::try_new(
958 schema,
959 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
960 )
961 .unwrap();
962
963 let result = extract_labels_i64(&batch, "text");
964 assert!(result.is_err());
965 }
966
967 #[test]
968 fn test_extract_labels_all_uint_types() {
969 let schema = Arc::new(Schema::new(vec![
970 Field::new("u8", DataType::UInt8, false),
971 Field::new("u16", DataType::UInt16, false),
972 Field::new("u32", DataType::UInt32, false),
973 Field::new("u64", DataType::UInt64, false),
974 ]));
975
976 let batch = RecordBatch::try_new(
977 schema,
978 vec![
979 Arc::new(UInt8Array::from(vec![1u8])),
980 Arc::new(UInt16Array::from(vec![2u16])),
981 Arc::new(UInt32Array::from(vec![3u32])),
982 Arc::new(UInt64Array::from(vec![4u64])),
983 ],
984 )
985 .unwrap();
986
987 assert_eq!(extract_labels_i64(&batch, "u8").unwrap(), vec![1i64]);
988 assert_eq!(extract_labels_i64(&batch, "u16").unwrap(), vec![2i64]);
989 assert_eq!(extract_labels_i64(&batch, "u32").unwrap(), vec![3i64]);
990 assert_eq!(extract_labels_i64(&batch, "u64").unwrap(), vec![4i64]);
991 }
992
993 #[test]
994 fn test_extract_labels_all_int_types() {
995 let schema = Arc::new(Schema::new(vec![
996 Field::new("i8", DataType::Int8, false),
997 Field::new("i16", DataType::Int16, false),
998 Field::new("f32", DataType::Float32, false),
999 ]));
1000
1001 let batch = RecordBatch::try_new(
1002 schema,
1003 vec![
1004 Arc::new(Int8Array::from(vec![1i8])),
1005 Arc::new(Int16Array::from(vec![2i16])),
1006 Arc::new(Float32Array::from(vec![3.0f32])),
1007 ],
1008 )
1009 .unwrap();
1010
1011 assert_eq!(extract_labels_i64(&batch, "i8").unwrap(), vec![1i64]);
1012 assert_eq!(extract_labels_i64(&batch, "i16").unwrap(), vec![2i64]);
1013 assert_eq!(extract_labels_i64(&batch, "f32").unwrap(), vec![3i64]);
1014 }
1015}