use super::super::CudaRuntime;
use super::super::client::CudaClient;
use crate::algorithm::linalg::LinearAlgebraAlgorithms;
use crate::ops::MatmulOps;
use crate::runtime::cuda::{CudaDevice, is_cuda_available};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
fn create_client() -> Option<CudaClient> {
if !is_cuda_available() {
return None;
}
let device = CudaDevice::new(0);
Some(CudaRuntime::default_client(&device))
}
#[test]
fn test_trace() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], device);
let t = LinearAlgebraAlgorithms::trace(&client, &a).unwrap();
let result: Vec<f32> = t.to_vec();
assert!((result[0] - 5.0).abs() < 1e-5);
}
#[test]
fn test_diag() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], device);
let d = LinearAlgebraAlgorithms::diag(&client, &a).unwrap();
let result: Vec<f32> = d.to_vec();
assert_eq!(result.len(), 2);
assert!((result[0] - 1.0).abs() < 1e-5);
assert!((result[1] - 5.0).abs() < 1e-5);
}
#[test]
fn test_diagflat() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], device);
let m = LinearAlgebraAlgorithms::diagflat(&client, &a).unwrap();
let result: Vec<f32> = m.to_vec();
assert_eq!(m.shape(), &[3, 3]);
assert!((result[0] - 1.0).abs() < 1e-5); assert!((result[1]).abs() < 1e-5); assert!((result[4] - 2.0).abs() < 1e-5); assert!((result[8] - 3.0).abs() < 1e-5); }
#[test]
fn test_lu_decomposition() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[4.0f32, 3.0, 6.0, 3.0], &[2, 2], device);
let lu = client.lu_decompose(&a).unwrap();
assert_eq!(lu.lu.shape(), &[2, 2]);
assert_eq!(lu.pivots.shape(), &[2]);
}
#[test]
fn test_cholesky() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[4.0f32, 2.0, 2.0, 5.0], &[2, 2], device);
let chol = client.cholesky_decompose(&a).unwrap();
assert_eq!(chol.l.shape(), &[2, 2]);
let l_data: Vec<f32> = chol.l.to_vec();
assert!((l_data[1]).abs() < 1e-5); }
#[test]
fn test_det() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], device);
let d = LinearAlgebraAlgorithms::det(&client, &a).unwrap();
let result: Vec<f32> = d.to_vec();
assert!((result[0] - (-2.0)).abs() < 1e-4);
}
#[test]
fn test_solve() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[2.0f32, 1.0, 1.0, 2.0], &[2, 2], device);
let b = Tensor::<CudaRuntime>::from_slice(&[3.0f32, 3.0], &[2], device);
let x = LinearAlgebraAlgorithms::solve(&client, &a, &b).unwrap();
let result: Vec<f32> = x.to_vec();
assert!((result[0] - 1.0).abs() < 1e-4);
assert!((result[1] - 1.0).abs() < 1e-4);
}
#[test]
fn test_inverse() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[4.0f32, 7.0, 2.0, 6.0], &[2, 2], device);
let inv = LinearAlgebraAlgorithms::inverse(&client, &a).unwrap();
let result: Vec<f32> = inv.to_vec();
assert!((result[0] - 0.6).abs() < 1e-4); assert!((result[1] - (-0.7)).abs() < 1e-4); assert!((result[2] - (-0.2)).abs() < 1e-4); assert!((result[3] - 0.4).abs() < 1e-4); }
#[test]
fn test_inverse_identity() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[4.0f32, 7.0, 2.0, 6.0], &[2, 2], device);
let inv = LinearAlgebraAlgorithms::inverse(&client, &a).unwrap();
let product = client.matmul(&a, &inv).unwrap();
let result: Vec<f32> = product.to_vec();
assert!((result[0] - 1.0).abs() < 1e-4); assert!((result[1]).abs() < 1e-4); assert!((result[2]).abs() < 1e-4); assert!((result[3] - 1.0).abs() < 1e-4); }
#[test]
fn test_matrix_rank_full() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], device);
let rank = LinearAlgebraAlgorithms::matrix_rank(&client, &a, None).unwrap();
let result: Vec<i64> = rank.to_vec();
assert_eq!(result[0], 2);
}
#[test]
fn test_matrix_rank_deficient() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 2.0, 4.0], &[2, 2], device);
let rank = LinearAlgebraAlgorithms::matrix_rank(&client, &a, None).unwrap();
let result: Vec<i64> = rank.to_vec();
assert_eq!(result[0], 1);
}
#[test]
fn test_qr_decomposition() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], device);
let qr = client.qr_decompose(&a).unwrap();
let reconstructed = client.matmul(&qr.q, &qr.r).unwrap();
let a_data: Vec<f32> = a.to_vec();
let reconstructed_data: Vec<f32> = reconstructed.to_vec();
for i in 0..4 {
assert!(
(a_data[i] - reconstructed_data[i]).abs() < 1e-4,
"Mismatch at {}: {} vs {}",
i,
a_data[i],
reconstructed_data[i]
);
}
}
#[test]
fn test_solve_multi_rhs() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[2.0f32, 1.0, 1.0, 2.0], &[2, 2], device);
let b = Tensor::<CudaRuntime>::from_slice(&[3.0f32, 4.0, 3.0, 5.0], &[2, 2], device);
let x = LinearAlgebraAlgorithms::solve(&client, &a, &b).unwrap();
assert_eq!(x.shape(), &[2, 2]);
let result: Vec<f32> = x.to_vec();
assert!(
(result[0] - 1.0).abs() < 1e-4,
"X[0,0] = {} expected 1",
result[0]
);
assert!(
(result[1] - 1.0).abs() < 1e-4,
"X[0,1] = {} expected 1",
result[1]
);
assert!(
(result[2] - 1.0).abs() < 1e-4,
"X[1,0] = {} expected 1",
result[2]
);
assert!(
(result[3] - 2.0).abs() < 1e-4,
"X[1,1] = {} expected 2",
result[3]
);
}
#[test]
fn test_lstsq_overdetermined() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 2.0, 1.0, 3.0], &[3, 2], device);
let b = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], device);
let x = LinearAlgebraAlgorithms::lstsq(&client, &a, &b).unwrap();
assert_eq!(x.shape(), &[2]);
let result: Vec<f32> = x.to_vec();
assert!((result[0]).abs() < 0.1, "x[0] = {} expected ~0", result[0]);
assert!(
(result[1] - 1.0).abs() < 0.1,
"x[1] = {} expected ~1",
result[1]
);
}
#[test]
fn test_lstsq_multi_rhs() {
let Some(client) = create_client() else {
return;
};
let device = client.device();
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 2.0, 1.0, 3.0], &[3, 2], device);
let b = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 2.0, 4.0, 3.0, 6.0], &[3, 2], device);
let x = LinearAlgebraAlgorithms::lstsq(&client, &a, &b).unwrap();
assert_eq!(x.shape(), &[2, 2]);
let result: Vec<f32> = x.to_vec();
assert!(
(result[0]).abs() < 0.1,
"X[0,0] = {} expected ~0",
result[0]
);
assert!(
(result[1]).abs() < 0.1,
"X[0,1] = {} expected ~0",
result[1]
);
assert!(
(result[2] - 1.0).abs() < 0.1,
"X[1,0] = {} expected ~1",
result[2]
);
assert!(
(result[3] - 2.0).abs() < 0.1,
"X[1,1] = {} expected ~2",
result[3]
);
}