use super::super::client::get_buffer;
use super::super::shaders::linalg as kernels;
use super::super::{WgpuClient, WgpuRuntime};
use super::helpers::get_buffer_or_err;
use crate::algorithm::linalg::{
ComplexSchurDecomposition, GeneralizedSchurDecomposition, LinearAlgebraAlgorithms,
PolarDecomposition, SchurDecomposition, validate_linalg_dtype, validate_square_matrix,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{LinalgOps, MatmulOps};
use crate::runtime::{AllocGuard, Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn rsf2csf(
client: &WgpuClient,
schur: &SchurDecomposition<WgpuRuntime>,
) -> Result<ComplexSchurDecomposition<WgpuRuntime>> {
let n = validate_square_matrix(schur.t.shape())?;
let dtype = schur.t.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "rsf2csf (WebGPU)",
});
}
if n == 0 {
return Ok(ComplexSchurDecomposition {
z_real: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
z_imag: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
t_real: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
t_imag: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
});
}
if n == 1 {
let elem = dtype.size_in_bytes();
let t_real_guard = AllocGuard::new(client.allocator(), elem)?;
let t_real_ptr = t_real_guard.ptr();
WgpuRuntime::copy_within_device(schur.t.ptr(), t_real_ptr, elem, device)?;
let z_real_guard = AllocGuard::new(client.allocator(), elem)?;
let z_real_ptr = z_real_guard.ptr();
WgpuRuntime::copy_within_device(schur.z.ptr(), z_real_ptr, elem, device)?;
return Ok(ComplexSchurDecomposition {
z_real: unsafe {
WgpuClient::tensor_from_raw(z_real_guard.release(), &[1, 1], dtype, device)
},
z_imag: Tensor::<WgpuRuntime>::from_slice(&[0.0f32], &[1, 1], device),
t_real: unsafe {
WgpuClient::tensor_from_raw(t_real_guard.release(), &[1, 1], dtype, device)
},
t_imag: Tensor::<WgpuRuntime>::from_slice(&[0.0f32], &[1, 1], device),
});
}
let matrix_size = n * n * dtype.size_in_bytes();
let t_real_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_real_ptr = t_real_guard.ptr();
let t_real_buffer = get_buffer_or_err!(t_real_ptr, "T_real");
let t_imag_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_imag_ptr = t_imag_guard.ptr();
let t_imag_buffer = get_buffer_or_err!(t_imag_ptr, "T_imag");
let z_real_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_real_ptr = z_real_guard.ptr();
let z_real_buffer = get_buffer_or_err!(z_real_ptr, "Z_real");
let z_imag_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_imag_ptr = z_imag_guard.ptr();
let z_imag_buffer = get_buffer_or_err!(z_imag_ptr, "Z_imag");
WgpuRuntime::copy_within_device(schur.t.ptr(), t_real_ptr, matrix_size, device)?;
WgpuRuntime::copy_within_device(schur.z.ptr(), z_real_ptr, matrix_size, device)?;
let zeros = vec![0.0f32; n * n];
WgpuRuntime::copy_to_device(bytemuck::cast_slice(&zeros), t_imag_ptr, device)?;
WgpuRuntime::copy_to_device(bytemuck::cast_slice(&zeros), z_imag_ptr, device)?;
let params: [u32; 1] = [n as u32];
let params_buffer = client.create_uniform_buffer("rsf2csf_params", 4);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_rsf2csf(
client.pipeline_cache(),
&client.queue,
&t_real_buffer,
&t_imag_buffer,
&z_real_buffer,
&z_imag_buffer,
¶ms_buffer,
dtype,
)?;
client.synchronize();
let z_real =
unsafe { WgpuClient::tensor_from_raw(z_real_guard.release(), &[n, n], dtype, device) };
let z_imag =
unsafe { WgpuClient::tensor_from_raw(z_imag_guard.release(), &[n, n], dtype, device) };
let t_real =
unsafe { WgpuClient::tensor_from_raw(t_real_guard.release(), &[n, n], dtype, device) };
let t_imag =
unsafe { WgpuClient::tensor_from_raw(t_imag_guard.release(), &[n, n], dtype, device) };
Ok(ComplexSchurDecomposition {
z_real,
z_imag,
t_real,
t_imag,
})
}
pub fn qz_decompose(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<GeneralizedSchurDecomposition<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let n_b = validate_square_matrix(b.shape())?;
if n != n_b {
return Err(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "qz_decompose (WebGPU)",
});
}
if n == 0 {
return Ok(GeneralizedSchurDecomposition {
q: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
z: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
s: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
t: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
eigenvalues_real: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0], device),
eigenvalues_imag: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0], device),
});
}
if n == 1 {
let elem = dtype.size_in_bytes();
let s_guard = AllocGuard::new(client.allocator(), elem)?;
let s_ptr = s_guard.ptr();
WgpuRuntime::copy_within_device(a.ptr(), s_ptr, elem, device)?;
let t_guard = AllocGuard::new(client.allocator(), elem)?;
let t_ptr = t_guard.ptr();
WgpuRuntime::copy_within_device(b.ptr(), t_ptr, elem, device)?;
let s_tensor =
unsafe { WgpuClient::tensor_from_raw(s_guard.release(), &[1], dtype, device) };
let t_tensor =
unsafe { WgpuClient::tensor_from_raw(t_guard.release(), &[1], dtype, device) };
use crate::ops::BinaryOps;
let eigenvalues_real = client.div(&s_tensor, &t_tensor)?;
let s_out = s_tensor.reshape(&[1, 1])?;
let t_out = t_tensor.reshape(&[1, 1])?;
return Ok(GeneralizedSchurDecomposition {
q: Tensor::<WgpuRuntime>::from_slice(&[1.0f32], &[1, 1], device),
z: Tensor::<WgpuRuntime>::from_slice(&[1.0f32], &[1, 1], device),
s: s_out,
t: t_out,
eigenvalues_real,
eigenvalues_imag: Tensor::<WgpuRuntime>::from_slice(&[0.0f32], &[1], device),
});
}
let matrix_size = n * n * dtype.size_in_bytes();
let vector_size = n * dtype.size_in_bytes();
let s_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let s_ptr = s_guard.ptr();
let s_buffer = get_buffer_or_err!(s_ptr, "S");
let t_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let t_ptr = t_guard.ptr();
let t_buffer = get_buffer_or_err!(t_ptr, "T");
let q_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let q_ptr = q_guard.ptr();
let q_buffer = get_buffer_or_err!(q_ptr, "Q");
let z_guard = AllocGuard::new(client.allocator(), matrix_size)?;
let z_ptr = z_guard.ptr();
let z_buffer = get_buffer_or_err!(z_ptr, "Z");
let eval_real_guard = AllocGuard::new(client.allocator(), vector_size)?;
let eval_real_ptr = eval_real_guard.ptr();
let eval_real_buffer = get_buffer_or_err!(eval_real_ptr, "eigenvalues_real");
let eval_imag_guard = AllocGuard::new(client.allocator(), vector_size)?;
let eval_imag_ptr = eval_imag_guard.ptr();
let eval_imag_buffer = get_buffer_or_err!(eval_imag_ptr, "eigenvalues_imag");
let converged_flag_size = std::mem::size_of::<i32>();
let converged_flag_guard = AllocGuard::new(client.allocator(), converged_flag_size)?;
let converged_flag_ptr = converged_flag_guard.ptr();
let converged_flag_buffer = get_buffer_or_err!(converged_flag_ptr, "QZ convergence flag");
WgpuRuntime::copy_within_device(a.ptr(), s_ptr, matrix_size, device)?;
WgpuRuntime::copy_within_device(b.ptr(), t_ptr, matrix_size, device)?;
let zero_i32: [i32; 1] = [0];
client.write_buffer(&converged_flag_buffer, &zero_i32);
let params: [u32; 1] = [n as u32];
let params_buffer = client.create_uniform_buffer("qz_params", 4);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_qz_decompose(
client.pipeline_cache(),
&client.queue,
&s_buffer,
&t_buffer,
&q_buffer,
&z_buffer,
&eval_real_buffer,
&eval_imag_buffer,
&converged_flag_buffer,
¶ms_buffer,
dtype,
)?;
client.synchronize();
drop(converged_flag_guard);
let q = unsafe { WgpuClient::tensor_from_raw(q_guard.release(), &[n, n], dtype, device) };
let z = unsafe { WgpuClient::tensor_from_raw(z_guard.release(), &[n, n], dtype, device) };
let s = unsafe { WgpuClient::tensor_from_raw(s_guard.release(), &[n, n], dtype, device) };
let t = unsafe { WgpuClient::tensor_from_raw(t_guard.release(), &[n, n], dtype, device) };
let eigenvalues_real =
unsafe { WgpuClient::tensor_from_raw(eval_real_guard.release(), &[n], dtype, device) };
let eigenvalues_imag =
unsafe { WgpuClient::tensor_from_raw(eval_imag_guard.release(), &[n], dtype, device) };
Ok(GeneralizedSchurDecomposition {
q,
z,
s,
t,
eigenvalues_real,
eigenvalues_imag,
})
}
pub fn polar_decompose(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
) -> Result<PolarDecomposition<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "polar_decompose (WebGPU)",
});
}
if n == 0 {
return Ok(PolarDecomposition {
u: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
p: Tensor::<WgpuRuntime>::from_slice::<f32>(&[], &[0, 0], device),
});
}
if n == 1 {
use crate::ops::{BinaryOps, UnaryOps};
let a_abs = client.abs(a)?;
let sign = client.div(a, &a_abs)?;
let u = sign.reshape(&[1, 1])?;
let p = a_abs.reshape(&[1, 1])?;
return Ok(PolarDecomposition { u, p });
}
let svd = client.svd_decompose(a)?;
let u_polar = client.matmul(&svd.u, &svd.vt)?;
let s_diag = LinalgOps::diagflat(client, &svd.s)?;
let v = svd.vt.transpose(0, 1)?.contiguous();
let v_s = client.matmul(&v, &s_diag)?;
let p = client.matmul(&v_s, &svd.vt)?;
Ok(PolarDecomposition { u: u_polar, p })
}