use super::super::CudaRuntime;
use super::super::client::CudaClient;
use super::super::kernels;
use crate::algorithm::linalg::{SchurDecomposition, validate_linalg_dtype, validate_square_matrix};
use crate::error::Result;
use crate::runtime::{AllocGuard, Allocator, Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn schur_decompose_impl(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
) -> Result<SchurDecomposition<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
let z_ptr = client.allocator().allocate(0)?;
let t_ptr = client.allocator().allocate(0)?;
let z = unsafe { CudaClient::tensor_from_raw(z_ptr, &[0, 0], dtype, device) };
let t = unsafe { CudaClient::tensor_from_raw(t_ptr, &[0, 0], dtype, device) };
return Ok(SchurDecomposition { z, t });
}
let matrix_size = n * n * dtype.size_in_bytes();
let flag_size = std::mem::size_of::<i32>();
let t_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let flag_guard = AllocGuard::new(client.allocator(), flag_size)?;
let t_ptr = t_guard.ptr();
let z_ptr = z_guard.ptr();
let flag_ptr = flag_guard.ptr();
CudaRuntime::copy_within_device(a.ptr(), t_ptr, matrix_size, device)?;
let zero_flag = [0i32];
CudaRuntime::copy_to_device(bytemuck::cast_slice(&zero_flag), flag_ptr, device)?;
let result = unsafe {
kernels::launch_schur_decompose(
client.context(),
client.stream(),
device.index,
dtype,
t_ptr,
z_ptr,
flag_ptr,
n,
)
};
result?;
client.synchronize();
let z = unsafe { CudaClient::tensor_from_raw(z_guard.release(), &[n, n], dtype, device) };
let t = unsafe { CudaClient::tensor_from_raw(t_guard.release(), &[n, n], dtype, device) };
Ok(SchurDecomposition { z, t })
}