use arrow_array::FixedSizeListArray;
use arrow_array::types::{Float32Type, Float64Type};
use arrow_schema::Field;
use ndarray::Ix3;
use super::{ArrowInteropError, fixed_shape_tensor_viewd};
fn fixed_shape_tensor_view3_f32<'a>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ndarray::ArrayView3<'a, f32>, ArrowInteropError> {
let view = fixed_shape_tensor_viewd::<Float32Type>(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
fn fixed_shape_tensor_view3_f64<'a>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ndarray::ArrayView3<'a, f64>, ArrowInteropError> {
let view = fixed_shape_tensor_viewd::<Float64Type>(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
pub fn qr_f32(
field: &Field,
array: &FixedSizeListArray,
config: &crate::linalg::qr::QRConfig<f32>,
) -> Result<Vec<crate::linalg::qr::QRResult<f32>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f32(field, array)?;
Ok(crate::linalg::batched::qr_view(&view, config)?)
}
pub fn qr_f64(
field: &Field,
array: &FixedSizeListArray,
config: &crate::linalg::qr::QRConfig<f64>,
) -> Result<Vec<crate::linalg::qr::QRResult<f64>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f64(field, array)?;
Ok(crate::linalg::batched::qr_view(&view, config)?)
}
pub fn svd_f32(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::svd::NdarraySVD<f32>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f32(field, array)?;
Ok(crate::linalg::batched::svd_view(&view)?)
}
pub fn svd_f64(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::svd::NdarraySVD<f64>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f64(field, array)?;
Ok(crate::linalg::batched::svd_view(&view)?)
}
pub fn lu_f32(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::lu::NdarrayLUResult<f32>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f32(field, array)?;
Ok(crate::linalg::batched::lu_view(&view)?)
}
pub fn lu_f64(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::lu::NdarrayLUResult<f64>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f64(field, array)?;
Ok(crate::linalg::batched::lu_view(&view)?)
}
pub fn cholesky_f32(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::cholesky::NdarrayCholeskyResult<f32>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f32(field, array)?;
Ok(crate::linalg::batched::cholesky_view(&view)?)
}
pub fn cholesky_f64(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::cholesky::NdarrayCholeskyResult<f64>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f64(field, array)?;
Ok(crate::linalg::batched::cholesky_view(&view)?)
}
pub fn symmetric_eigen_f32(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::eigen::NdarrayEigenResult<f32>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f32(field, array)?;
Ok(crate::linalg::batched::symmetric_eigen_view(&view)?)
}
pub fn symmetric_eigen_f64(
field: &Field,
array: &FixedSizeListArray,
) -> Result<Vec<crate::linalg::eigen::NdarrayEigenResult<f64>>, ArrowInteropError> {
let view = fixed_shape_tensor_view3_f64(field, array)?;
Ok(crate::linalg::batched::symmetric_eigen_view(&view)?)
}