use super::super::CudaRuntime;
use super::super::client::CudaClient;
use super::super::kernels;
use crate::algorithm::linalg::{
ComplexSchurDecomposition, GeneralizedSchurDecomposition, LinearAlgebraAlgorithms,
PolarDecomposition, SchurDecomposition, validate_linalg_dtype, validate_square_matrix,
};
use crate::error::Result;
use crate::ops::{LinalgOps, MatmulOps};
use crate::runtime::{AllocGuard, Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn rsf2csf_impl(
client: &CudaClient,
schur: &SchurDecomposition<CudaRuntime>,
) -> Result<ComplexSchurDecomposition<CudaRuntime>> {
validate_linalg_dtype(schur.t.dtype())?;
let n = validate_square_matrix(schur.t.shape())?;
let dtype = schur.t.dtype();
let device = client.device();
if n == 0 {
return Ok(ComplexSchurDecomposition {
z_real: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
z_imag: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
t_real: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
t_imag: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
});
}
let matrix_size = n * n * dtype.size_in_bytes();
let z_real_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_imag_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_real_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_imag_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_real_ptr = z_real_guard.ptr();
let z_imag_ptr = z_imag_guard.ptr();
let t_real_ptr = t_real_guard.ptr();
let t_imag_ptr = t_imag_guard.ptr();
let result = unsafe {
kernels::launch_rsf2csf(
client.context(),
client.stream(),
device.index,
dtype,
schur.z.ptr(),
schur.t.ptr(),
z_real_ptr,
z_imag_ptr,
t_real_ptr,
t_imag_ptr,
n,
)
};
result?;
client.synchronize();
let z_real =
unsafe { CudaClient::tensor_from_raw(z_real_guard.release(), &[n, n], dtype, device) };
let z_imag =
unsafe { CudaClient::tensor_from_raw(z_imag_guard.release(), &[n, n], dtype, device) };
let t_real =
unsafe { CudaClient::tensor_from_raw(t_real_guard.release(), &[n, n], dtype, device) };
let t_imag =
unsafe { CudaClient::tensor_from_raw(t_imag_guard.release(), &[n, n], dtype, device) };
Ok(ComplexSchurDecomposition {
z_real,
z_imag,
t_real,
t_imag,
})
}
pub fn qz_decompose_impl(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
) -> Result<GeneralizedSchurDecomposition<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(GeneralizedSchurDecomposition {
q: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
z: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
s: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
t: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
eigenvalues_real: Tensor::<CudaRuntime>::zeros(&[0], dtype, device),
eigenvalues_imag: Tensor::<CudaRuntime>::zeros(&[0], dtype, device),
});
}
let matrix_size = n * n * dtype.size_in_bytes();
let vector_size = n * dtype.size_in_bytes();
let flag_size = std::mem::size_of::<i32>();
let q_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let s_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let eig_real_guard = AllocGuard::new(client.allocator(), vector_size)?;
let eig_imag_guard = AllocGuard::new(client.allocator(), vector_size)?;
let flag_guard = AllocGuard::new(client.allocator(), flag_size)?;
let q_ptr = q_guard.ptr();
let z_ptr = z_guard.ptr();
let s_ptr = s_guard.ptr();
let t_ptr = t_guard.ptr();
let eig_real_ptr = eig_real_guard.ptr();
let eig_imag_ptr = eig_imag_guard.ptr();
let flag_ptr = flag_guard.ptr();
CudaRuntime::copy_within_device(a.ptr(), s_ptr, matrix_size, device)?;
CudaRuntime::copy_within_device(b.ptr(), t_ptr, matrix_size, device)?;
let zero_flag = [0i32];
CudaRuntime::copy_to_device(bytemuck::cast_slice(&zero_flag), flag_ptr, device)?;
let result = unsafe {
kernels::launch_qz_decompose(
client.context(),
client.stream(),
device.index,
dtype,
s_ptr,
t_ptr,
q_ptr,
z_ptr,
eig_real_ptr,
eig_imag_ptr,
flag_ptr,
n,
)
};
result?;
client.synchronize();
let q = unsafe { CudaClient::tensor_from_raw(q_guard.release(), &[n, n], dtype, device) };
let z = unsafe { CudaClient::tensor_from_raw(z_guard.release(), &[n, n], dtype, device) };
let s = unsafe { CudaClient::tensor_from_raw(s_guard.release(), &[n, n], dtype, device) };
let t = unsafe { CudaClient::tensor_from_raw(t_guard.release(), &[n, n], dtype, device) };
let eigenvalues_real =
unsafe { CudaClient::tensor_from_raw(eig_real_guard.release(), &[n], dtype, device) };
let eigenvalues_imag =
unsafe { CudaClient::tensor_from_raw(eig_imag_guard.release(), &[n], dtype, device) };
Ok(GeneralizedSchurDecomposition {
q,
z,
s,
t,
eigenvalues_real,
eigenvalues_imag,
})
}
pub fn polar_decompose_impl(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
) -> Result<PolarDecomposition<CudaRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if n == 0 {
return Ok(PolarDecomposition {
u: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
p: Tensor::<CudaRuntime>::zeros(&[0, 0], dtype, device),
});
}
let svd = client.svd_decompose(a)?;
let v = svd.vt.transpose(0, 1)?.contiguous();
let u_polar = client.matmul(&svd.u, &v)?;
let s_diag = LinalgOps::diagflat(client, &svd.s)?;
let temp = client.matmul(&v, &s_diag)?;
let p = client.matmul(&temp, &svd.vt)?;
Ok(PolarDecomposition { u: u_polar, p })
}