use std::sync::Arc;
use arrow_array::types::ArrowPrimitiveType;
use arrow_array::{Array, FixedSizeListArray, ListArray, PrimitiveArray, StructArray};
use arrow_buffer::ScalarBuffer;
use arrow_schema::Field;
use nabled_core::scalar::NabledReal;
use ndarray::{Array1, Array2, ArrayD, ArrayView1, ArrayView2, ArrayViewD};
use ndarrow::{AsNdarray, NdarrowElement};
use num_complex::Complex64;
use thiserror::Error;
pub mod batched;
pub mod cholesky;
pub mod eigen;
pub mod iterative;
pub mod jacobian;
pub mod lu;
pub mod matrix;
pub mod matrix_functions;
pub mod optimization;
pub mod orthogonalization;
pub mod pca;
pub mod polar;
pub mod qr;
pub mod regression;
pub mod schur;
pub mod sparse;
pub mod stats;
pub mod svd;
pub mod tensor;
pub mod triangular;
pub mod vector;
#[derive(Debug, Error)]
pub enum ArrowInteropError {
#[error(transparent)]
Bridge(#[from] ndarrow::NdarrowError),
#[error(transparent)]
Vector(#[from] crate::linalg::vector::VectorError),
#[error(transparent)]
Matrix(#[from] crate::linalg::matrix::MatrixError),
#[error(transparent)]
LU(#[from] crate::linalg::lu::LUError),
#[error(transparent)]
Cholesky(#[from] crate::linalg::cholesky::CholeskyError),
#[error(transparent)]
QR(#[from] crate::linalg::qr::QRError),
#[error(transparent)]
SVD(#[from] crate::linalg::svd::SVDError),
#[error(transparent)]
Eigen(#[from] crate::linalg::eigen::EigenError),
#[error(transparent)]
Schur(#[from] crate::linalg::schur::SchurError),
#[error(transparent)]
Polar(#[from] crate::linalg::polar::PolarError),
#[error(transparent)]
MatrixFunctions(#[from] crate::linalg::matrix_functions::MatrixFunctionError),
#[error(transparent)]
Orthogonalization(#[from] crate::linalg::orthogonalization::OrthogonalizationError),
#[error(transparent)]
Sparse(#[from] crate::linalg::sparse::SparseError),
#[error(transparent)]
Tensor(#[from] crate::linalg::tensor::TensorError),
#[error(transparent)]
Triangular(#[from] crate::linalg::triangular::TriangularError),
#[error(transparent)]
Iterative(#[from] crate::ml::iterative::IterativeError),
#[error(transparent)]
Jacobian(#[from] crate::ml::jacobian::JacobianError),
#[error(transparent)]
Optimization(#[from] crate::ml::optimization::OptimizationError),
#[error(transparent)]
PCA(#[from] crate::ml::pca::PCAError),
#[error(transparent)]
Regression(#[from] crate::ml::regression::RegressionError),
#[error(transparent)]
Stats(#[from] crate::ml::stats::StatsError),
#[error("invalid Arrow tensor shape: {0}")]
InvalidShape(String),
}
pub(crate) fn primitive_array_view<T>(
array: &PrimitiveArray<T>,
) -> Result<ArrayView1<'_, T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
Ok(array.as_ndarray()?)
}
pub(crate) fn fixed_size_list_view<T>(
array: &FixedSizeListArray,
) -> Result<ArrayView2<'_, T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
Ok(ndarrow::fixed_size_list_as_array2::<T>(array)?)
}
pub(crate) fn fixed_shape_tensor_viewd<'a, T>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayViewD<'a, T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
Ok(ndarrow::fixed_shape_tensor_as_array_viewd::<T>(field, array)?)
}
pub(crate) fn variable_shape_tensor_batch_view<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<ndarrow::VariableShapeTensorBatchView<'a, T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
if array.null_count() > 0 {
return Err(ndarrow::NdarrowError::NullsPresent { null_count: array.null_count() }.into());
}
Ok(ndarrow::variable_shape_tensor_batch_view::<T>(field, array)?)
}
fn owned_array1_into_vec<T: Copy>(array: Array1<T>) -> Vec<T> {
let len = array.len();
let standard =
if array.is_standard_layout() { array } else { array.as_standard_layout().into_owned() };
let (mut raw, offset) = standard.into_raw_vec_and_offset();
let off = offset.unwrap_or(0);
if off == 0 {
raw.truncate(len);
raw
} else {
raw[off..off + len].to_vec()
}
}
fn owned_array2_into_vec<T: Copy>(array: Array2<T>) -> (usize, Vec<T>) {
let cols = array.ncols();
let len = array.len();
let standard =
if array.is_standard_layout() { array } else { array.as_standard_layout().into_owned() };
let (mut raw, offset) = standard.into_raw_vec_and_offset();
let off = offset.unwrap_or(0);
if off == 0 {
raw.truncate(len);
(cols, raw)
} else {
(cols, raw[off..off + len].to_vec())
}
}
pub(crate) fn primitive_array_from_owned<T>(array: Array1<T::Native>) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
PrimitiveArray::new(ScalarBuffer::from(owned_array1_into_vec(array)), None)
}
pub(crate) fn fixed_size_list_from_owned<T>(
array: Array2<T::Native>,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
let (cols, values) = owned_array2_into_vec(array);
let value_length = i32::try_from(cols).map_err(|_| ndarrow::NdarrowError::ShapeMismatch {
message: format!("matrix column count {cols} exceeds Arrow i32 value_length limits"),
})?;
let values = PrimitiveArray::<T>::new(ScalarBuffer::from(values), None);
Ok(FixedSizeListArray::new(
Arc::new(Field::new("item", T::DATA_TYPE, false)),
value_length,
Arc::new(values),
None,
))
}
pub(crate) fn complex64_vector_view<'a>(
field: &Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayView1<'a, Complex64>, ArrowInteropError> {
Ok(ndarrow::complex64_as_array_view1(field, array)?)
}
pub(crate) fn complex64_vector_from_owned(
field_name: &str,
array: Array1<Complex64>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
Ok(ndarrow::array1_complex64_to_extension(field_name, array)?)
}
pub(crate) fn complex64_matrix_view(
array: &FixedSizeListArray,
) -> Result<ArrayView2<'_, Complex64>, ArrowInteropError> {
Ok(ndarrow::complex64_as_array_view2(array)?)
}
pub(crate) fn complex64_matrix_from_owned(
array: Array2<Complex64>,
) -> Result<FixedSizeListArray, ArrowInteropError> {
Ok(ndarrow::array2_complex64_to_fixed_size_list(array)?)
}
pub(crate) fn complex64_fixed_shape_tensor_viewd<'a>(
field: &Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayViewD<'a, Complex64>, ArrowInteropError> {
Ok(ndarrow::complex64_fixed_shape_tensor_as_array_viewd(field, array)?)
}
pub(crate) fn complex64_fixed_shape_tensor_from_owned(
field_name: &str,
array: ArrayD<Complex64>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
Ok(ndarrow::arrayd_complex64_to_fixed_shape_tensor(field_name, array)?)
}
pub(crate) fn fixed_shape_tensor_from_owned<T>(
field_name: &str,
array: ArrayD<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
Ok(ndarrow::arrayd_to_fixed_shape_tensor(field_name, array)?)
}
pub(crate) fn csr_matrix_view_from_columns<'a, T>(
indices: &'a ListArray,
values: &'a ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::CsrMatrixView<'a, i32, T::Native, u32>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let view = ndarrow::csr_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::CsrMatrixView::new(
view.nrows,
view.ncols,
view.row_ptrs,
view.col_indices,
view.values,
)?)
}
pub(crate) fn csr_matrix_view_from_extension<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<crate::linalg::sparse::CsrMatrixView<'a, i32, T::Native, u32>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let view = ndarrow::csr_view_from_extension::<T>(field, array)?;
Ok(crate::linalg::sparse::CsrMatrixView::new(
view.nrows,
view.ncols,
view.row_ptrs,
view.col_indices,
view.values,
)?)
}
pub(crate) fn csr_matrix_batch_view<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<ndarrow::CsrMatrixBatchView<'a, T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if array.null_count() > 0 {
return Err(ndarrow::NdarrowError::NullsPresent { null_count: array.null_count() }.into());
}
Ok(ndarrow::csr_matrix_batch_view::<T>(field, array)?)
}
pub(crate) fn csr_matrix_view_from_batch_row<T>(
view: ndarrow::CsrView<'_, T::Native>,
) -> Result<crate::linalg::sparse::CsrMatrixView<'_, i32, T::Native, u32>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
Ok(crate::linalg::sparse::CsrMatrixView::new(
view.nrows,
view.ncols,
view.row_ptrs,
view.col_indices,
view.values,
)?)
}