use arrow_array::types::{ArrowPrimitiveType, Float64Type};
use arrow_array::{Array, FixedSizeListArray, StructArray};
use arrow_schema::Field;
use arrow_schema::extension::{
EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY, ExtensionType, VariableShapeTensor,
};
use nabled_core::scalar::NabledReal;
use ndarray::{ArrayD, ArrayView1, ArrayView2, ArrayView3, ArrayViewD, Ix3};
use ndarrow::NdarrowElement;
use num_complex::Complex64;
use serde::{Deserialize, Serialize};
use super::{
ArrowInteropError, complex64_fixed_shape_tensor_from_owned, complex64_fixed_shape_tensor_viewd,
complex64_matrix_from_owned, complex64_matrix_view, fixed_shape_tensor_from_owned,
fixed_shape_tensor_viewd, fixed_size_list_from_owned, fixed_size_list_view,
variable_shape_tensor_batch_view,
};
#[derive(Debug, Deserialize, Serialize)]
struct VariableShapeTensorWireMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
dim_names: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
permutations: Option<Vec<usize>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
uniform_shape: Option<Vec<Option<i32>>>,
}
fn fixed_shape_tensor_view3<'a, T>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayView3<'a, T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
let view = fixed_shape_tensor_viewd::<T>(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
fn complex64_fixed_shape_tensor_view3<'a>(
field: &'a Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayView3<'a, Complex64>, ArrowInteropError> {
let view = complex64_fixed_shape_tensor_viewd(field, array)?;
view.into_dimensionality::<Ix3>()
.map_err(|error: ndarray::ShapeError| ArrowInteropError::InvalidShape(error.to_string()))
}
fn variable_shape_uniform_shape(
field: &Field,
) -> Result<Option<Vec<Option<i32>>>, ArrowInteropError> {
let raw_metadata =
field.extension_type_metadata().ok_or_else(|| ndarrow::NdarrowError::InvalidMetadata {
message: "arrow.variable_shape_tensor metadata missing".to_owned(),
})?;
let metadata: VariableShapeTensorWireMetadata =
serde_json::from_str(raw_metadata).map_err(|error| {
ndarrow::NdarrowError::InvalidMetadata {
message: format!("arrow.variable_shape_tensor metadata parse failed: {error}"),
}
})?;
Ok(metadata.uniform_shape)
}
fn reduced_uniform_shape(mut uniform_shape: Option<Vec<Option<i32>>>) -> Option<Vec<Option<i32>>> {
if let Some(shape) = &mut uniform_shape {
let _ = shape.pop();
}
uniform_shape
}
fn variable_shape_field_with_metadata(
field: &Field,
uniform_shape: Option<Vec<Option<i32>>>,
) -> Result<Field, ArrowInteropError> {
let metadata_json = serde_json::to_string(&VariableShapeTensorWireMetadata {
dim_names: None,
permutations: None,
uniform_shape,
})
.map_err(|error| ndarrow::NdarrowError::InvalidMetadata {
message: format!("arrow.variable_shape_tensor metadata serialization failed: {error}"),
})?;
let mut metadata = field.metadata().clone();
drop(metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), VariableShapeTensor::NAME.to_owned()));
drop(metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata_json));
Ok(Field::new(field.name(), field.data_type().clone(), field.is_nullable())
.with_metadata(metadata))
}
fn collect_variable_shape_real_rows<T, F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
F: FnMut(&ArrayViewD<'_, T::Native>) -> Result<ArrayD<T::Native>, ArrowInteropError>,
{
let batch = variable_shape_tensor_batch_view::<T>(field, array)?;
let mut outputs = Vec::with_capacity(batch.len());
for row in 0..batch.len() {
let tensor_view = batch.row(row)?.as_array_viewd()?;
outputs.push(op(&tensor_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, uniform_shape)?)
}
fn collect_variable_shape_complex_rows<F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
F: FnMut(&ArrayViewD<'_, Complex64>) -> Result<ArrayD<Complex64>, ArrowInteropError>,
{
let mut outputs = Vec::with_capacity(array.len());
for row in ndarrow::complex64_variable_shape_tensor_iter(field, array)? {
let (_, tensor_view) = row?;
outputs.push(op(&tensor_view)?);
}
let (field, array) = ndarrow::arrays_complex64_to_variable_shape_tensor(
field.name(),
outputs,
uniform_shape.clone(),
)?;
Ok((variable_shape_field_with_metadata(&field, uniform_shape)?, array))
}
fn collect_variable_shape_complex_norm_rows<F>(
field: &Field,
array: &StructArray,
uniform_shape: Option<Vec<Option<i32>>>,
mut op: F,
) -> Result<(Field, StructArray), ArrowInteropError>
where
F: FnMut(&ArrayViewD<'_, Complex64>) -> Result<ArrayD<f64>, ArrowInteropError>,
{
let mut outputs = Vec::with_capacity(array.len());
for row in ndarrow::complex64_variable_shape_tensor_iter(field, array)? {
let (_, tensor_view) = row?;
outputs.push(op(&tensor_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, uniform_shape)?)
}
pub type ArrowCpAls3WithReport<T> =
(crate::linalg::tensor::CpAls3Result<T>, crate::linalg::tensor::CpAlsReport<T>);
pub type ArrowCpAlsNdWithReport<T> =
(crate::linalg::tensor::CpAlsNdResult<T>, crate::linalg::tensor::CpAlsReport<T>);
pub fn sum_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::sum_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn l2_norm_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::l2_norm_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn normalize_last_axis<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::normalize_last_axis_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn batched_dot_last_axis<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::batched_dot_last_axis_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn permute_axes<T>(
field: &Field,
array: &FixedSizeListArray,
permutation: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::permute_axes_view(&tensor_view, permutation)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn contract_axes<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
left_axes: &[usize],
right_axes: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output =
crate::linalg::tensor::contract_axes_view(&left_view, &right_view, left_axes, right_axes)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn batched_matmul_last_two<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::batched_matmul_last_two_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn cube_matvec<T>(
cube_field: &Field,
cube: &FixedSizeListArray,
vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(cube_field, cube)?;
let vectors_view = fixed_size_list_view::<T>(vectors)?;
let output = crate::linalg::tensor::cube_matvec_view(&cube_view, &vectors_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn cube_matmat<T>(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let left_view = fixed_shape_tensor_view3::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_view3::<T>(right_field, right)?;
let output = crate::linalg::tensor::cube_matmat_view(&left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output.into_dyn())
}
pub fn flatten_cubes<T>(
field: &Field,
array: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
let output = crate::linalg::tensor::flatten_cubes_view(&cube_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn sum_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement + Default,
{
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::sum_last_axis_view(tensor_view)?)
})
}
pub fn l2_norm_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::l2_norm_last_axis_view(tensor_view)?)
})
}
pub fn normalize_last_axis_variable<T>(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let uniform_shape = variable_shape_uniform_shape(field)?;
collect_variable_shape_real_rows::<T, _>(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::normalize_last_axis_view(tensor_view)?)
})
}
pub fn batched_dot_last_axis_variable<T>(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if left.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"variable-shape tensor batch row count mismatch: {} vs {}",
left.len(),
right.len()
)));
}
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(left_field)?);
let left_batch = variable_shape_tensor_batch_view::<T>(left_field, left)?;
let right_batch = variable_shape_tensor_batch_view::<T>(right_field, right)?;
let mut outputs = Vec::with_capacity(left_batch.len());
for row in 0..left_batch.len() {
let left_view = left_batch.row(row)?.as_array_viewd()?;
let right_view = right_batch.row(row)?.as_array_viewd()?;
outputs.push(crate::linalg::tensor::batched_dot_last_axis_view(&left_view, &right_view)?);
}
Ok(ndarrow::arrays_to_variable_shape_tensor(left_field.name(), outputs, uniform_shape)?)
}
pub fn sum_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::sum_last_axis_complex_view(&tensor_view)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn sum_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_complex_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::sum_last_axis_complex_view(tensor_view)?)
})
}
pub fn l2_norm_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::l2_norm_last_axis_complex_view(&tensor_view)?;
fixed_shape_tensor_from_owned::<Float64Type>(field.name(), output)
}
pub fn l2_norm_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(field)?);
collect_variable_shape_complex_norm_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::l2_norm_last_axis_complex_view(tensor_view)?)
})
}
pub fn normalize_last_axis_complex(
field: &Field,
array: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::normalize_last_axis_complex_view(&tensor_view)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn normalize_last_axis_variable_complex(
field: &Field,
array: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
let uniform_shape = variable_shape_uniform_shape(field)?;
collect_variable_shape_complex_rows(field, array, uniform_shape, |tensor_view| {
Ok(crate::linalg::tensor::normalize_last_axis_complex_view(tensor_view)?)
})
}
pub fn batched_dot_last_axis_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output =
crate::linalg::tensor::batched_dot_last_axis_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn batched_dot_last_axis_variable_complex(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError> {
if left.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"variable-shape tensor batch row count mismatch: {} vs {}",
left.len(),
right.len()
)));
}
let uniform_shape = reduced_uniform_shape(variable_shape_uniform_shape(left_field)?);
let mut outputs = Vec::with_capacity(left.len());
let mut right_iter = ndarrow::complex64_variable_shape_tensor_iter(right_field, right)?;
for left_row in ndarrow::complex64_variable_shape_tensor_iter(left_field, left)? {
let (_, left_view) = left_row?;
let (_, right_view) = right_iter.next().ok_or_else(|| {
ArrowInteropError::InvalidShape(
"variable-shape tensor batch iterator ended early".to_owned(),
)
})??;
outputs.push(crate::linalg::tensor::batched_dot_last_axis_complex_view(
&left_view,
&right_view,
)?);
}
Ok(ndarrow::arrays_complex64_to_variable_shape_tensor(
left_field.name(),
outputs,
uniform_shape,
)?)
}
pub fn permute_axes_complex(
field: &Field,
array: &FixedSizeListArray,
permutation: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let tensor_view = complex64_fixed_shape_tensor_viewd(field, array)?;
let output = crate::linalg::tensor::permute_axes_complex_view(&tensor_view, permutation)?;
complex64_fixed_shape_tensor_from_owned(field.name(), output)
}
pub fn contract_axes_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
left_axes: &[usize],
right_axes: &[usize],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output = crate::linalg::tensor::contract_axes_complex_view(
&left_view,
&right_view,
left_axes,
right_axes,
)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn batched_matmul_last_two_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output =
crate::linalg::tensor::batched_matmul_last_two_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn cube_matvec_complex(
cube_field: &Field,
cube: &FixedSizeListArray,
vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError> {
let cube_view = complex64_fixed_shape_tensor_view3(cube_field, cube)?;
let vectors_view = complex64_matrix_view(vectors)?;
let output = crate::linalg::tensor::cube_matvec_complex_view(&cube_view, &vectors_view)?;
complex64_matrix_from_owned(output)
}
pub fn cube_matmat_complex(
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_view3(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_view3(right_field, right)?;
let output = crate::linalg::tensor::cube_matmat_complex_view(&left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output.into_dyn())
}
pub fn einsum<T>(
expression: &str,
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let left_view = fixed_shape_tensor_viewd::<T>(left_field, left)?;
let right_view = fixed_shape_tensor_viewd::<T>(right_field, right)?;
let output = crate::linalg::tensor::einsum_view(expression, &left_view, &right_view)?;
fixed_shape_tensor_from_owned::<T>(left_field.name(), output)
}
pub fn einsum_complex(
expression: &str,
left_field: &Field,
left: &FixedSizeListArray,
right_field: &Field,
right: &FixedSizeListArray,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError> {
let left_view = complex64_fixed_shape_tensor_viewd(left_field, left)?;
let right_view = complex64_fixed_shape_tensor_viewd(right_field, right)?;
let output = crate::linalg::tensor::einsum_complex_view(expression, &left_view, &right_view)?;
complex64_fixed_shape_tensor_from_owned(left_field.name(), output)
}
pub fn cp_als3<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<crate::linalg::tensor::CpAls3Result<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_view(&cube_view, rank, config)?)
}
pub fn cp_als3_with_report<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<ArrowCpAls3WithReport<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_view_with_report(&cube_view, rank, config)?)
}
pub fn cp_als3_diagnostics<T>(
field: &Field,
array: &FixedSizeListArray,
result: &crate::linalg::tensor::CpAls3Result<T::Native>,
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
cp_als3_diagnostics_from_factors_view::<T>(
field,
array,
&result.weights.view(),
&result.factor_0.view(),
&result.factor_1.view(),
&result.factor_2.view(),
)
}
pub fn cp_als3_diagnostics_from_factors_view<T>(
field: &Field,
array: &FixedSizeListArray,
weights: &ArrayView1<'_, T::Native>,
factor_0: &ArrayView2<'_, T::Native>,
factor_1: &ArrayView2<'_, T::Native>,
factor_2: &ArrayView2<'_, T::Native>,
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let cube_view = fixed_shape_tensor_view3::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als3_diagnostics_from_factors_view(
&cube_view, weights, factor_0, factor_1, factor_2,
)?)
}
pub fn cp_als3_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::CpAls3Result<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
cp_als3_reconstruct_from_factors_view::<T>(
field_name,
&result.weights.view(),
&result.factor_0.view(),
&result.factor_1.view(),
&result.factor_2.view(),
)
}
pub fn cp_als3_reconstruct_from_factors_view<T>(
field_name: &str,
weights: &ArrayView1<'_, T::Native>,
factor_0: &ArrayView2<'_, T::Native>,
factor_1: &ArrayView2<'_, T::Native>,
factor_2: &ArrayView2<'_, T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::cp_als3_reconstruct_from_factors_view(
weights, factor_0, factor_1, factor_2,
)?;
fixed_shape_tensor_from_owned::<T>(field_name, output.into_dyn())
}
pub fn cp_als_nd<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<crate::linalg::tensor::CpAlsNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_view(&tensor_view, rank, config)?)
}
pub fn cp_als_nd_with_report<T>(
field: &Field,
array: &FixedSizeListArray,
rank: usize,
config: &crate::linalg::tensor::CpAlsConfig<T::Native>,
) -> Result<ArrowCpAlsNdWithReport<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::CpAlsScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_view_with_report(&tensor_view, rank, config)?)
}
pub fn cp_als_nd_diagnostics<T>(
field: &Field,
array: &FixedSizeListArray,
result: &crate::linalg::tensor::CpAlsNdResult<T::Native>,
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
cp_als_nd_diagnostics_from_factors_view::<T>(
field,
array,
&result.weights.view(),
&result.factors.iter().map(|factor| factor.view()).collect::<Vec<_>>(),
)
}
pub fn cp_als_nd_diagnostics_from_factors_view<T>(
field: &Field,
array: &FixedSizeListArray,
weights: &ArrayView1<'_, T::Native>,
factors: &[ArrayView2<'_, T::Native>],
) -> Result<crate::linalg::tensor::CpErrorMetrics<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::cp_als_nd_diagnostics_from_factors_view(
&tensor_view,
weights,
factors,
)?)
}
pub fn cp_als_nd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::CpAlsNdResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
cp_als_nd_reconstruct_from_factors_view::<T>(
field_name,
&result.weights.view(),
&result.factors.iter().map(|factor| factor.view()).collect::<Vec<_>>(),
)
}
pub fn cp_als_nd_reconstruct_from_factors_view<T>(
field_name: &str,
weights: &ArrayView1<'_, T::Native>,
factors: &[ArrayView2<'_, T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::cp_als_nd_reconstruct_from_factors_view(weights, factors)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}
pub fn hosvd_nd<T>(
field: &Field,
array: &FixedSizeListArray,
ranks: &[usize],
) -> Result<crate::linalg::tensor::HosvdNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::HosvdNdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::hosvd_nd_view(&tensor_view, ranks)?)
}
pub fn hooi_nd<T>(
field: &Field,
array: &FixedSizeListArray,
ranks: &[usize],
config: &crate::linalg::tensor::HooiConfig<T::Native>,
) -> Result<crate::linalg::tensor::HosvdNdResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::HooiNdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::hooi_nd_view(&tensor_view, ranks, config)?)
}
pub fn tucker_project<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ndarray::Array2<T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
tucker_project_from_factors_view::<T>(
field,
array,
&factors.iter().map(|factor| factor.view()).collect::<Vec<_>>(),
)
}
pub fn tucker_project_from_factors_view<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ArrayView2<'_, T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::tucker_project_from_factors_view(&tensor_view, factors)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn tucker_expand<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ndarray::Array2<T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
tucker_expand_from_factors_view::<T>(
field,
array,
&factors.iter().map(|factor| factor.view()).collect::<Vec<_>>(),
)
}
pub fn tucker_expand_from_factors_view<T>(
field: &Field,
array: &FixedSizeListArray,
factors: &[ArrayView2<'_, T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
let output = crate::linalg::tensor::tucker_expand_from_factors_view(&tensor_view, factors)?;
fixed_shape_tensor_from_owned::<T>(field.name(), output)
}
pub fn hosvd_nd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::HosvdNdResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
hosvd_nd_reconstruct_from_factors_view::<T>(
field_name,
&result.core.view(),
&result.factors.iter().map(|factor| factor.view()).collect::<Vec<_>>(),
)
}
pub fn hosvd_nd_reconstruct_from_factors_view<T>(
field_name: &str,
core: &ArrayViewD<'_, T::Native>,
factors: &[ArrayView2<'_, T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::hosvd_nd_reconstruct_from_factors_view(core, factors)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}
pub fn tt_svd<T>(
field: &Field,
array: &FixedSizeListArray,
config: &crate::linalg::tensor::TtSvdConfig<T::Native>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: crate::linalg::tensor::TtSvdScalar + NdarrowElement,
{
let tensor_view = fixed_shape_tensor_viewd::<T>(field, array)?;
Ok(crate::linalg::tensor::tt_svd_view(&tensor_view, config)?)
}
pub fn tt_orthogonalize_left<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_left(result)?)
}
pub fn tt_orthogonalize_left_from_cores_view<T: crate::linalg::tensor::TtSvdScalar>(
cores: &[ArrayView3<'_, T>],
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_left_from_cores_view(cores)?)
}
pub fn tt_orthogonalize_right<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_right(result)?)
}
pub fn tt_orthogonalize_right_from_cores_view<T: crate::linalg::tensor::TtSvdScalar>(
cores: &[ArrayView3<'_, T>],
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_orthogonalize_right_from_cores_view(cores)?)
}
pub fn tt_round<T: crate::linalg::tensor::TtSvdScalar>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_round(result, config)?)
}
pub fn tt_round_from_cores_view<T: crate::linalg::tensor::TtSvdScalar>(
cores: &[ArrayView3<'_, T>],
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_round_from_cores_view(cores, config)?)
}
pub fn tt_inner<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_inner(left, right)?)
}
pub fn tt_inner_from_cores_view<T: NabledReal>(
left: &[ArrayView3<'_, T>],
right: &[ArrayView3<'_, T>],
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_inner_from_cores_view(left, right)?)
}
pub fn tt_norm<T: NabledReal>(
result: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_norm(result)?)
}
pub fn tt_norm_from_cores_view<T: NabledReal>(
cores: &[ArrayView3<'_, T>],
) -> Result<T, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_norm_from_cores_view(cores)?)
}
pub fn tt_add<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_add(left, right)?)
}
pub fn tt_add_from_cores_view<T: NabledReal>(
left: &[ArrayView3<'_, T>],
right: &[ArrayView3<'_, T>],
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_add_from_cores_view(left, right)?)
}
pub fn tt_hadamard<T: NabledReal>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard(left, right)?)
}
pub fn tt_hadamard_from_cores_view<T: NabledReal>(
left: &[ArrayView3<'_, T>],
right: &[ArrayView3<'_, T>],
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard_from_cores_view(left, right)?)
}
pub fn tt_hadamard_round<T: crate::linalg::tensor::TtSvdScalar>(
left: &crate::linalg::tensor::TensorTrainResult<T>,
right: &crate::linalg::tensor::TensorTrainResult<T>,
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard_round(left, right, config)?)
}
pub fn tt_hadamard_round_from_cores_view<T: crate::linalg::tensor::TtSvdScalar>(
left: &[ArrayView3<'_, T>],
right: &[ArrayView3<'_, T>],
config: &crate::linalg::tensor::TtRoundConfig<T>,
) -> Result<crate::linalg::tensor::TensorTrainResult<T>, ArrowInteropError> {
Ok(crate::linalg::tensor::tt_hadamard_round_from_cores_view(left, right, config)?)
}
pub fn tt_svd_reconstruct<T>(
field_name: &str,
result: &crate::linalg::tensor::TensorTrainResult<T::Native>,
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::tt_svd_reconstruct(result)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}
pub fn tt_svd_reconstruct_from_cores_view<T>(
field_name: &str,
cores: &[ArrayView3<'_, T::Native>],
) -> Result<(Field, FixedSizeListArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let output = crate::linalg::tensor::tt_svd_reconstruct_from_cores_view(cores)?;
fixed_shape_tensor_from_owned::<T>(field_name, output)
}