use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
use super::super::GpuContext;
#[derive(Debug, Clone)]
pub struct GpuDecompositionConfig {
pub tile_size: usize,
pub min_gpu_size: usize,
pub out_of_core: bool,
pub max_gpu_memory: usize,
pub mixed_precision: bool,
pub num_streams: usize,
pub tolerance: f64,
}
impl Default for GpuDecompositionConfig {
fn default() -> Self {
#[cfg(target_pointer_width = "32")]
let max_gpu_memory = 256 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let max_gpu_memory = 4usize * 1024 * 1024 * 1024;
Self {
tile_size: 256,
min_gpu_size: 10000,
out_of_core: true,
max_gpu_memory,
mixed_precision: false,
num_streams: 2,
tolerance: 1e-14,
}
}
}
#[derive(Debug, Clone)]
pub struct LuDecomposition<T> {
pub l: Array2<T>,
pub u: Array2<T>,
pub p: Array1<usize>,
pub num_swaps: usize,
}
#[derive(Debug, Clone)]
pub struct QrDecomposition<T> {
pub q: Array2<T>,
pub r: Array2<T>,
}
#[derive(Debug, Clone)]
pub struct CholeskyDecomposition<T> {
pub l: Array2<T>,
}
#[derive(Debug, Clone)]
pub struct SvdDecomposition<T> {
pub u: Array2<T>,
pub s: Array1<T>,
pub vt: Array2<T>,
}
#[derive(Debug, Clone)]
pub struct EigenDecomposition<T> {
pub eigenvalues: Array1<T>,
pub eigenvectors: Array2<T>,
}
pub struct GpuDecompositions<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
config: GpuDecompositionConfig,
_phantom: PhantomData<T>,
}
impl<T> GpuDecompositions<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
pub fn new() -> Self {
Self::with_config(GpuDecompositionConfig::default())
}
pub fn with_config(config: GpuDecompositionConfig) -> Self {
Self {
config,
_phantom: PhantomData,
}
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn lu(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
) -> LinalgResult<LuDecomposition<T>> {
let (m, n) = a.dim();
if m * n < self.config.min_gpu_size {
return self.cpu_lu(a);
}
context.synchronize()?;
let result = if m * n > self.config.tile_size * self.config.tile_size {
self.tiled_lu(a)?
} else {
self.cpu_lu(a)?
};
context.synchronize()?;
Ok(result)
}
pub fn cpu_lu(&self, a: &ArrayView2<T>) -> LinalgResult<LuDecomposition<T>> {
let (m, n) = a.dim();
let min_dim = m.min(n);
let mut lu = a.to_owned();
let mut p: Vec<usize> = (0..m).collect();
let mut num_swaps = 0;
for k in 0..min_dim {
let mut max_val = lu[[k, k]].abs();
let mut max_idx = k;
for i in (k + 1)..m {
let val = lu[[i, k]].abs();
if val > max_val {
max_val = val;
max_idx = i;
}
}
if max_idx != k {
p.swap(k, max_idx);
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_idx, j]];
lu[[max_idx, j]] = tmp;
}
num_swaps += 1;
}
if lu[[k, k]].abs() < T::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during LU decomposition".to_string(),
));
}
for i in (k + 1)..m {
lu[[i, k]] = lu[[i, k]] / lu[[k, k]];
for j in (k + 1)..n {
let lik = lu[[i, k]];
lu[[i, j]] = lu[[i, j]] - lik * lu[[k, j]];
}
}
}
let mut l = Array2::zeros((m, min_dim));
let mut u = Array2::zeros((min_dim, n));
for i in 0..m {
for j in 0..min_dim {
if i == j {
l[[i, j]] = T::one();
} else if i > j {
l[[i, j]] = lu[[i, j]];
}
}
}
for i in 0..min_dim {
for j in i..n {
u[[i, j]] = lu[[i, j]];
}
}
Ok(LuDecomposition {
l,
u,
p: Array1::from_vec(p),
num_swaps,
})
}
fn tiled_lu(&self, a: &ArrayView2<T>) -> LinalgResult<LuDecomposition<T>> {
self.cpu_lu(a)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn qr(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
) -> LinalgResult<QrDecomposition<T>> {
let (m, n) = a.dim();
if m * n < self.config.min_gpu_size {
return self.cpu_qr(a);
}
context.synchronize()?;
let result = self.cpu_qr(a)?;
context.synchronize()?;
Ok(result)
}
pub fn cpu_qr(&self, a: &ArrayView2<T>) -> LinalgResult<QrDecomposition<T>> {
let (m, n) = a.dim();
let min_dim = m.min(n);
let mut r = a.to_owned();
let mut q = Array2::eye(m);
for k in 0..min_dim {
let mut v = Array1::zeros(m - k);
for i in k..m {
v[i - k] = r[[i, k]];
}
let norm_v = v.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt();
if norm_v < T::epsilon() {
continue;
}
let sign = if v[0] >= T::zero() {
T::one()
} else {
-T::one()
};
v[0] += sign * norm_v;
let norm_v_new = v.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt();
if norm_v_new > T::epsilon() {
for val in v.iter_mut() {
*val /= norm_v_new;
}
}
let two = T::from(2.0).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert 2.0".to_string())
})?;
for j in k..n {
let mut dot = T::zero();
for i in k..m {
dot += v[i - k] * r[[i, j]];
}
for i in k..m {
r[[i, j]] -= two * v[i - k] * dot;
}
}
for i in 0..m {
let mut dot = T::zero();
for j in k..m {
dot += q[[i, j]] * v[j - k];
}
for j in k..m {
q[[i, j]] -= two * dot * v[j - k];
}
}
}
let mut r_out = Array2::zeros((min_dim, n));
for i in 0..min_dim {
for j in i..n {
r_out[[i, j]] = r[[i, j]];
}
}
let q_out = q.slice(scirs2_core::ndarray::s![.., 0..min_dim]).to_owned();
Ok(QrDecomposition { q: q_out, r: r_out })
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn cholesky(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
) -> LinalgResult<CholeskyDecomposition<T>> {
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::ShapeError(
"Cholesky decomposition requires a square matrix".to_string(),
));
}
if m * n < self.config.min_gpu_size {
return self.cpu_cholesky(a);
}
context.synchronize()?;
let result = self.cpu_cholesky(a)?;
context.synchronize()?;
Ok(result)
}
pub fn cpu_cholesky(&self, a: &ArrayView2<T>) -> LinalgResult<CholeskyDecomposition<T>> {
let n = a.nrows();
let mut l = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if sum <= T::zero() {
return Err(LinalgError::NonPositiveDefiniteError(
"Matrix is not positive definite during Cholesky decomposition"
.to_string(),
));
}
l[[i, j]] = sum.sqrt();
} else {
if l[[j, j]].abs() < T::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during Cholesky decomposition".to_string(),
));
}
l[[i, j]] = sum / l[[j, j]];
}
}
}
Ok(CholeskyDecomposition { l })
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn svd(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
full_matrices: bool,
) -> LinalgResult<SvdDecomposition<T>> {
let (m, n) = a.dim();
if m * n < self.config.min_gpu_size {
return self.cpu_svd(a, full_matrices);
}
context.synchronize()?;
let result = self.cpu_svd(a, full_matrices)?;
context.synchronize()?;
Ok(result)
}
pub fn cpu_svd(
&self,
a: &ArrayView2<T>,
full_matrices: bool,
) -> LinalgResult<SvdDecomposition<T>> {
let (m, n) = a.dim();
let min_dim = m.min(n);
let mut u = Array2::eye(m);
let mut vt = Array2::eye(n);
let mut s = Array1::zeros(min_dim);
let mut work = a.to_owned();
for k in 0..min_dim {
let mut v: Array1<T> = Array1::from_vec(
(0..n)
.map(|i| if i == k { T::one() } else { T::zero() })
.collect(),
);
let max_iter = 100;
let tol = T::from(self.config.tolerance).unwrap_or_else(|| T::epsilon());
for _ in 0..max_iter {
let mut u_new = Array1::zeros(m);
for i in 0..m {
for j in 0..n {
u_new[i] += work[[i, j]] * v[j];
}
}
let norm_u = u_new.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt();
if norm_u < tol {
break;
}
for val in u_new.iter_mut() {
*val /= norm_u;
}
let mut v_new = Array1::zeros(n);
for j in 0..n {
for i in 0..m {
v_new[j] += work[[i, j]] * u_new[i];
}
}
let norm_v = v_new.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt();
if norm_v < tol {
break;
}
let diff: T = v
.iter()
.zip(v_new.iter())
.map(|(&a, &b)| {
let d = a - b / norm_v;
d * d
})
.fold(T::zero(), |acc, x| acc + x)
.sqrt();
for val in v_new.iter_mut() {
*val /= norm_v;
}
v = v_new;
if diff < tol {
s[k] = norm_v;
for i in 0..m {
u[[i, k]] = u_new[i];
}
for j in 0..n {
vt[[k, j]] = v[j];
}
break;
}
}
for i in 0..m {
for j in 0..n {
work[[i, j]] -= s[k] * u[[i, k]] * vt[[k, j]];
}
}
}
if full_matrices {
Ok(SvdDecomposition { u, s, vt })
} else {
let u_trimmed = u.slice(scirs2_core::ndarray::s![.., 0..min_dim]).to_owned();
let vt_trimmed = vt
.slice(scirs2_core::ndarray::s![0..min_dim, ..])
.to_owned();
Ok(SvdDecomposition {
u: u_trimmed,
s,
vt: vt_trimmed,
})
}
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn eigh(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
) -> LinalgResult<EigenDecomposition<T>> {
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::ShapeError(
"Eigendecomposition requires a square matrix".to_string(),
));
}
if m * n < self.config.min_gpu_size {
return self.cpu_eigh(a);
}
context.synchronize()?;
let result = self.cpu_eigh(a)?;
context.synchronize()?;
Ok(result)
}
pub fn cpu_eigh(&self, a: &ArrayView2<T>) -> LinalgResult<EigenDecomposition<T>> {
let n = a.nrows();
let mut eigenvalues = Array1::zeros(n);
let mut eigenvectors = Array2::eye(n);
let mut work = a.to_owned();
let tol = T::from(self.config.tolerance).unwrap_or_else(|| T::epsilon());
let max_iter = 100;
for k in 0..n {
let mut v = Array1::from_vec(
(0..n)
.map(|i| if i == k { T::one() } else { T::zero() })
.collect(),
);
for _ in 0..max_iter {
let mut v_new = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
v_new[i] += work[[i, j]] * v[j];
}
}
let mut numerator = T::zero();
let mut denominator = T::zero();
for i in 0..n {
numerator += v[i] * v_new[i];
denominator += v[i] * v[i];
}
let eigenvalue = numerator / denominator;
let norm = v_new.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt();
if norm < tol {
break;
}
for val in v_new.iter_mut() {
*val /= norm;
}
let diff = (eigenvalue - eigenvalues[k]).abs();
eigenvalues[k] = eigenvalue;
if diff < tol {
break;
}
v = v_new;
}
for i in 0..n {
eigenvectors[[i, k]] = v[i];
}
for i in 0..n {
for j in 0..n {
work[[i, j]] -= eigenvalues[k] * v[i] * v[j];
}
}
}
Ok(EigenDecomposition {
eigenvalues,
eigenvectors,
})
}
pub fn truncated_svd(&self, a: &ArrayView2<T>, k: usize) -> LinalgResult<SvdDecomposition<T>> {
let (m, n) = a.dim();
let rank = k.min(m).min(n);
if m * n > self.config.min_gpu_size {
return self.randomized_svd(a, rank);
}
let full_svd = self.cpu_svd(a, false)?;
let u_truncated = full_svd
.u
.slice(scirs2_core::ndarray::s![.., 0..rank])
.to_owned();
let s_truncated = full_svd
.s
.slice(scirs2_core::ndarray::s![0..rank])
.to_owned();
let vt_truncated = full_svd
.vt
.slice(scirs2_core::ndarray::s![0..rank, ..])
.to_owned();
Ok(SvdDecomposition {
u: u_truncated,
s: s_truncated,
vt: vt_truncated,
})
}
fn randomized_svd(&self, a: &ArrayView2<T>, k: usize) -> LinalgResult<SvdDecomposition<T>> {
let (m, n) = a.dim();
let l = (k + 10).min(m.min(n));
let omega = self.random_matrix(n, l);
let mut y = Array2::zeros((m, l));
for i in 0..m {
for j in 0..l {
for kk in 0..n {
y[[i, j]] += a[[i, kk]] * omega[[kk, j]];
}
}
}
let qr = self.cpu_qr(&y.view())?;
let q = qr.q;
let q_cols = q.ncols();
let mut b = Array2::zeros((q_cols, n));
for i in 0..q_cols {
for j in 0..n {
for kk in 0..m {
b[[i, j]] += q[[kk, i]] * a[[kk, j]];
}
}
}
let b_svd = self.cpu_svd(&b.view(), false)?;
let u_b_cols = b_svd.u.ncols();
let mut u = Array2::zeros((m, u_b_cols));
for i in 0..m {
for j in 0..u_b_cols {
for kk in 0..q_cols {
u[[i, j]] += q[[i, kk]] * b_svd.u[[kk, j]];
}
}
}
let u_truncated = u.slice(scirs2_core::ndarray::s![.., 0..k]).to_owned();
let s_truncated = b_svd.s.slice(scirs2_core::ndarray::s![0..k]).to_owned();
let vt_truncated = b_svd
.vt
.slice(scirs2_core::ndarray::s![0..k, ..])
.to_owned();
Ok(SvdDecomposition {
u: u_truncated,
s: s_truncated,
vt: vt_truncated,
})
}
fn random_matrix(&self, rows: usize, cols: usize) -> Array2<T> {
let mut result = Array2::zeros((rows, cols));
let mut seed = 42u64;
for i in 0..rows {
for j in 0..cols {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
let val = ((seed >> 16) as f64) / (u32::MAX as f64) - 0.5;
result[[i, j]] = T::from(val).unwrap_or_else(|| T::zero());
}
}
result
}
}
impl<T> Default for GpuDecompositions<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_cpu_lu() {
let decomp = GpuDecompositions::<f64>::new();
let a = array![[4.0, 3.0], [6.0, 3.0]];
let result = decomp.cpu_lu(&a.view()).expect("LU failed");
let l = &result.l;
let u = &result.u;
let mut lu = Array2::<f64>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
lu[[i, j]] += l[[i, k]] * u[[k, j]];
}
}
}
let mut pa = Array2::<f64>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
pa[[i, j]] = a[[result.p[i], j]];
}
}
for i in 0..2 {
for j in 0..2 {
let diff: f64 = pa[[i, j]] - lu[[i, j]];
assert!(diff.abs() < 1e-10);
}
}
}
#[test]
fn test_cpu_qr() {
let decomp = GpuDecompositions::<f64>::new();
let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let result = decomp.cpu_qr(&a.view()).expect("QR failed");
let q = &result.q;
let qt = q.t();
let mut qtq = Array2::<f64>::zeros((q.ncols(), q.ncols()));
for i in 0..q.ncols() {
for j in 0..q.ncols() {
for k in 0..q.nrows() {
qtq[[i, j]] += qt[[i, k]] * q[[k, j]];
}
}
}
for i in 0..q.ncols() {
for j in 0..q.ncols() {
let expected = if i == j { 1.0 } else { 0.0 };
let diff: f64 = qtq[[i, j]] - expected;
assert!(diff.abs() < 1e-10);
}
}
}
#[test]
fn test_cpu_cholesky() {
let decomp = GpuDecompositions::<f64>::new();
let a = array![[4.0, 2.0], [2.0, 3.0]];
let result = decomp.cpu_cholesky(&a.view()).expect("Cholesky failed");
let l = &result.l;
let mut llt = Array2::<f64>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
llt[[i, j]] += l[[i, k]] * l[[j, k]];
}
}
}
for i in 0..2 {
for j in 0..2 {
let diff: f64 = a[[i, j]] - llt[[i, j]];
assert!(diff.abs() < 1e-10);
}
}
}
#[test]
fn test_truncated_svd() {
let decomp = GpuDecompositions::<f64>::new();
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let result = decomp
.truncated_svd(&a.view(), 2)
.expect("Truncated SVD failed");
assert_eq!(result.u.ncols(), 2);
assert_eq!(result.s.len(), 2);
assert_eq!(result.vt.nrows(), 2);
}
}