1use std::sync::Arc;
39
40use arrow::{
41 array::{
42 Array, AsArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
43 UInt16Array, UInt32Array, UInt64Array, UInt8Array,
44 },
45 datatypes::DataType,
46 record_batch::RecordBatch,
47};
48
49use crate::error::{Error, Result};
50
51#[derive(Debug, Clone)]
56pub struct TensorData<T> {
57 data: Vec<T>,
59 shape: [usize; 2],
61}
62
63impl<T: Clone + Default> TensorData<T> {
64 pub fn new(rows: usize, cols: usize) -> Self {
66 Self {
67 data: vec![T::default(); rows * cols],
68 shape: [rows, cols],
69 }
70 }
71
72 pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Result<Self> {
83 if data.len() != rows * cols {
84 return Err(Error::data(format!(
85 "Data length {} doesn't match shape [{}, {}]",
86 data.len(),
87 rows,
88 cols
89 )));
90 }
91 Ok(Self {
92 data,
93 shape: [rows, cols],
94 })
95 }
96
97 pub fn shape(&self) -> [usize; 2] {
99 self.shape
100 }
101
102 pub fn rows(&self) -> usize {
104 self.shape[0]
105 }
106
107 pub fn cols(&self) -> usize {
109 self.shape[1]
110 }
111
112 pub fn as_slice(&self) -> &[T] {
114 &self.data
115 }
116
117 pub fn as_mut_slice(&mut self) -> &mut [T] {
119 &mut self.data
120 }
121
122 pub fn into_vec(self) -> Vec<T> {
124 self.data
125 }
126
127 pub fn as_ptr(&self) -> *const T {
131 self.data.as_ptr()
132 }
133
134 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
136 if row < self.shape[0] && col < self.shape[1] {
137 Some(&self.data[row * self.shape[1] + col])
138 } else {
139 None
140 }
141 }
142
143 pub fn set(&mut self, row: usize, col: usize, value: T) {
149 assert!(row < self.shape[0] && col < self.shape[1]);
150 self.data[row * self.shape[1] + col] = value;
151 }
152}
153
154#[derive(Debug, Clone)]
158pub struct TensorExtractor {
159 columns: Vec<String>,
161}
162
163impl TensorExtractor {
164 pub fn new(columns: &[&str]) -> Self {
166 Self {
167 columns: columns.iter().map(|s| (*s).to_string()).collect(),
168 }
169 }
170
171 pub fn from_columns(columns: Vec<String>) -> Self {
173 Self { columns }
174 }
175
176 pub fn columns(&self) -> &[String] {
178 &self.columns
179 }
180
181 pub fn extract_f32(&self, batch: &RecordBatch) -> Result<TensorData<f32>> {
192 let rows = batch.num_rows();
193 let cols = self.columns.len();
194
195 let mut data = vec![0.0f32; rows * cols];
196
197 for (col_idx, col_name) in self.columns.iter().enumerate() {
198 let col_index = batch
199 .schema()
200 .index_of(col_name)
201 .map_err(|_| Error::column_not_found(col_name))?;
202
203 let array = batch.column(col_index);
204 Self::extract_column_f32(array, &mut data, col_idx, cols, rows)?;
205 }
206
207 TensorData::from_vec(data, rows, cols)
208 }
209
210 pub fn extract_f64(&self, batch: &RecordBatch) -> Result<TensorData<f64>> {
221 let rows = batch.num_rows();
222 let cols = self.columns.len();
223
224 let mut data = vec![0.0f64; rows * cols];
225
226 for (col_idx, col_name) in self.columns.iter().enumerate() {
227 let col_index = batch
228 .schema()
229 .index_of(col_name)
230 .map_err(|_| Error::column_not_found(col_name))?;
231
232 let array = batch.column(col_index);
233 Self::extract_column_f64(array, &mut data, col_idx, cols, rows)?;
234 }
235
236 TensorData::from_vec(data, rows, cols)
237 }
238
239 pub fn extract_i64(&self, batch: &RecordBatch) -> Result<TensorData<i64>> {
250 let rows = batch.num_rows();
251 let cols = self.columns.len();
252
253 let mut data = vec![0i64; rows * cols];
254
255 for (col_idx, col_name) in self.columns.iter().enumerate() {
256 let col_index = batch
257 .schema()
258 .index_of(col_name)
259 .map_err(|_| Error::column_not_found(col_name))?;
260
261 let array = batch.column(col_index);
262 Self::extract_column_i64(array, &mut data, col_idx, cols, rows)?;
263 }
264
265 TensorData::from_vec(data, rows, cols)
266 }
267
268 fn extract_column_f32(
269 array: &Arc<dyn Array>,
270 data: &mut [f32],
271 col_idx: usize,
272 num_cols: usize,
273 num_rows: usize,
274 ) -> Result<()> {
275 match array.data_type() {
276 DataType::Float32 => {
277 let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
278 for row in 0..num_rows {
279 data[row * num_cols + col_idx] = arr.value(row);
280 }
281 }
282 DataType::Float64 => {
283 let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
284 for row in 0..num_rows {
285 #[allow(clippy::cast_possible_truncation)]
286 {
287 data[row * num_cols + col_idx] = arr.value(row) as f32;
288 }
289 }
290 }
291 DataType::Int8 => {
292 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
293 for row in 0..num_rows {
294 data[row * num_cols + col_idx] = f32::from(arr.value(row));
295 }
296 }
297 DataType::Int16 => {
298 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
299 for row in 0..num_rows {
300 data[row * num_cols + col_idx] = f32::from(arr.value(row));
301 }
302 }
303 DataType::Int32 => {
304 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
305 for row in 0..num_rows {
306 #[allow(clippy::cast_precision_loss)]
307 {
308 data[row * num_cols + col_idx] = arr.value(row) as f32;
309 }
310 }
311 }
312 DataType::Int64 => {
313 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
314 for row in 0..num_rows {
315 #[allow(clippy::cast_precision_loss)]
316 {
317 data[row * num_cols + col_idx] = arr.value(row) as f32;
318 }
319 }
320 }
321 DataType::UInt8 => {
322 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
323 for row in 0..num_rows {
324 data[row * num_cols + col_idx] = f32::from(arr.value(row));
325 }
326 }
327 DataType::UInt16 => {
328 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
329 for row in 0..num_rows {
330 data[row * num_cols + col_idx] = f32::from(arr.value(row));
331 }
332 }
333 DataType::UInt32 => {
334 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
335 for row in 0..num_rows {
336 #[allow(clippy::cast_precision_loss)]
337 {
338 data[row * num_cols + col_idx] = arr.value(row) as f32;
339 }
340 }
341 }
342 DataType::UInt64 => {
343 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
344 for row in 0..num_rows {
345 #[allow(clippy::cast_precision_loss)]
346 {
347 data[row * num_cols + col_idx] = arr.value(row) as f32;
348 }
349 }
350 }
351 dt => {
352 return Err(Error::data(format!(
353 "Cannot convert {:?} to f32 tensor",
354 dt
355 )));
356 }
357 }
358 Ok(())
359 }
360
361 fn extract_column_f64(
362 array: &Arc<dyn Array>,
363 data: &mut [f64],
364 col_idx: usize,
365 num_cols: usize,
366 num_rows: usize,
367 ) -> Result<()> {
368 match array.data_type() {
369 DataType::Float32 => {
370 let arr = array.as_primitive::<arrow::datatypes::Float32Type>();
371 for row in 0..num_rows {
372 data[row * num_cols + col_idx] = f64::from(arr.value(row));
373 }
374 }
375 DataType::Float64 => {
376 let arr = array.as_primitive::<arrow::datatypes::Float64Type>();
377 for row in 0..num_rows {
378 data[row * num_cols + col_idx] = arr.value(row);
379 }
380 }
381 DataType::Int8 => {
382 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
383 for row in 0..num_rows {
384 data[row * num_cols + col_idx] = f64::from(arr.value(row));
385 }
386 }
387 DataType::Int16 => {
388 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
389 for row in 0..num_rows {
390 data[row * num_cols + col_idx] = f64::from(arr.value(row));
391 }
392 }
393 DataType::Int32 => {
394 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
395 for row in 0..num_rows {
396 data[row * num_cols + col_idx] = f64::from(arr.value(row));
397 }
398 }
399 DataType::Int64 => {
400 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
401 for row in 0..num_rows {
402 #[allow(clippy::cast_precision_loss)]
403 {
404 data[row * num_cols + col_idx] = arr.value(row) as f64;
405 }
406 }
407 }
408 DataType::UInt8 => {
409 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
410 for row in 0..num_rows {
411 data[row * num_cols + col_idx] = f64::from(arr.value(row));
412 }
413 }
414 DataType::UInt16 => {
415 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
416 for row in 0..num_rows {
417 data[row * num_cols + col_idx] = f64::from(arr.value(row));
418 }
419 }
420 DataType::UInt32 => {
421 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
422 for row in 0..num_rows {
423 data[row * num_cols + col_idx] = f64::from(arr.value(row));
424 }
425 }
426 DataType::UInt64 => {
427 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
428 for row in 0..num_rows {
429 #[allow(clippy::cast_precision_loss)]
430 {
431 data[row * num_cols + col_idx] = arr.value(row) as f64;
432 }
433 }
434 }
435 dt => {
436 return Err(Error::data(format!(
437 "Cannot convert {:?} to f64 tensor",
438 dt
439 )));
440 }
441 }
442 Ok(())
443 }
444
445 fn extract_column_i64(
446 array: &Arc<dyn Array>,
447 data: &mut [i64],
448 col_idx: usize,
449 num_cols: usize,
450 num_rows: usize,
451 ) -> Result<()> {
452 match array.data_type() {
453 DataType::Int8 => {
454 let arr = array.as_primitive::<arrow::datatypes::Int8Type>();
455 for row in 0..num_rows {
456 data[row * num_cols + col_idx] = i64::from(arr.value(row));
457 }
458 }
459 DataType::Int16 => {
460 let arr = array.as_primitive::<arrow::datatypes::Int16Type>();
461 for row in 0..num_rows {
462 data[row * num_cols + col_idx] = i64::from(arr.value(row));
463 }
464 }
465 DataType::Int32 => {
466 let arr = array.as_primitive::<arrow::datatypes::Int32Type>();
467 for row in 0..num_rows {
468 data[row * num_cols + col_idx] = i64::from(arr.value(row));
469 }
470 }
471 DataType::Int64 => {
472 let arr = array.as_primitive::<arrow::datatypes::Int64Type>();
473 for row in 0..num_rows {
474 data[row * num_cols + col_idx] = arr.value(row);
475 }
476 }
477 DataType::UInt8 => {
478 let arr = array.as_primitive::<arrow::datatypes::UInt8Type>();
479 for row in 0..num_rows {
480 data[row * num_cols + col_idx] = i64::from(arr.value(row));
481 }
482 }
483 DataType::UInt16 => {
484 let arr = array.as_primitive::<arrow::datatypes::UInt16Type>();
485 for row in 0..num_rows {
486 data[row * num_cols + col_idx] = i64::from(arr.value(row));
487 }
488 }
489 DataType::UInt32 => {
490 let arr = array.as_primitive::<arrow::datatypes::UInt32Type>();
491 for row in 0..num_rows {
492 data[row * num_cols + col_idx] = i64::from(arr.value(row));
493 }
494 }
495 DataType::UInt64 => {
496 let arr = array.as_primitive::<arrow::datatypes::UInt64Type>();
497 for row in 0..num_rows {
498 #[allow(clippy::cast_possible_wrap)]
499 {
500 data[row * num_cols + col_idx] = arr.value(row) as i64;
501 }
502 }
503 }
504 dt => {
505 return Err(Error::data(format!(
506 "Cannot convert {:?} to i64 tensor",
507 dt
508 )));
509 }
510 }
511 Ok(())
512 }
513}
514
515pub fn extract_column_f32(batch: &RecordBatch, column: &str) -> Result<Vec<f32>> {
521 let extractor = TensorExtractor::new(&[column]);
522 let tensor = extractor.extract_f32(batch)?;
523 Ok(tensor.into_vec())
524}
525
526pub fn extract_column_f64(batch: &RecordBatch, column: &str) -> Result<Vec<f64>> {
532 let extractor = TensorExtractor::new(&[column]);
533 let tensor = extractor.extract_f64(batch)?;
534 Ok(tensor.into_vec())
535}
536
537pub fn extract_labels_i64(batch: &RecordBatch, column: &str) -> Result<Vec<i64>> {
545 let col_index = batch
546 .schema()
547 .index_of(column)
548 .map_err(|_| Error::column_not_found(column))?;
549
550 let array = batch.column(col_index);
551
552 match array.data_type() {
553 DataType::Int8 => {
554 let arr = array
555 .as_any()
556 .downcast_ref::<Int8Array>()
557 .ok_or_else(|| Error::data("Failed to downcast to Int8Array"))?;
558 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
559 }
560 DataType::Int16 => {
561 let arr = array
562 .as_any()
563 .downcast_ref::<Int16Array>()
564 .ok_or_else(|| Error::data("Failed to downcast to Int16Array"))?;
565 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
566 }
567 DataType::Int32 => {
568 let arr = array
569 .as_any()
570 .downcast_ref::<Int32Array>()
571 .ok_or_else(|| Error::data("Failed to downcast to Int32Array"))?;
572 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
573 }
574 DataType::Int64 => {
575 let arr = array
576 .as_any()
577 .downcast_ref::<Int64Array>()
578 .ok_or_else(|| Error::data("Failed to downcast to Int64Array"))?;
579 Ok(arr.iter().map(|v| v.unwrap_or(0)).collect())
580 }
581 DataType::UInt8 => {
582 let arr = array
583 .as_any()
584 .downcast_ref::<UInt8Array>()
585 .ok_or_else(|| Error::data("Failed to downcast to UInt8Array"))?;
586 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
587 }
588 DataType::UInt16 => {
589 let arr = array
590 .as_any()
591 .downcast_ref::<UInt16Array>()
592 .ok_or_else(|| Error::data("Failed to downcast to UInt16Array"))?;
593 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
594 }
595 DataType::UInt32 => {
596 let arr = array
597 .as_any()
598 .downcast_ref::<UInt32Array>()
599 .ok_or_else(|| Error::data("Failed to downcast to UInt32Array"))?;
600 Ok(arr.iter().map(|v| i64::from(v.unwrap_or(0))).collect())
601 }
602 DataType::UInt64 => {
603 let arr = array
604 .as_any()
605 .downcast_ref::<UInt64Array>()
606 .ok_or_else(|| Error::data("Failed to downcast to UInt64Array"))?;
607 #[allow(clippy::cast_possible_wrap)]
608 Ok(arr.iter().map(|v| v.unwrap_or(0) as i64).collect())
609 }
610 DataType::Float32 => {
611 let arr = array
612 .as_any()
613 .downcast_ref::<Float32Array>()
614 .ok_or_else(|| Error::data("Failed to downcast to Float32Array"))?;
615 #[allow(clippy::cast_possible_truncation)]
616 Ok(arr.iter().map(|v| v.unwrap_or(0.0) as i64).collect())
617 }
618 DataType::Float64 => {
619 let arr = array
620 .as_any()
621 .downcast_ref::<Float64Array>()
622 .ok_or_else(|| Error::data("Failed to downcast to Float64Array"))?;
623 #[allow(clippy::cast_possible_truncation)]
624 Ok(arr.iter().map(|v| v.unwrap_or(0.0) as i64).collect())
625 }
626 dt => Err(Error::data(format!("Cannot extract labels from {:?}", dt))),
627 }
628}
629
630#[cfg(test)]
631#[allow(
632 clippy::cast_possible_truncation,
633 clippy::cast_possible_wrap,
634 clippy::uninlined_format_args,
635 clippy::unwrap_used,
636 clippy::expect_used,
637 clippy::float_cmp
638)]
639mod tests {
640 use arrow::datatypes::{Field, Schema};
641
642 use super::*;
643
644 fn create_numeric_batch() -> RecordBatch {
645 let schema = Arc::new(Schema::new(vec![
646 Field::new("f32_col", DataType::Float32, false),
647 Field::new("f64_col", DataType::Float64, false),
648 Field::new("i32_col", DataType::Int32, false),
649 Field::new("i64_col", DataType::Int64, false),
650 ]));
651
652 RecordBatch::try_new(
653 schema,
654 vec![
655 Arc::new(Float32Array::from(vec![1.0f32, 2.0, 3.0])),
656 Arc::new(Float64Array::from(vec![4.0f64, 5.0, 6.0])),
657 Arc::new(Int32Array::from(vec![7, 8, 9])),
658 Arc::new(Int64Array::from(vec![10i64, 11, 12])),
659 ],
660 )
661 .unwrap()
662 }
663
664 #[test]
665 fn test_tensor_data_new() {
666 let tensor: TensorData<f32> = TensorData::new(3, 4);
667 assert_eq!(tensor.shape(), [3, 4]);
668 assert_eq!(tensor.rows(), 3);
669 assert_eq!(tensor.cols(), 4);
670 assert_eq!(tensor.as_slice().len(), 12);
671 }
672
673 #[test]
674 fn test_tensor_data_from_vec() {
675 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
676 let tensor = TensorData::from_vec(data, 2, 3).unwrap();
677 assert_eq!(tensor.shape(), [2, 3]);
678 assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
679 }
680
681 #[test]
682 fn test_tensor_data_from_vec_invalid_shape() {
683 let data = vec![1.0f32, 2.0, 3.0, 4.0];
684 let result = TensorData::from_vec(data, 2, 3);
685 assert!(result.is_err());
686 }
687
688 #[test]
689 fn test_tensor_data_get_set() {
690 let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], 2, 2).unwrap();
691
692 assert_eq!(tensor.get(0, 0), Some(&1.0f32));
693 assert_eq!(tensor.get(0, 1), Some(&2.0f32));
694 assert_eq!(tensor.get(1, 0), Some(&3.0f32));
695 assert_eq!(tensor.get(1, 1), Some(&4.0f32));
696 assert_eq!(tensor.get(2, 0), None);
697
698 tensor.set(0, 1, 99.0);
699 assert_eq!(tensor.get(0, 1), Some(&99.0f32));
700 }
701
702 #[test]
703 fn test_tensor_data_into_vec() {
704 let data = vec![1.0f32, 2.0, 3.0];
705 let tensor = TensorData::from_vec(data.clone(), 1, 3).unwrap();
706 assert_eq!(tensor.into_vec(), data);
707 }
708
709 #[test]
710 fn test_tensor_data_as_ptr() {
711 let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
712 let ptr = tensor.as_ptr();
713 assert!(!ptr.is_null());
714 }
715
716 #[test]
717 fn test_tensor_data_as_mut_slice() {
718 let mut tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
719 let slice = tensor.as_mut_slice();
720 slice[0] = 10.0;
721 assert_eq!(tensor.as_slice()[0], 10.0);
722 }
723
724 #[test]
725 fn test_tensor_data_clone() {
726 let tensor = TensorData::from_vec(vec![1.0f32, 2.0, 3.0], 1, 3).unwrap();
727 let cloned = tensor.clone();
728 assert_eq!(cloned.shape(), tensor.shape());
729 assert_eq!(cloned.as_slice(), tensor.as_slice());
730 }
731
732 #[test]
733 fn test_tensor_data_debug() {
734 let tensor = TensorData::from_vec(vec![1.0f32], 1, 1).unwrap();
735 let debug = format!("{:?}", tensor);
736 assert!(debug.contains("TensorData"));
737 }
738
739 #[test]
740 fn test_extractor_new() {
741 let extractor = TensorExtractor::new(&["a", "b", "c"]);
742 assert_eq!(extractor.columns().len(), 3);
743 assert_eq!(extractor.columns()[0], "a");
744 }
745
746 #[test]
747 fn test_extractor_from_columns() {
748 let extractor = TensorExtractor::from_columns(vec!["x".to_string(), "y".to_string()]);
749 assert_eq!(extractor.columns().len(), 2);
750 }
751
752 #[test]
753 fn test_extractor_clone() {
754 let extractor = TensorExtractor::new(&["a", "b"]);
755 let cloned = extractor.clone();
756 assert_eq!(cloned.columns(), extractor.columns());
757 }
758
759 #[test]
760 fn test_extractor_debug() {
761 let extractor = TensorExtractor::new(&["col"]);
762 let debug = format!("{:?}", extractor);
763 assert!(debug.contains("TensorExtractor"));
764 }
765
766 #[test]
767 fn test_extract_f32() {
768 let batch = create_numeric_batch();
769 let extractor = TensorExtractor::new(&["f32_col", "i32_col"]);
770 let tensor = extractor.extract_f32(&batch).unwrap();
771
772 assert_eq!(tensor.shape(), [3, 2]);
773 assert_eq!(tensor.get(0, 0), Some(&1.0f32));
774 assert_eq!(tensor.get(0, 1), Some(&7.0f32));
775 assert_eq!(tensor.get(2, 0), Some(&3.0f32));
776 assert_eq!(tensor.get(2, 1), Some(&9.0f32));
777 }
778
779 #[test]
780 fn test_extract_f64() {
781 let batch = create_numeric_batch();
782 let extractor = TensorExtractor::new(&["f64_col", "i64_col"]);
783 let tensor = extractor.extract_f64(&batch).unwrap();
784
785 assert_eq!(tensor.shape(), [3, 2]);
786 assert_eq!(tensor.get(0, 0), Some(&4.0f64));
787 assert_eq!(tensor.get(0, 1), Some(&10.0f64));
788 }
789
790 #[test]
791 fn test_extract_i64() {
792 let batch = create_numeric_batch();
793 let extractor = TensorExtractor::new(&["i32_col", "i64_col"]);
794 let tensor = extractor.extract_i64(&batch).unwrap();
795
796 assert_eq!(tensor.shape(), [3, 2]);
797 assert_eq!(tensor.get(0, 0), Some(&7i64));
798 assert_eq!(tensor.get(0, 1), Some(&10i64));
799 }
800
801 #[test]
802 fn test_extract_column_not_found() {
803 let batch = create_numeric_batch();
804 let extractor = TensorExtractor::new(&["nonexistent"]);
805 let result = extractor.extract_f32(&batch);
806 assert!(result.is_err());
807 }
808
809 #[test]
810 fn test_extract_column_f32_helper() {
811 let batch = create_numeric_batch();
812 let data = extract_column_f32(&batch, "f32_col").unwrap();
813 assert_eq!(data, vec![1.0f32, 2.0, 3.0]);
814 }
815
816 #[test]
817 fn test_extract_column_f64_helper() {
818 let batch = create_numeric_batch();
819 let data = extract_column_f64(&batch, "f64_col").unwrap();
820 assert_eq!(data, vec![4.0f64, 5.0, 6.0]);
821 }
822
823 #[test]
824 fn test_extract_labels_i64() {
825 let batch = create_numeric_batch();
826 let labels = extract_labels_i64(&batch, "i32_col").unwrap();
827 assert_eq!(labels, vec![7i64, 8, 9]);
828 }
829
830 #[test]
831 fn test_extract_labels_i64_from_float() {
832 let schema = Arc::new(Schema::new(vec![Field::new(
833 "label",
834 DataType::Float64,
835 false,
836 )]));
837 let batch = RecordBatch::try_new(
838 schema,
839 vec![Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0]))],
840 )
841 .unwrap();
842
843 let labels = extract_labels_i64(&batch, "label").unwrap();
844 assert_eq!(labels, vec![0i64, 1, 2]);
845 }
846
847 #[test]
848 fn test_extract_labels_column_not_found() {
849 let batch = create_numeric_batch();
850 let result = extract_labels_i64(&batch, "nonexistent");
851 assert!(result.is_err());
852 }
853
854 #[test]
855 fn test_extract_all_int_types() {
856 let schema = Arc::new(Schema::new(vec![
857 Field::new("i8", DataType::Int8, false),
858 Field::new("i16", DataType::Int16, false),
859 Field::new("u8", DataType::UInt8, false),
860 Field::new("u16", DataType::UInt16, false),
861 Field::new("u32", DataType::UInt32, false),
862 Field::new("u64", DataType::UInt64, false),
863 ]));
864
865 let batch = RecordBatch::try_new(
866 schema,
867 vec![
868 Arc::new(Int8Array::from(vec![1i8])),
869 Arc::new(Int16Array::from(vec![2i16])),
870 Arc::new(UInt8Array::from(vec![3u8])),
871 Arc::new(UInt16Array::from(vec![4u16])),
872 Arc::new(UInt32Array::from(vec![5u32])),
873 Arc::new(UInt64Array::from(vec![6u64])),
874 ],
875 )
876 .unwrap();
877
878 let extractor = TensorExtractor::new(&["i8", "i16", "u8", "u16", "u32", "u64"]);
880 let tensor = extractor.extract_f32(&batch).unwrap();
881 assert_eq!(tensor.as_slice(), &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
882
883 let tensor = extractor.extract_f64(&batch).unwrap();
885 assert_eq!(tensor.as_slice(), &[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
886
887 let tensor = extractor.extract_i64(&batch).unwrap();
889 assert_eq!(tensor.as_slice(), &[1i64, 2, 3, 4, 5, 6]);
890 }
891
892 #[test]
893 fn test_extract_f32_from_f64() {
894 let schema = Arc::new(Schema::new(vec![Field::new(
895 "value",
896 DataType::Float64,
897 false,
898 )]));
899 let batch = RecordBatch::try_new(
900 schema,
901 vec![Arc::new(Float64Array::from(vec![1.5f64, 2.5, 3.5]))],
902 )
903 .unwrap();
904
905 let extractor = TensorExtractor::new(&["value"]);
906 let tensor = extractor.extract_f32(&batch).unwrap();
907 assert_eq!(tensor.as_slice(), &[1.5f32, 2.5, 3.5]);
908 }
909
910 #[test]
911 fn test_extract_f64_from_f32() {
912 let schema = Arc::new(Schema::new(vec![Field::new(
913 "value",
914 DataType::Float32,
915 false,
916 )]));
917 let batch = RecordBatch::try_new(
918 schema,
919 vec![Arc::new(Float32Array::from(vec![1.5f32, 2.5, 3.5]))],
920 )
921 .unwrap();
922
923 let extractor = TensorExtractor::new(&["value"]);
924 let tensor = extractor.extract_f64(&batch).unwrap();
925 assert_eq!(tensor.as_slice(), &[1.5f64, 2.5, 3.5]);
927 }
928
929 #[test]
930 fn test_extract_unsupported_type_f32() {
931 use arrow::array::StringArray;
932
933 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
934 let batch = RecordBatch::try_new(
935 schema,
936 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
937 )
938 .unwrap();
939
940 let extractor = TensorExtractor::new(&["text"]);
941 let result = extractor.extract_f32(&batch);
942 assert!(result.is_err());
943 }
944
945 #[test]
946 fn test_extract_unsupported_type_f64() {
947 use arrow::array::StringArray;
948
949 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
950 let batch = RecordBatch::try_new(
951 schema,
952 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
953 )
954 .unwrap();
955
956 let extractor = TensorExtractor::new(&["text"]);
957 let result = extractor.extract_f64(&batch);
958 assert!(result.is_err());
959 }
960
961 #[test]
962 fn test_extract_unsupported_type_i64() {
963 use arrow::array::StringArray;
964
965 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
966 let batch = RecordBatch::try_new(
967 schema,
968 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
969 )
970 .unwrap();
971
972 let extractor = TensorExtractor::new(&["text"]);
973 let result = extractor.extract_i64(&batch);
974 assert!(result.is_err());
975 }
976
977 #[test]
978 fn test_extract_labels_unsupported_type() {
979 use arrow::array::StringArray;
980
981 let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
982 let batch = RecordBatch::try_new(
983 schema,
984 vec![Arc::new(StringArray::from(vec!["hello", "world"]))],
985 )
986 .unwrap();
987
988 let result = extract_labels_i64(&batch, "text");
989 assert!(result.is_err());
990 }
991
992 #[test]
993 fn test_extract_labels_all_uint_types() {
994 let schema = Arc::new(Schema::new(vec![
995 Field::new("u8", DataType::UInt8, false),
996 Field::new("u16", DataType::UInt16, false),
997 Field::new("u32", DataType::UInt32, false),
998 Field::new("u64", DataType::UInt64, false),
999 ]));
1000
1001 let batch = RecordBatch::try_new(
1002 schema,
1003 vec![
1004 Arc::new(UInt8Array::from(vec![1u8])),
1005 Arc::new(UInt16Array::from(vec![2u16])),
1006 Arc::new(UInt32Array::from(vec![3u32])),
1007 Arc::new(UInt64Array::from(vec![4u64])),
1008 ],
1009 )
1010 .unwrap();
1011
1012 assert_eq!(extract_labels_i64(&batch, "u8").unwrap(), vec![1i64]);
1013 assert_eq!(extract_labels_i64(&batch, "u16").unwrap(), vec![2i64]);
1014 assert_eq!(extract_labels_i64(&batch, "u32").unwrap(), vec![3i64]);
1015 assert_eq!(extract_labels_i64(&batch, "u64").unwrap(), vec![4i64]);
1016 }
1017
1018 #[test]
1019 fn test_extract_labels_all_int_types() {
1020 let schema = Arc::new(Schema::new(vec![
1021 Field::new("i8", DataType::Int8, false),
1022 Field::new("i16", DataType::Int16, false),
1023 Field::new("f32", DataType::Float32, false),
1024 ]));
1025
1026 let batch = RecordBatch::try_new(
1027 schema,
1028 vec![
1029 Arc::new(Int8Array::from(vec![1i8])),
1030 Arc::new(Int16Array::from(vec![2i16])),
1031 Arc::new(Float32Array::from(vec![3.0f32])),
1032 ],
1033 )
1034 .unwrap();
1035
1036 assert_eq!(extract_labels_i64(&batch, "i8").unwrap(), vec![1i64]);
1037 assert_eq!(extract_labels_i64(&batch, "i16").unwrap(), vec![2i64]);
1038 assert_eq!(extract_labels_i64(&batch, "f32").unwrap(), vec![3i64]);
1039 }
1040}