use crate::enums::error::MinarrowError;
use crate::enums::operators::ArithmeticOperator;
use crate::kernels::broadcast::array_view::broadcast_arrayview_to_tableview;
use crate::{SuperArrayV, SuperTableV, TableV};
#[cfg(all(feature = "chunked", feature = "views"))]
pub fn broadcast_superarrayview_to_tableview(
op: ArithmeticOperator,
super_array_view: &SuperArrayV,
table_view: &TableV,
) -> Result<SuperTableV, MinarrowError> {
if super_array_view.len != table_view.len {
return Err(MinarrowError::ShapeError {
message: format!(
"SuperArrayView length ({}) does not match TableView length ({})",
super_array_view.len, table_view.len
),
});
}
let mut current_offset = 0;
let mut table_slices = Vec::new();
for array_slice in super_array_view.slices.iter() {
let chunk_len = array_slice.len();
let table_slice = table_view.from_self(current_offset, chunk_len);
table_slices.push(table_slice);
current_offset += chunk_len;
}
let aligned_super_table = SuperTableV {
slices: table_slices,
len: table_view.len,
};
let mut result_slices = Vec::new();
for i in 0..super_array_view.slices.len() {
let array_slice = &super_array_view.slices[i];
let table_slice = &aligned_super_table.slices[i];
let slice_result_table = broadcast_arrayview_to_tableview(op, array_slice, table_slice)?;
let n_rows = slice_result_table.n_rows;
result_slices.push(TableV::from_table(slice_result_table, 0, n_rows));
}
Ok(SuperTableV {
slices: result_slices,
len: super_array_view.len,
})
}
#[cfg(all(test, feature = "chunked", feature = "views"))]
mod tests {
use super::*;
use crate::ffi::arrow_dtype::ArrowType;
use crate::{Array, Field, FieldArray, IntegerArray, NumericArray, SuperArray, Table, vec64};
fn create_super_array_view(chunks: Vec<&[i32]>) -> SuperArrayV {
let field_arrays: Vec<FieldArray> = chunks
.iter()
.map(|chunk| {
let arr = Array::from_int32(IntegerArray::from_slice(chunk));
let field = Field::new("test_col".to_string(), ArrowType::Int32, false, None);
FieldArray::new(field, arr)
})
.collect();
let super_array = SuperArray::from_chunks(field_arrays);
SuperArrayV::from(super_array)
}
#[test]
fn test_superarrayview_to_tableview_single_chunk() {
let super_array_view = create_super_array_view(vec![&[1, 2, 3]]);
let arr1 = Array::from_int32(IntegerArray::from_slice(&vec64![10, 20, 30]));
let table = Table::build(
vec![FieldArray::new(
Field::new("col1".to_string(), ArrowType::Int32, false, None),
arr1,
)],
3,
"test".to_string(),
);
let table_view = TableV::from_table(table, 0, 3);
let result = broadcast_superarrayview_to_tableview(
ArithmeticOperator::Add,
&super_array_view,
&table_view,
)
.unwrap();
assert_eq!(result.len, 3);
assert_eq!(result.slices.len(), 1);
let result_table = result.slices[0].to_table();
if let Array::NumericArray(NumericArray::Int32(arr)) = &result_table.cols[0].array {
assert_eq!(arr.data.as_slice(), &[11, 22, 33]);
} else {
panic!("Expected Int32 array");
}
}
#[test]
fn test_superarrayview_to_tableview_length_mismatch() {
let super_array_view = create_super_array_view(vec![&[1, 2, 3]]);
let arr1 = Array::from_int32(IntegerArray::from_slice(&vec64![10, 20, 30, 40, 50]));
let table = Table::build(
vec![FieldArray::new(
Field::new("col1".to_string(), ArrowType::Int32, false, None),
arr1,
)],
5,
"test".to_string(),
);
let table_view = TableV::from_table(table, 0, 5);
let result = broadcast_superarrayview_to_tableview(
ArithmeticOperator::Add,
&super_array_view,
&table_view,
);
assert!(result.is_err());
if let Err(MinarrowError::ShapeError { message }) = result {
assert!(message.contains("does not match"));
} else {
panic!("Expected ShapeError with length mismatch message");
}
}
}