use crate::array2::Array2;
use crate::array3::Axis3;
use crate::error::{Error, Result};
use crate::numeric::Float;
use crate::rand::SmallRng;
use crate::view2::{ArrayView2, ArrayViewMut2};
use crate::view3::ArrayView3;
use crate::workspace::Workspace;
use pulp::{Arch, Simd, WithSimd};
use rayon::prelude::*;
pub fn dot<T: Float>(x: &[T], y: &[T]) -> Result<T> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
Ok(x.iter().zip(y).map(|(&a, &b)| a * b).sum())
}
pub fn axpy<T: Float>(alpha: T, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
for (yi, &xi) in y.iter_mut().zip(x) {
*yi += alpha * xi;
}
Ok(())
}
pub fn norm_l2<T: Float>(x: &[T]) -> T {
x.iter()
.copied()
.map(|value| value * value)
.sum::<T>()
.sqrt()
}
pub fn dot_f32(x: &[f32], y: &[f32]) -> Result<f32> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
Ok(Arch::new().dispatch(DotF32 { x, y }))
}
pub fn dot_f64(x: &[f64], y: &[f64]) -> Result<f64> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
Ok(Arch::new().dispatch(DotF64 { x, y }))
}
pub fn axpy_f32(alpha: f32, x: &[f32], y: &mut [f32]) -> Result<()> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
Arch::new().dispatch(AxpyF32 { alpha, x, y });
Ok(())
}
pub fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) -> Result<()> {
if x.len() != y.len() {
return Err(Error::shape(vec![x.len()], vec![y.len()]));
}
Arch::new().dispatch(AxpyF64 { alpha, x, y });
Ok(())
}
pub fn norm_l2_f32(x: &[f32]) -> f32 {
dot_f32(x, x)
.expect("matching input slices are valid")
.sqrt()
}
pub fn norm_l2_f64(x: &[f64]) -> f64 {
dot_f64(x, x)
.expect("matching input slices are valid")
.sqrt()
}
struct DotF32<'a> {
x: &'a [f32],
y: &'a [f32],
}
impl WithSimd for DotF32<'_> {
type Output = f32;
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let (x_head, x_tail) = S::as_simd_f32s(self.x);
let (y_head, y_tail) = S::as_simd_f32s(self.y);
let mut acc = simd.splat_f32s(0.0);
for (&x, &y) in x_head.iter().zip(y_head) {
acc = simd.mul_add_f32s(x, y, acc);
}
let mut sum = simd.reduce_sum_f32s(acc);
for (&x, &y) in x_tail.iter().zip(y_tail) {
sum += x * y;
}
sum
}
}
struct DotF64<'a> {
x: &'a [f64],
y: &'a [f64],
}
impl WithSimd for DotF64<'_> {
type Output = f64;
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let (x_head, x_tail) = S::as_simd_f64s(self.x);
let (y_head, y_tail) = S::as_simd_f64s(self.y);
let mut acc = simd.splat_f64s(0.0);
for (&x, &y) in x_head.iter().zip(y_head) {
acc = simd.mul_add_f64s(x, y, acc);
}
let mut sum = simd.reduce_sum_f64s(acc);
for (&x, &y) in x_tail.iter().zip(y_tail) {
sum += x * y;
}
sum
}
}
struct AxpyF32<'a> {
alpha: f32,
x: &'a [f32],
y: &'a mut [f32],
}
impl WithSimd for AxpyF32<'_> {
type Output = ();
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let (x_head, x_tail) = S::as_simd_f32s(self.x);
let (y_head, y_tail) = S::as_mut_simd_f32s(self.y);
let alpha = simd.splat_f32s(self.alpha);
for (y, &x) in y_head.iter_mut().zip(x_head) {
*y = simd.mul_add_f32s(alpha, x, *y);
}
for (y, &x) in y_tail.iter_mut().zip(x_tail) {
*y += self.alpha * x;
}
}
}
struct AxpyF64<'a> {
alpha: f64,
x: &'a [f64],
y: &'a mut [f64],
}
impl WithSimd for AxpyF64<'_> {
type Output = ();
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let (x_head, x_tail) = S::as_simd_f64s(self.x);
let (y_head, y_tail) = S::as_mut_simd_f64s(self.y);
let alpha = simd.splat_f64s(self.alpha);
for (y, &x) in y_head.iter_mut().zip(x_head) {
*y = simd.mul_add_f64s(alpha, x, *y);
}
for (y, &x) in y_tail.iter_mut().zip(x_tail) {
*y += self.alpha * x;
}
}
}
pub fn pack_block<T: Copy>(
a: ArrayView2<'_, T>,
row: usize,
col: usize,
rows: usize,
cols: usize,
) -> Result<Array2<T>> {
if row > a.rows()
|| col > a.cols()
|| rows > a.rows().saturating_sub(row)
|| cols > a.cols().saturating_sub(col)
{
return Err(Error::IndexOutOfBounds);
}
Ok(Array2::from_fn([rows, cols], |i, j| a[(row + i, col + j)]))
}
pub fn unpack_block<T: Copy>(
block: ArrayView2<'_, T>,
mut dst: ArrayViewMut2<'_, T>,
row: usize,
col: usize,
) -> Result<()> {
if row > dst.rows()
|| col > dst.cols()
|| block.rows() > dst.rows().saturating_sub(row)
|| block.cols() > dst.cols().saturating_sub(col)
{
return Err(Error::IndexOutOfBounds);
}
for i in 0..block.rows() {
for j in 0..block.cols() {
dst[(row + i, col + j)] = block[(i, j)];
}
}
Ok(())
}
pub fn gemm<T: Float>(
alpha: T,
a: ArrayView2<'_, T>,
trans_a: bool,
b: ArrayView2<'_, T>,
trans_b: bool,
beta: T,
c: ArrayViewMut2<'_, T>,
) -> Result<()> {
let mut workspace = Workspace::new();
gemm_with_workspace(alpha, a, trans_a, b, trans_b, beta, c, &mut workspace)
}
#[allow(clippy::too_many_arguments)]
pub fn gemm_with_workspace<T: Float>(
alpha: T,
a: ArrayView2<'_, T>,
trans_a: bool,
b: ArrayView2<'_, T>,
trans_b: bool,
beta: T,
c: ArrayViewMut2<'_, T>,
workspace: &mut Workspace<T>,
) -> Result<()> {
gemm_blocked_workspace(
GemmBlocked {
alpha,
a,
trans_a,
b,
trans_b,
beta,
c,
block_size: 32,
},
workspace,
)
}
struct GemmBlocked<'a, 'b, 'c, T> {
alpha: T,
a: ArrayView2<'a, T>,
trans_a: bool,
b: ArrayView2<'b, T>,
trans_b: bool,
beta: T,
c: ArrayViewMut2<'c, T>,
block_size: usize,
}
fn gemm_blocked_workspace<T: Float>(
spec: GemmBlocked<'_, '_, '_, T>,
workspace: &mut Workspace<T>,
) -> Result<()> {
let GemmBlocked {
alpha,
a,
trans_a,
b,
trans_b,
beta,
mut c,
block_size,
} = spec;
let (m, k_a) = if trans_a {
(a.cols(), a.rows())
} else {
(a.rows(), a.cols())
};
let (k_b, n) = if trans_b {
(b.cols(), b.rows())
} else {
(b.rows(), b.cols())
};
if k_a != k_b {
return Err(Error::shape(vec![m, k_a], vec![k_b, n]));
}
if c.shape() != [m, n] {
return Err(Error::shape(vec![m, n], c.shape()));
}
let block = block_size.max(1);
for i in 0..m {
for j in 0..n {
c[(i, j)] *= beta;
}
}
for i0 in (0..m).step_by(block) {
let ib = block.min(m - i0);
for p0 in (0..k_a).step_by(block) {
let pb = block.min(k_a - p0);
for j0 in (0..n).step_by(block) {
let jb = block.min(n - j0);
let (a_buffer, b_buffer) = workspace.two_buffers_mut(0, 1);
let a_block = a_buffer.zeros(ib * pb);
pack_op_block_into(a, trans_a, i0, p0, ib, pb, a_block);
let b_block = b_buffer.zeros(pb * jb);
pack_op_block_into(b, trans_b, p0, j0, pb, jb, b_block);
for i in (0..ib).step_by(4) {
for j in (0..jb).step_by(4) {
let rows = 4.min(ib - i);
let cols = 4.min(jb - j);
microkernel_4x4(
alpha,
PackedBlock {
data: &a_block[i * pb..],
rows,
cols: pb,
},
PackedBlock {
data: &b_block[j..],
rows: pb,
cols: jb,
},
&mut c,
[i0 + i, j0 + j],
cols,
);
}
}
}
}
}
Ok(())
}
fn pack_op_block_into<T: Float>(
a: ArrayView2<'_, T>,
trans: bool,
row: usize,
col: usize,
rows: usize,
cols: usize,
out: &mut [T],
) {
for i in 0..rows {
for j in 0..cols {
out[i * cols + j] = if trans {
a[(col + j, row + i)]
} else {
a[(row + i, col + j)]
};
}
}
}
struct PackedBlock<'a, T> {
data: &'a [T],
rows: usize,
cols: usize,
}
fn microkernel_4x4<T: Float>(
alpha: T,
a: PackedBlock<'_, T>,
b: PackedBlock<'_, T>,
c: &mut ArrayViewMut2<'_, T>,
c_origin: [usize; 2],
c_cols: usize,
) {
let mut c00 = T::zero();
let mut c01 = T::zero();
let mut c02 = T::zero();
let mut c03 = T::zero();
let mut c10 = T::zero();
let mut c11 = T::zero();
let mut c12 = T::zero();
let mut c13 = T::zero();
let mut c20 = T::zero();
let mut c21 = T::zero();
let mut c22 = T::zero();
let mut c23 = T::zero();
let mut c30 = T::zero();
let mut c31 = T::zero();
let mut c32 = T::zero();
let mut c33 = T::zero();
for p in 0..a.cols {
let b0 = b.data[p * b.cols];
let b1 = if c_cols > 1 {
b.data[p * b.cols + 1]
} else {
T::zero()
};
let b2 = if c_cols > 2 {
b.data[p * b.cols + 2]
} else {
T::zero()
};
let b3 = if c_cols > 3 {
b.data[p * b.cols + 3]
} else {
T::zero()
};
let a0 = a.data[p];
c00 += a0 * b0;
c01 += a0 * b1;
c02 += a0 * b2;
c03 += a0 * b3;
if a.rows > 1 {
let a1 = a.data[a.cols + p];
c10 += a1 * b0;
c11 += a1 * b1;
c12 += a1 * b2;
c13 += a1 * b3;
}
if a.rows > 2 {
let a2 = a.data[2 * a.cols + p];
c20 += a2 * b0;
c21 += a2 * b1;
c22 += a2 * b2;
c23 += a2 * b3;
}
if a.rows > 3 {
let a3 = a.data[3 * a.cols + p];
c30 += a3 * b0;
c31 += a3 * b1;
c32 += a3 * b2;
c33 += a3 * b3;
}
}
accumulate_tile(alpha, c, c_origin, 0, &[c00, c01, c02, c03], c_cols);
if a.rows > 1 {
accumulate_tile(alpha, c, c_origin, 1, &[c10, c11, c12, c13], c_cols);
}
if a.rows > 2 {
accumulate_tile(alpha, c, c_origin, 2, &[c20, c21, c22, c23], c_cols);
}
if a.rows > 3 {
accumulate_tile(alpha, c, c_origin, 3, &[c30, c31, c32, c33], c_cols);
}
}
fn accumulate_tile<T: Float>(
alpha: T,
c: &mut ArrayViewMut2<'_, T>,
origin: [usize; 2],
row: usize,
values: &[T; 4],
cols: usize,
) {
for col in 0..cols {
c[(origin[0] + row, origin[1] + col)] += alpha * values[col];
}
}
pub fn matmul<T: Float>(a: ArrayView2<'_, T>, b: ArrayView2<'_, T>) -> Result<Array2<T>> {
if a.cols() != b.rows() {
return Err(Error::shape(a.shape(), b.shape()));
}
let mut c = Array2::zeros([a.rows(), b.cols()]);
gemm(T::one(), a, false, b, false, T::zero(), c.view_mut())?;
Ok(c)
}
pub trait LinearOperator<T: Float> {
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
fn matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
if x.rows() != self.cols() || y.shape() != [self.rows(), x.cols()] {
return Err(Error::shape(vec![self.cols(), x.cols()], x.shape()));
}
for col in 0..x.cols() {
let mut input = vec![T::zero(); x.rows()];
let mut output = vec![T::zero(); y.rows()];
for row in 0..x.rows() {
input[row] = x[(row, col)];
}
self.matvec(&input, &mut output)?;
for row in 0..y.rows() {
y[(row, col)] = output[row];
}
}
Ok(())
}
fn t_matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
if x.rows() != self.rows() || y.shape() != [self.cols(), x.cols()] {
return Err(Error::shape(vec![self.rows(), x.cols()], x.shape()));
}
for col in 0..x.cols() {
let mut input = vec![T::zero(); x.rows()];
let mut output = vec![T::zero(); y.rows()];
for row in 0..x.rows() {
input[row] = x[(row, col)];
}
self.t_matvec(&input, &mut output)?;
for row in 0..y.rows() {
y[(row, col)] = output[row];
}
}
Ok(())
}
}
impl<T: Float> LinearOperator<T> for Array2<T> {
fn rows(&self) -> usize {
self.rows()
}
fn cols(&self) -> usize {
self.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.cols() || y.len() != self.rows() {
return Err(Error::shape(
vec![self.cols(), self.rows()],
vec![x.len(), y.len()],
));
}
for i in 0..self.rows() {
let mut sum = T::zero();
for j in 0..self.cols() {
sum += self[(i, j)] * x[j];
}
y[i] = sum;
}
Ok(())
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.rows() || y.len() != self.cols() {
return Err(Error::shape(
vec![self.rows(), self.cols()],
vec![x.len(), y.len()],
));
}
y.fill(T::zero());
for i in 0..self.rows() {
for j in 0..self.cols() {
y[j] += self[(i, j)] * x[i];
}
}
Ok(())
}
fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
gemm(T::one(), self.view(), false, x, false, T::zero(), y)
}
fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
gemm(T::one(), self.view(), true, x, false, T::zero(), y)
}
}
impl<T: Float> LinearOperator<T> for ArrayView2<'_, T> {
fn rows(&self) -> usize {
self.rows()
}
fn cols(&self) -> usize {
self.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.cols() || y.len() != self.rows() {
return Err(Error::shape(
vec![self.cols(), self.rows()],
vec![x.len(), y.len()],
));
}
for i in 0..self.rows() {
let mut sum = T::zero();
for j in 0..self.cols() {
sum += self[(i, j)] * x[j];
}
y[i] = sum;
}
Ok(())
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.rows() || y.len() != self.cols() {
return Err(Error::shape(
vec![self.rows(), self.cols()],
vec![x.len(), y.len()],
));
}
y.fill(T::zero());
for i in 0..self.rows() {
for j in 0..self.cols() {
y[j] += self[(i, j)] * x[i];
}
}
Ok(())
}
fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
gemm(T::one(), *self, false, x, false, T::zero(), y)
}
fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
gemm(T::one(), *self, true, x, false, T::zero(), y)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Transpose<A> {
inner: A,
}
impl<A> Transpose<A> {
pub fn new(inner: A) -> Self {
Self { inner }
}
}
impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for Transpose<A> {
fn rows(&self) -> usize {
self.inner.cols()
}
fn cols(&self) -> usize {
self.inner.rows()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
self.inner.t_matvec(x, y)
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
self.inner.matvec(x, y)
}
}
#[derive(Clone, Debug)]
pub struct CenteredOperator<A, T> {
inner: A,
means: Vec<T>,
}
impl<A, T: Float> CenteredOperator<A, T> {
pub fn new(inner: A, means: Vec<T>) -> Self {
Self { inner, means }
}
pub fn means(&self) -> &[T] {
&self.means
}
}
impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for CenteredOperator<A, T> {
fn rows(&self) -> usize {
self.inner.rows()
}
fn cols(&self) -> usize {
self.inner.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.means.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
}
self.inner.matvec(x, y)?;
let correction: T = self.means.iter().zip(x).map(|(&mean, &xj)| mean * xj).sum();
for yi in y {
*yi -= correction;
}
Ok(())
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.means.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
}
self.inner.t_matvec(x, y)?;
let total: T = x.iter().copied().sum();
for (yj, &mean) in y.iter_mut().zip(&self.means) {
*yj -= mean * total;
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct ColumnScaledOperator<A, T> {
inner: A,
scales: Vec<T>,
}
impl<A, T: Float> ColumnScaledOperator<A, T> {
pub fn new(inner: A, scales: Vec<T>) -> Self {
Self { inner, scales }
}
pub fn scales(&self) -> &[T] {
&self.scales
}
}
impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for ColumnScaledOperator<A, T> {
fn rows(&self) -> usize {
self.inner.rows()
}
fn cols(&self) -> usize {
self.inner.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.scales.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
}
if x.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![x.len()]));
}
let scaled = x
.iter()
.zip(&self.scales)
.map(|(&value, &scale)| value * scale)
.collect::<Vec<_>>();
self.inner.matvec(&scaled, y)
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.scales.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
}
self.inner.t_matvec(x, y)?;
for (yj, &scale) in y.iter_mut().zip(&self.scales) {
*yj *= scale;
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct RowScaledOperator<A, T> {
inner: A,
scales: Vec<T>,
}
impl<A, T: Float> RowScaledOperator<A, T> {
pub fn new(inner: A, scales: Vec<T>) -> Self {
Self { inner, scales }
}
pub fn scales(&self) -> &[T] {
&self.scales
}
}
impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for RowScaledOperator<A, T> {
fn rows(&self) -> usize {
self.inner.rows()
}
fn cols(&self) -> usize {
self.inner.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.scales.len() != self.rows() {
return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
}
self.inner.matvec(x, y)?;
for (yi, &scale) in y.iter_mut().zip(&self.scales) {
*yi *= scale;
}
Ok(())
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if self.scales.len() != self.rows() {
return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
}
if x.len() != self.rows() {
return Err(Error::shape(vec![self.rows()], vec![x.len()]));
}
let scaled = x
.iter()
.zip(&self.scales)
.map(|(&value, &scale)| value * scale)
.collect::<Vec<_>>();
self.inner.t_matvec(&scaled, y)
}
}
#[derive(Clone, Debug)]
pub struct StandardizedOperator<A, T> {
inner: A,
means: Vec<T>,
scales: Vec<T>,
}
impl<A, T: Float> StandardizedOperator<A, T> {
pub fn new(inner: A, means: Vec<T>, scales: Vec<T>) -> Self {
Self {
inner,
means,
scales,
}
}
pub fn means(&self) -> &[T] {
&self.means
}
pub fn scales(&self) -> &[T] {
&self.scales
}
}
impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for StandardizedOperator<A, T> {
fn rows(&self) -> usize {
self.inner.rows()
}
fn cols(&self) -> usize {
self.inner.cols()
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
if x.len() != self.cols() {
return Err(Error::shape(vec![self.cols()], vec![x.len()]));
}
let scaled = x
.iter()
.zip(&self.scales)
.map(|(&value, &scale)| value / scale)
.collect::<Vec<_>>();
self.inner.matvec(&scaled, y)?;
let correction: T = self
.means
.iter()
.zip(&scaled)
.map(|(&mean, &xj)| mean * xj)
.sum();
for yi in y {
*yi -= correction;
}
Ok(())
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
self.inner.t_matvec(x, y)?;
let total: T = x.iter().copied().sum();
for ((yj, &mean), &scale) in y.iter_mut().zip(&self.means).zip(&self.scales) {
*yj = (*yj - mean * total) / scale;
}
Ok(())
}
}
fn validate_standardized_parts<T: Float>(cols: usize, means: &[T], scales: &[T]) -> Result<()> {
if means.len() != cols {
return Err(Error::shape(vec![cols], vec![means.len()]));
}
if scales.len() != cols {
return Err(Error::shape(vec![cols], vec![scales.len()]));
}
if scales.iter().any(|&scale| scale == T::zero()) {
return Err(Error::NumericalFailure("standardization scale is zero"));
}
Ok(())
}
#[derive(Clone, Debug)]
pub struct RandomizedSvdOptions {
pub rank: usize,
pub oversampling: usize,
pub power_iterations: usize,
pub seed: Option<u64>,
pub tolerance: Option<f64>,
pub compute_u: bool,
pub compute_vt: bool,
}
impl Default for RandomizedSvdOptions {
fn default() -> Self {
Self {
rank: 2,
oversampling: 8,
power_iterations: 1,
seed: None,
tolerance: None,
compute_u: true,
compute_vt: true,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct SvdResult<T> {
pub u: Array2<T>,
pub s: Vec<T>,
pub vt: Array2<T>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct EighResult<T> {
pub eigenvalues: Vec<T>,
pub eigenvectors: Array2<T>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct QrResult<T> {
pub q: Array2<T>,
pub r: Array2<T>,
}
pub fn qr<T: Float>(a: ArrayView2<'_, T>) -> Result<QrResult<T>> {
let mut columns: Vec<Vec<T>> = Vec::new();
let mut r_rows: Vec<Vec<T>> = Vec::new();
for j in 0..a.cols() {
let mut column = (0..a.rows()).map(|i| a[(i, j)]).collect::<Vec<_>>();
for _ in 0..2 {
for (idx, prev) in columns.iter().enumerate() {
let mut projection = T::zero();
for i in 0..a.rows() {
projection += prev[i] * column[i];
}
r_rows[idx][j] += projection;
for i in 0..a.rows() {
column[i] -= projection * prev[i];
}
}
}
let mut norm = T::zero();
for &value in &column {
norm += value * value;
}
norm = norm.sqrt();
if norm <= T::from_f64(1e-12) {
continue;
}
if !norm.is_finite() {
return Err(Error::NumericalFailure("non-finite QR column norm"));
}
for value in &mut column {
*value /= norm;
}
columns.push(column);
let mut r_row = vec![T::zero(); a.cols()];
r_row[j] = norm;
r_rows.push(r_row);
}
if columns.is_empty() && a.cols() > 0 {
return Err(Error::NumericalFailure(
"matrix has no independent QR columns",
));
}
let q = Array2::from_fn([a.rows(), columns.len()], |i, j| columns[j][i]);
let r = Array2::from_fn([r_rows.len(), a.cols()], |i, j| r_rows[i][j]);
Ok(QrResult { q, r })
}
pub fn thin_qr<T: Float>(a: ArrayView2<'_, T>) -> Result<Array2<T>> {
Ok(qr(a)?.q)
}
pub fn reorthogonalize<T: Float>(q: ArrayView2<'_, T>) -> Result<Array2<T>> {
thin_qr(q)
}
pub fn randomized_range_finder<T: Float, A: LinearOperator<T>>(
a: &A,
rank: usize,
oversampling: usize,
power_iterations: usize,
seed: Option<u64>,
) -> Result<Array2<T>> {
let l = (rank + oversampling).min(a.cols()).min(a.rows());
if rank == 0 || rank > a.rows().min(a.cols()) {
return Err(Error::RankTooLarge {
requested: rank,
max: a.rows().min(a.cols()),
});
}
let mut rng = SmallRng::new(seed.unwrap_or(0x5eed_1234_abcd_9876));
let omega = Array2::from_fn([a.cols(), l], |_, _| rng.normal::<T>());
let mut y = Array2::zeros([a.rows(), l]);
a.matmat(omega.view(), y.view_mut())?;
for _ in 0..power_iterations {
let q = thin_qr(y.view())?;
let mut z = Array2::zeros([a.cols(), q.cols()]);
a.t_matmat(q.view(), z.view_mut())?;
y = Array2::zeros([a.rows(), z.cols()]);
a.matmat(z.view(), y.view_mut())?;
}
thin_qr(y.view())
}
pub fn randomized_svd<T: Float, A: LinearOperator<T>>(
a: &A,
options: RandomizedSvdOptions,
) -> Result<SvdResult<T>> {
if options.rank == 0 || options.rank > a.rows().min(a.cols()) {
return Err(Error::RankTooLarge {
requested: options.rank,
max: a.rows().min(a.cols()),
});
}
let q = randomized_range_finder(
a,
options.rank,
options.oversampling,
options.power_iterations,
options.seed,
)?;
let mut at_q = Array2::zeros([a.cols(), q.cols()]);
a.t_matmat(q.view(), at_q.view_mut())?;
let b = Array2::clone_contiguous(at_q.transpose_view());
let small = svd_small(b.view())?;
let rank = options.rank.min(small.s.len());
let u = if options.compute_u {
let projected = matmul(q.view(), small.u.view())?;
Array2::from_fn([a.rows(), rank], |i, j| projected[(i, j)])
} else {
Array2::zeros([0, 0])
};
let s = small.s.into_iter().take(rank).collect();
let vt = if options.compute_vt {
Array2::from_fn([rank, a.cols()], |i, j| small.vt[(i, j)])
} else {
Array2::zeros([0, 0])
};
Ok(SvdResult { u, s, vt })
}
pub fn randomized_svd_with_error<T: Float>(
a: ArrayView2<'_, T>,
options: RandomizedSvdOptions,
) -> Result<(SvdResult<T>, T)> {
let compute_u = options.compute_u;
let compute_vt = options.compute_vt;
let work_options = RandomizedSvdOptions {
compute_u: true,
compute_vt: true,
..options.clone()
};
let mut result = randomized_svd(&a, work_options)?;
let error = approx_reconstruction_error(a, result.u.view(), &result.s, result.vt.view())?;
if let Some(tolerance) = options.tolerance
&& error.to_f64() > tolerance
{
return Err(Error::NotConverged);
}
if !compute_u {
result.u = Array2::zeros([0, 0]);
}
if !compute_vt {
result.vt = Array2::zeros([0, 0]);
}
Ok((result, error))
}
pub fn batch_randomized_svd<T: Float>(
a: ArrayView3<'_, T>,
axis: Axis3,
options: RandomizedSvdOptions,
) -> Result<Vec<SvdResult<T>>> {
let axis_index = axis.index();
let mut results = Vec::with_capacity(a.shape()[axis_index]);
for index in 0..a.shape()[axis_index] {
let matrix = a.matrix_at(axis_index, index)?;
results.push(randomized_svd(&matrix, options.clone())?);
}
Ok(results)
}
pub fn batch_randomized_svd_parallel<T: Float>(
a: ArrayView3<'_, T>,
axis: Axis3,
options: RandomizedSvdOptions,
) -> Result<Vec<SvdResult<T>>> {
let axis_index = axis.index();
(0..a.shape()[axis_index])
.into_par_iter()
.map(|index| {
let matrix = a.matrix_at(axis_index, index)?;
randomized_svd(&matrix, options.clone())
})
.collect()
}
pub fn approx_reconstruction_error<T: Float>(
a: ArrayView2<'_, T>,
u: ArrayView2<'_, T>,
s: &[T],
vt: ArrayView2<'_, T>,
) -> Result<T> {
if u.rows() != a.rows() || u.cols() != s.len() {
return Err(Error::shape(vec![a.rows(), s.len()], u.shape()));
}
if vt.rows() != s.len() || vt.cols() != a.cols() {
return Err(Error::shape(vec![s.len(), a.cols()], vt.shape()));
}
let mut residual = T::zero();
for i in 0..a.rows() {
for j in 0..a.cols() {
let mut approx = T::zero();
for r in 0..s.len() {
approx += u[(i, r)] * s[r] * vt[(r, j)];
}
let diff = a[(i, j)] - approx;
residual += diff * diff;
}
}
Ok(residual.sqrt())
}
pub fn explained_variance_ratio<T: Float>(s: &[T]) -> Vec<T> {
let total: T = s.iter().copied().map(|value| value * value).sum();
if total == T::zero() {
return vec![T::zero(); s.len()];
}
s.iter()
.copied()
.map(|value| value * value / total)
.collect()
}
pub fn eigh_small<T: Float>(a: ArrayView2<'_, T>) -> Result<EighResult<T>> {
if a.rows() != a.cols() {
return Err(Error::shape([a.rows(), a.rows()], a.shape()));
}
for i in 0..a.rows() {
for j in (i + 1)..a.cols() {
if (a[(i, j)] - a[(j, i)]).abs() > T::from_f64(1e-9) {
return Err(Error::NumericalFailure("matrix is not symmetric"));
}
}
}
let mut eig = jacobi_symmetric(Array2::from_fn(a.shape(), |i, j| a[(i, j)].to_f64()))?;
eig.sort_by(|left, right| {
right
.0
.partial_cmp(&left.0)
.unwrap_or(core::cmp::Ordering::Equal)
});
let eigenvalues = eig
.iter()
.map(|(value, _)| T::from_f64(*value))
.collect::<Vec<_>>();
let eigenvectors = Array2::from_fn([a.rows(), a.cols()], |i, j| T::from_f64(eig[j].1[i]));
Ok(EighResult {
eigenvalues,
eigenvectors,
})
}
pub fn svd_small<T: Float>(a: ArrayView2<'_, T>) -> Result<SvdResult<T>> {
let gram = gram_left(a);
let mut eig = jacobi_symmetric(gram)?;
eig.sort_by(|left, right| {
right
.0
.partial_cmp(&left.0)
.unwrap_or(core::cmp::Ordering::Equal)
});
let rank = eig.len().min(a.rows()).min(a.cols());
let mut u = Array2::zeros([a.rows(), rank]);
let mut s = vec![T::zero(); rank];
for j in 0..rank {
let value = eig[j].0.max(0.0).sqrt();
s[j] = T::from_f64(value);
for i in 0..a.rows() {
u[(i, j)] = T::from_f64(eig[j].1[i]);
}
}
let mut vt = Array2::zeros([rank, a.cols()]);
for r in 0..rank {
if s[r] <= T::from_f64(1e-12) {
continue;
}
for col in 0..a.cols() {
let mut value = T::zero();
for row in 0..a.rows() {
value += u[(row, r)] * a[(row, col)];
}
vt[(r, col)] = value / s[r];
}
}
Ok(SvdResult { u, s, vt })
}
fn gram_left<T: Float>(a: ArrayView2<'_, T>) -> Array2<f64> {
Array2::from_fn([a.rows(), a.rows()], |i, j| {
let mut sum = 0.0;
for col in 0..a.cols() {
sum += a[(i, col)].to_f64() * a[(j, col)].to_f64();
}
sum
})
}
fn jacobi_symmetric(mut a: Array2<f64>) -> Result<Vec<(f64, Vec<f64>)>> {
if a.rows() != a.cols() {
return Err(Error::shape([a.rows(), a.rows()], a.shape()));
}
let n = a.rows();
let mut v = Array2::from_fn([n, n], |i, j| if i == j { 1.0 } else { 0.0 });
let max_iter = 64usize.saturating_mul(n.max(1)).saturating_mul(n.max(1));
for _ in 0..max_iter {
let mut p = 0;
let mut q = 0;
let mut max = 0.0;
for i in 0..n {
for j in (i + 1)..n {
let value = a[(i, j)].abs();
if value > max {
max = value;
p = i;
q = j;
}
}
}
if max < 1e-12 {
let mut result = Vec::with_capacity(n);
for col in 0..n {
let mut vector = Vec::with_capacity(n);
for row in 0..n {
vector.push(v[(row, col)]);
}
result.push((a[(col, col)], vector));
}
return Ok(result);
}
let app = a[(p, p)];
let aqq = a[(q, q)];
let apq = a[(p, q)];
let tau = (aqq - app) / (2.0 * apq);
let t = tau.signum() / (tau.abs() + (1.0 + tau * tau).sqrt());
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
for k in 0..n {
if k != p && k != q {
let akp = a[(k, p)];
let akq = a[(k, q)];
let new_kp = c * akp - s * akq;
let new_kq = s * akp + c * akq;
a[(k, p)] = new_kp;
a[(p, k)] = new_kp;
a[(k, q)] = new_kq;
a[(q, k)] = new_kq;
}
}
a[(p, p)] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
a[(q, q)] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
a[(p, q)] = 0.0;
a[(q, p)] = 0.0;
for k in 0..n {
let vkp = v[(k, p)];
let vkq = v[(k, q)];
v[(k, p)] = c * vkp - s * vkq;
v[(k, q)] = s * vkp + c * vkq;
}
}
Err(Error::NotConverged)
}