use super::super::CudaRuntime;
use super::super::client::CudaClient;
use super::super::kernels;
use crate::algorithm::linalg::{EigenDecomposition, validate_linalg_dtype, validate_square_matrix};
use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::{AllocGuard, Allocator, Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn eig_decompose_symmetric_impl(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
) -> Result<EigenDecomposition<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
let eigenvalues_ptr = client.allocator().allocate(0)?;
let eigenvectors_ptr = client.allocator().allocate(0)?;
let eigenvalues =
unsafe { CudaClient::tensor_from_raw(eigenvalues_ptr, &[0], dtype, device) };
let eigenvectors =
unsafe { CudaClient::tensor_from_raw(eigenvectors_ptr, &[0, 0], dtype, device) };
return Ok(EigenDecomposition {
eigenvalues,
eigenvectors,
});
}
if n == 1 {
let eigenvalues_size = dtype.size_in_bytes();
let eigenvectors_size = dtype.size_in_bytes();
let eigenvalues_ptr = client.allocator().allocate(eigenvalues_size)?;
let eigenvectors_ptr = client.allocator().allocate(eigenvectors_size)?;
CudaRuntime::copy_within_device(a.ptr(), eigenvalues_ptr, eigenvalues_size, device)?;
match dtype {
DType::F32 => {
let one: [u8; 4] = 1.0f32.to_ne_bytes();
CudaRuntime::copy_to_device(&one, eigenvectors_ptr, device)?;
}
DType::F64 => {
let one: [u8; 8] = 1.0f64.to_ne_bytes();
CudaRuntime::copy_to_device(&one, eigenvectors_ptr, device)?;
}
_ => unreachable!(),
}
let eigenvalues =
unsafe { CudaClient::tensor_from_raw(eigenvalues_ptr, &[1], dtype, device) };
let eigenvectors =
unsafe { CudaClient::tensor_from_raw(eigenvectors_ptr, &[1, 1], dtype, device) };
return Ok(EigenDecomposition {
eigenvalues,
eigenvectors,
});
}
let elem_size = dtype.size_in_bytes();
let work_size = n * n * elem_size;
let eigenvectors_size = n * n * elem_size;
let eigenvalues_size = n * elem_size;
let flag_size = std::mem::size_of::<i32>();
let work_guard = AllocGuard::new(client.allocator(), work_size)?;
let eigenvectors_guard = AllocGuard::new(client.allocator(), eigenvectors_size)?;
let eigenvalues_guard = AllocGuard::new(client.allocator(), eigenvalues_size)?;
let converged_flag_guard = AllocGuard::new(client.allocator(), flag_size)?;
let work_ptr = work_guard.ptr();
let eigenvectors_ptr = eigenvectors_guard.ptr();
let eigenvalues_ptr = eigenvalues_guard.ptr();
let converged_flag_ptr = converged_flag_guard.ptr();
CudaRuntime::copy_within_device(a.ptr(), work_ptr, work_size, device)?;
let zero_i32: [u8; 4] = [0; 4];
CudaRuntime::copy_to_device(&zero_i32, converged_flag_ptr, device)?;
let result = unsafe {
kernels::launch_eig_jacobi_symmetric(
client.context(),
client.stream(),
device.index,
dtype,
work_ptr,
eigenvectors_ptr,
eigenvalues_ptr,
converged_flag_ptr,
n,
)
};
result?;
client.synchronize();
let abs_eigenvalues_size = n * elem_size;
let abs_eigenvalues_guard = AllocGuard::new(client.allocator(), abs_eigenvalues_size)?;
let abs_eigenvalues_ptr = abs_eigenvalues_guard.ptr();
let abs_result = unsafe {
kernels::launch_unary_op(
client.context(),
client.stream(),
device.index,
"abs",
dtype,
eigenvalues_ptr,
abs_eigenvalues_ptr,
n,
)
};
abs_result?;
let indices_size = n * std::mem::size_of::<i64>();
let indices_guard = AllocGuard::new(client.allocator(), indices_size)?;
let indices_ptr = indices_guard.ptr();
let argsort_result = unsafe {
kernels::launch_argsort(
client.context(),
client.stream(),
device.index,
dtype,
abs_eigenvalues_ptr, indices_ptr, 1, n, 1, true, )
};
argsort_result?;
let eigenvalues_sorted_size = n * elem_size;
let eigenvalues_sorted_guard = AllocGuard::new(client.allocator(), eigenvalues_sorted_size)?;
let eigenvalues_sorted_ptr = eigenvalues_sorted_guard.ptr();
let eigenvalues_select_result = unsafe {
kernels::launch_index_select(
client.context(),
client.stream(),
device.index,
dtype,
eigenvalues_ptr, indices_ptr, eigenvalues_sorted_ptr, 1, n, 1, n, )
};
eigenvalues_select_result?;
let eigenvectors_sorted_size = n * n * elem_size;
let eigenvectors_sorted_guard = AllocGuard::new(client.allocator(), eigenvectors_sorted_size)?;
let eigenvectors_sorted_ptr = eigenvectors_sorted_guard.ptr();
let eigenvectors_select_result = unsafe {
kernels::launch_index_select(
client.context(),
client.stream(),
device.index,
dtype,
eigenvectors_ptr, indices_ptr, eigenvectors_sorted_ptr, n, n, 1, n, )
};
eigenvectors_select_result?;
let eigenvalues = unsafe {
CudaClient::tensor_from_raw(eigenvalues_sorted_guard.release(), &[n], dtype, device)
};
let eigenvectors = unsafe {
CudaClient::tensor_from_raw(eigenvectors_sorted_guard.release(), &[n, n], dtype, device)
};
Ok(EigenDecomposition {
eigenvalues,
eigenvectors,
})
}