use arrow_array::types::{ArrowPrimitiveType, Float64Type};
use arrow_array::{Array, FixedSizeListArray, StructArray};
use arrow_schema::Field;
use nabled_core::scalar::NabledReal;
use ndarray::{ArrayD, Ix3};
use ndarrow::NdarrowElement;
use num_complex::Complex64;
use serde::Deserialize;
use super::{
ArrowInteropError, complex64_fixed_shape_tensor_from_owned, complex64_fixed_shape_tensor_viewd,
complex64_matrix_from_owned, complex64_matrix_view, fixed_shape_tensor_from_owned,
fixed_shape_tensor_viewd, fixed_size_list_from_owned, fixed_size_list_view,
variable_shape_tensor_batch_view,
};
#[derive(Debug, Deserialize)]
struct VariableShapeTensorWireMetadata {
#[serde(default)]
uniform_shape: Option<Vec<Option<i32>>>,
}
fn fixed_shape_tensor_view3<'a, T>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ndarray::ArrayView3<'a, T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
let view = fixed_shape_tensor_viewd::<T>(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
fn complex64_fixed_shape_tensor_view3<'a>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ndarray::ArrayView3<'a, Complex64>, ArrowInteropError> {
let view = complex64_fixed_shape_tensor_viewd(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
fn variable_shape_uniform_shape(
field: &Field,
) -> Result<Option<Vec<Option<i32>>>, ArrowInteropError> {
let raw_metadata =
field.extension_type_metadata().ok_or_else(|| ndarrow::NdarrowError::InvalidMetadata {
message: "arrow.variable_shape_tensor metadata missing".to_owned(),
})?;
let metadata: VariableShapeTensorWireMetadata =
serde_json::from_str(raw_metadata).map_err(|error| {
ndarrow::NdarrowError::InvalidMetadata {
message: format!("arrow.variable_shape_tensor metadata parse failed: {error}"),
}
})?;
Ok(metadata.uniform_shape)
}
fn reduced_uniform_shape(mut uniform_shape: Option<Vec<Option<i32>>>) -> Option<Vec<Option<i32>>> {
if let Some(shape) = &mut uniform_shape {
let _ = shape.pop();
}
uniform_shape
}
fn collect_variable_shape_real_rows<T, F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
F: FnMut(&ndarray::ArrayViewD<'_, T::Native>) -> Result<ArrayD<T::Native>, ArrowInteropError>,
{
let batch = variable_shape_tensor_batch_view::<T>(field, array)?;
let mut outputs = Vec::with_capacity(batch.len());
for row in 0..batch.len() {
let tensor_view = batch.row(row)?.as_array_viewd()?;
outputs.push(op(&tensor_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, uniform_shape)?)
}
fn collect_variable_shape_complex_rows<F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
F: FnMut(&ndarray::ArrayViewD<'_, Complex64>) -> Result<ArrayD<Complex64>, ArrowInteropError>,
{
let mut outputs = Vec::with_capacity(array.len());
for row in ndarrow::complex64_variable_shape_tensor_iter(field, array)? {
let (_, tensor_view) = row?;
outputs.push(op(&tensor_view)?);
}
Ok(ndarrow::arrays_complex64_to_variable_shape_tensor(field.name(), outputs, uniform_shape)?)
}
fn collect_variable_shape_complex_norm_rows<F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
F: FnMut(&ndarray::ArrayViewD<'_, Complex64>) -> Result<ArrayD<f64>, ArrowInteropError>,
{
let mut outputs = Vec::with_capacity(array.len());
for row in ndarrow::complex64_variable_shape_tensor_iter(field, array)? {
let (_, tensor_view) = row?;
outputs.push(op(&tensor_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, uniform_shape)?)
}
pub type ArrowCpAls3WithReport<T> =
(crate::linalg::tensor::CpAls3Result<T>, crate::linalg::tensor::CpAlsReport<T>);
pub type ArrowCpAlsNdWithReport<T> =
(crate::linalg::tensor::CpAlsNdResult<T>, crate::linalg::tensor::CpAlsReport<T>);
pub fn sum_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::sum_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn l2_norm_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::l2_norm_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn normalize_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::normalize_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn batched_dot_last_axis<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::batched_dot_last_axis_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn permute_axes<T>(
field: &Field,
array: &FixedSizeListArray,
permutation: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::permute_axes_view(&tensor_view, permutation)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn contract_axes<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
left_axes: &[usize],
right_axes: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output =
crate::linalg::tensor::contract_axes_view(&left_view, &right_view, left_axes, right_axes)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn batched_matmul_last_two<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::batched_matmul_last_two_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn cube_matvec<T>(
cube_field: &Field,
cube: &FixedSizeListArray,
vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(cube_field, cube)?;
let vectors_view = fixed_size_list_view::<T>(vectors)?;
let output = crate::linalg::tensor::cube_matvec_view(&cube_view, &vectors_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn cube_matmat<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_view3::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_view3::<T>(right_field, right)?;
let output = crate::linalg::tensor::cube_matmat_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output.into_dyn())
}
pub fn flatten_cubes<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
let output = crate::linalg::tensor::flatten_cubes_view(&cube_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn sum_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::sum_last_axis_view(tensor_view)?)
})
}
pub fn l2_norm_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::l2_norm_last_axis_view(tensor_view)?)
})
}
pub fn normalize_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let uniform_shape = variable_shape_uniform_shape(field)?;
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::normalize_last_axis_view(tensor_view)?)
})
}
pub fn batched_dot_last_axis_variable<T>(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if left.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"variable-shape tensor batch row count mismatch: {} vs {}",
left.len(),
right.len()
)));
}
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(left_field)?);
let left_batch = variable_shape_tensor_batch_view::<T>(left_field, left)?;
let right_batch = variable_shape_tensor_batch_view::<T>(right_field, right)?;
let mut outputs = Vec::with_capacity(left_batch.len());
for row in 0..left_batch.len() {
let left_view = left_batch.row(row)?.as_array_viewd()?;
let right_view = right_batch.row(row)?.as_array_viewd()?;
outputs.push(crate::linalg::tensor::batched_dot_last_axis_view(&left_view, &right_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(left_field.name(), outputs, uniform_shape)?)
}
pub fn sum_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::sum_last_axis_complex_view(&tensor_view)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn sum_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_complex_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::sum_last_axis_complex_view(tensor_view)?)
})
}
pub fn l2_norm_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::l2_norm_last_axis_complex_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<Float64Type>(field.name(), output)
}
pub fn l2_norm_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_complex_norm_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::l2_norm_last_axis_complex_view(tensor_view)?)
})
}
pub fn normalize_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::normalize_last_axis_complex_view(&tensor_view)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn normalize_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = variable_shape_uniform_shape(field)?;
collect_variable_shape_complex_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::normalize_last_axis_complex_view(tensor_view)?)
})
}
pub fn batched_dot_last_axis_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output =
crate::linalg::tensor::batched_dot_last_axis_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn batched_dot_last_axis_variable_complex(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
if left.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"variable-shape tensor batch row count mismatch: {} vs {}",
left.len(),
right.len()
)));
}
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(left_field)?);
let mut outputs = Vec::with_capacity(left.len());
let mut right_iter = ndarrow::complex64_variable_shape_tensor_iter(right_field, right)?;
for left_row in ndarrow::complex64_variable_shape_tensor_iter(left_field, left)? {
let (_, left_view) = left_row?;
let (_, right_view) = right_iter.next().ok_or_else(|| {
ArrowInteropError::InvalidShape(
"variable-shape tensor batch iterator ended early".to_owned(),
)
})??;
outputs.push(crate::linalg::tensor::batched_dot_last_axis_complex_view(
&left_view,
&right_view,
)?);
}
Ok(ndarrow::arrays_complex64_to_variable_shape_tensor(
left_field.name(),
outputs,
uniform_shape,
)?)
}
pub fn permute_axes_complex(
field: &Field,
array: &FixedSizeListArray,
permutation: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::permute_axes_complex_view(&tensor_view, permutation)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn contract_axes_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
left_axes: &[usize],
right_axes: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output = crate::linalg::tensor::contract_axes_complex_view(
&left_view,
&right_view,
left_axes,
right_axes,
)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn batched_matmul_last_two_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output =
crate::linalg::tensor::batched_matmul_last_two_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn cube_matvec_complex(
cube_field: &Field,
cube: &FixedSizeListArray,
vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError> {
let cube_view = complex64_fixed_shape_tensor_view3(cube_field, cube)?;
let vectors_view = complex64_matrix_view(vectors)?;
let output = crate::linalg::tensor::cube_matvec_complex_view(&cube_view, &vectors_view)?;
complex64_matrix_from_owned(output)
}
pub fn cube_matmat_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_view3(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_view3(right_field, right)?;
let output = crate::linalg::tensor::cube_matmat_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output.into_dyn())
}
pub fn einsum<T>(
expression: &str,
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::einsum_view(expression, &left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn einsum_complex(
expression: &str,
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output = crate::linalg::tensor::einsum_complex_view(expression, &left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn cp_als3<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<crate::linalg::tensor::CpAls3Result<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_view(&cube_view, rank, config)?)
}
pub fn cp_als3_with_report<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<ArrowCpAls3WithReport<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_view_with_report(&cube_view, rank, config)?)
}
pub fn cp_als3_diagnostics<T>(
field: &Field,
array: &FixedSizeListArray,
result: &crate::linalg::tensor::CpAls3Result<T::Native>,
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_diagnostics_view(&cube_view, result)?)
}
pub fn cp_als3_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::CpAls3Result<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::cp_als3_reconstruct(result)?;
fixed_shape_tensor_from_owned::<T>(field_name, output.into_dyn())
}
pub fn cp_als_nd<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<crate::linalg::tensor::CpAlsNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_view(&tensor_view, rank, config)?)
}
pub fn cp_als_nd_with_report<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<ArrowCpAlsNdWithReport<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_view_with_report(&tensor_view, rank, config)?)
}
pub fn cp_als_nd_diagnostics<T>(
field: &Field,
array: &FixedSizeListArray,
result: &crate::linalg::tensor::CpAlsNdResult<T::Native>,
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_diagnostics_view(&tensor_view, result)?)
}
pub fn cp_als_nd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::CpAlsNdResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::cp_als_nd_reconstruct(result)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}
pub fn hosvd_nd<T>(
field: &Field,
array: &FixedSizeListArray,
ranks: &[usize],
) -> Result<crate::linalg::tensor::HosvdNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::HosvdNdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::hosvd_nd_view(&tensor_view, ranks)?)
}
pub fn hooi_nd<T>(
field: &Field,
array: &FixedSizeListArray,
ranks: &[usize],
config: &crate::linalg::tensor::HooiConfig<T::Native>,
) -> Result<crate::linalg::tensor::HosvdNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::HooiNdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::hooi_nd_view(&tensor_view, ranks, config)?)
}
pub fn tucker_project<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ndarray::Array2<T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::tucker_project_view(&tensor_view, factors)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn tucker_expand<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ndarray::Array2<T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::tucker_expand_view(&tensor_view, factors)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn hosvd_nd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::HosvdNdResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::hosvd_nd_reconstruct(result)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}
pub fn tt_svd<T>(
field: &Field,
array: &FixedSizeListArray,
config: &crate::linalg::tensor::TtSvdConfig<T::Native>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::TtSvdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::tt_svd_view(&tensor_view, config)?)
}
pub fn tt_orthogonalize_left<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_left(result)?)
}
pub fn tt_orthogonalize_right<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_right(result)?)
}
pub fn tt_round<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_round(result, config)?)
}
pub fn tt_inner<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_inner(left, right)?)
}
pub fn tt_norm<T: NabledReal>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_norm(result)?)
}
pub fn tt_add<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_add(left, right)?)
}
pub fn tt_hadamard<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard(left, right)?)
}
pub fn tt_hadamard_round<T: crate::linalg::tensor::TtSvdScalar>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard_round(left, right, config)?)
}
pub fn tt_svd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::TensorTrainResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::tt_svd_reconstruct(result)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}