use super::super::CudaRuntime;
use super::super::client::CudaClient;
use super::super::kernels::linalg_launchers::{
compute_schur_func_gpu, launch_validate_eigenvalues,
};
use crate::algorithm::linalg::{
LinearAlgebraAlgorithms, matrix_functions_core, validate_linalg_dtype, validate_square_matrix,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, MatmulOps, ReduceOps, ScalarOps, UnaryOps, UtilityOps};
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
fn get_tensor_ptr(tensor: &Tensor<CudaRuntime>) -> u64 {
tensor.ptr()
}
fn read_scalar_f64(_client: &CudaClient, tensor: &Tensor<CudaRuntime>) -> Result<f64> {
if tensor.numel() != 1 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "read_scalar_f64 requires a single-element tensor".to_string(),
});
}
let tensor = if tensor.is_contiguous() {
tensor.clone()
} else {
tensor.contiguous()
};
let mut result: f64 = 0.0;
unsafe {
cudarc::driver::sys::cuMemcpyDtoH_v2(
&mut result as *mut f64 as *mut std::ffi::c_void,
get_tensor_ptr(&tensor),
std::mem::size_of::<f64>(),
);
}
Ok(result)
}
pub fn expm_impl(client: &CudaClient, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if n == 1 {
return client.exp(a);
}
let schur = client.schur_decompose(a)?;
let f_t = Tensor::<CudaRuntime>::zeros(&[n, n], dtype, device);
let t_ptr = get_tensor_ptr(&schur.t);
let f_ptr = get_tensor_ptr(&f_t);
unsafe {
compute_schur_func_gpu(
client.context(),
client.stream(),
client.device().index,
dtype,
t_ptr,
f_ptr,
n,
"exp",
)?;
}
let temp = client.matmul(&schur.z, &f_t)?;
let z_t = schur.z.transpose(0, 1)?;
client.matmul(&temp, &z_t)
}
pub fn logm_impl(client: &CudaClient, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if n == 1 {
let log_result = client.log(a)?;
let log_scalar = read_scalar_f64(client, &log_result)?;
if log_scalar.is_nan() {
return Err(Error::InvalidArgument {
arg: "a",
reason: "logm requires matrix with no non-positive real eigenvalues".to_string(),
});
}
return Ok(log_result);
}
let eps = match dtype {
DType::F32 => f32::EPSILON as f64,
DType::F64 => f64::EPSILON,
_ => f64::EPSILON,
};
let schur = client.schur_decompose(a)?;
let result_guard = AllocGuard::new(&client.allocator, 2 * dtype.size_in_bytes())?;
let result_buffer = result_guard.ptr();
let zero_data: [f64; 2] = [0.0, 0.0];
unsafe {
cudarc::driver::sys::cuMemcpyHtoDAsync_v2(
result_buffer,
zero_data.as_ptr() as *const std::ffi::c_void,
2 * std::mem::size_of::<f64>(),
client.stream().cu_stream(),
);
}
unsafe {
launch_validate_eigenvalues(
client.context(),
client.stream(),
client.device().index,
dtype,
get_tensor_ptr(&schur.t),
result_buffer,
n,
eps,
"log",
)?;
}
client
.stream()
.synchronize()
.map_err(|e| Error::Internal(format!("CUDA stream synchronize failed: {:?}", e)))?;
let mut result_data: [f64; 2] = [0.0, 0.0];
unsafe {
cudarc::driver::sys::cuMemcpyDtoH_v2(
result_data.as_mut_ptr() as *mut std::ffi::c_void,
result_buffer,
2 * std::mem::size_of::<f64>(),
);
}
if result_data[0] > 0.5 {
return Err(Error::InvalidArgument {
arg: "a",
reason: format!(
"logm requires matrix with no non-positive real eigenvalues, found {}",
result_data[1]
),
});
}
let f_t = Tensor::<CudaRuntime>::zeros(&[n, n], dtype, device);
unsafe {
compute_schur_func_gpu(
client.context(),
client.stream(),
client.device().index,
dtype,
get_tensor_ptr(&schur.t),
get_tensor_ptr(&f_t),
n,
"log",
)?;
}
let temp = client.matmul(&schur.z, &f_t)?;
let z_t = schur.z.transpose(0, 1)?;
client.matmul(&temp, &z_t)
}
pub fn sqrtm_impl(client: &CudaClient, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if n == 1 {
let sqrt_result = client.sqrt(a)?;
let sqrt_scalar = read_scalar_f64(client, &sqrt_result)?;
if sqrt_scalar.is_nan() {
return Err(Error::InvalidArgument {
arg: "a",
reason: "sqrtm requires matrix with no negative real eigenvalues".to_string(),
});
}
return Ok(sqrt_result);
}
let eps = match dtype {
DType::F32 => f32::EPSILON as f64,
DType::F64 => f64::EPSILON,
_ => f64::EPSILON,
};
let schur = client.schur_decompose(a)?;
let result_guard = AllocGuard::new(&client.allocator, 2 * dtype.size_in_bytes())?;
let result_buffer = result_guard.ptr();
let zero_data: [f64; 2] = [0.0, 0.0];
unsafe {
cudarc::driver::sys::cuMemcpyHtoDAsync_v2(
result_buffer,
zero_data.as_ptr() as *const std::ffi::c_void,
2 * std::mem::size_of::<f64>(),
client.stream().cu_stream(),
);
}
unsafe {
launch_validate_eigenvalues(
client.context(),
client.stream(),
client.device().index,
dtype,
get_tensor_ptr(&schur.t),
result_buffer,
n,
eps,
"sqrt",
)?;
}
client
.stream()
.synchronize()
.map_err(|e| Error::Internal(format!("CUDA stream synchronize failed: {:?}", e)))?;
let mut result_data: [f64; 2] = [0.0, 0.0];
unsafe {
cudarc::driver::sys::cuMemcpyDtoH_v2(
result_data.as_mut_ptr() as *mut std::ffi::c_void,
result_buffer,
2 * std::mem::size_of::<f64>(),
);
}
if result_data[0] > 0.5 {
return Err(Error::InvalidArgument {
arg: "a",
reason: format!(
"sqrtm requires matrix with no negative real eigenvalues, found {}",
result_data[1]
),
});
}
let mut y = a.clone();
let mut z = client.eye(n, None, dtype)?;
let max_iter = 50;
let tol = eps * (n as f64);
for _iter in 0..max_iter {
let y_inv = match LinearAlgebraAlgorithms::inverse(client, &y) {
Ok(inv) => inv,
Err(_) => {
return Err(Error::Internal(
"sqrtm: matrix inversion failed during iteration".to_string(),
));
}
};
let z_inv = match LinearAlgebraAlgorithms::inverse(client, &z) {
Ok(inv) => inv,
Err(_) => {
return Err(Error::Internal(
"sqrtm: matrix inversion failed during iteration".to_string(),
));
}
};
let y_plus_zinv = client.add(&y, &z_inv)?;
let y_new = client.div_scalar(&y_plus_zinv, 2.0)?;
let z_plus_yinv = client.add(&z, &y_inv)?;
let z_new = client.div_scalar(&z_plus_yinv, 2.0)?;
let diff = client.sub(&y_new, &y)?;
let diff_sq = client.mul(&diff, &diff)?;
let diff_sum = client.sum(&diff_sq, &[], false)?;
let y_sq = client.mul(&y, &y)?;
let y_sum = client.sum(&y_sq, &[], false)?;
let diff_norm: f64 = {
let sum_val = read_scalar_f64(client, &diff_sum)?;
sum_val.sqrt()
};
let y_norm: f64 = {
let sum_val = read_scalar_f64(client, &y_sum)?;
sum_val.sqrt().max(1.0)
};
y = y_new;
z = z_new;
if diff_norm / y_norm < tol {
break;
}
}
Ok(y)
}
pub fn signm_impl(client: &CudaClient, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
let n = crate::algorithm::linalg::validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if n == 1 {
let abs_a = client.abs(a)?;
let abs_scalar = read_scalar_f64(client, &abs_a)?;
if abs_scalar < f64::EPSILON {
return Err(Error::InvalidArgument {
arg: "a",
reason: "signm requires matrix with no zero eigenvalues".to_string(),
});
}
let sign_result = client.div(a, &abs_a)?;
return Ok(sign_result);
}
let eps = match dtype {
DType::F32 => f32::EPSILON as f64,
DType::F64 => f64::EPSILON,
_ => f64::EPSILON,
};
let mut x = a.clone();
let max_iter = 100;
let tol = eps * (n as f64).sqrt();
for _iter in 0..max_iter {
let x_inv = match LinearAlgebraAlgorithms::inverse(client, &x) {
Ok(inv) => inv,
Err(_) => {
return Err(Error::Internal(
"signm: matrix became singular during iteration".to_string(),
));
}
};
let x_plus_inv = client.add(&x, &x_inv)?;
let x_new = client.div_scalar(&x_plus_inv, 2.0)?;
let diff = client.sub(&x_new, &x)?;
let diff_sq = client.mul(&diff, &diff)?;
let diff_sum = client.sum(&diff_sq, &[], false)?;
let diff_norm: f64 = {
let sum_val = read_scalar_f64(client, &diff_sum)?;
sum_val.sqrt()
};
x = x_new;
if diff_norm < tol {
break;
}
}
Ok(x)
}
pub fn fractional_matrix_power_impl(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
p: f64,
) -> Result<Tensor<CudaRuntime>> {
let n = crate::algorithm::linalg::validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if p.abs() < f64::EPSILON {
return client.eye(n, None, dtype);
}
if (p - 1.0).abs() < f64::EPSILON {
return Ok(a.clone());
}
if n == 1 {
let log_a = logm_impl(client, a)?;
let p_log_a = client.mul_scalar(&log_a, p)?;
return expm_impl(client, &p_log_a);
}
if (p + 1.0).abs() < f64::EPSILON {
return LinearAlgebraAlgorithms::inverse(client, a);
}
if (p - 0.5).abs() < f64::EPSILON {
return sqrtm_impl(client, a);
}
if p.fract() == 0.0 && p.abs() < 100.0 {
return integer_matrix_power_gpu(client, a, n, p as i64);
}
let log_a = logm_impl(client, a)?;
let p_log_a = client.mul_scalar(&log_a, p)?;
expm_impl(client, &p_log_a)
}
fn integer_matrix_power_gpu(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
n: usize,
p: i64,
) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
if p == 0 {
return client.eye(n, None, dtype);
}
let (mut base, mut exp) = if p < 0 {
let inv = LinearAlgebraAlgorithms::inverse(client, a)?;
(inv, (-p) as u64)
} else {
(a.clone(), p as u64)
};
let mut result = client.eye(n, None, dtype)?;
while exp > 0 {
if exp & 1 == 1 {
result = client.matmul(&result, &base)?;
}
base = client.matmul(&base, &base)?;
exp >>= 1;
}
Ok(result)
}
pub fn funm_impl<F>(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
f: F,
) -> Result<Tensor<CudaRuntime>>
where
F: Fn(f64) -> f64,
{
let n = crate::algorithm::linalg::validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device));
}
if n == 1 {
let data: Vec<f64> = a.to_vec();
let val = data[0];
let result = f(val);
if result.is_nan() || result.is_infinite() {
return Err(Error::InvalidArgument {
arg: "f",
reason: "function returned NaN or Inf for eigenvalue".to_string(),
});
}
return Ok(Tensor::<CudaRuntime>::full_scalar(
&[1, 1],
dtype,
result,
device,
));
}
let schur = client.schur_decompose(a)?;
let t_data: Vec<f64> = schur.t.to_vec();
let z_data: Vec<f64> = schur.z.to_vec();
let f_t = matrix_functions_core::funm_quasi_triangular_f64(&t_data, n, &f)?;
let f_t_tensor = Tensor::<CudaRuntime>::from_slice(&f_t, &[n, n], device);
let z_tensor = Tensor::<CudaRuntime>::from_slice(&z_data, &[n, n], device);
let temp = client.matmul(&z_tensor, &f_t_tensor)?;
let z_t = z_tensor.transpose(0, 1)?;
client.matmul(&temp, &z_t)
}