use std::sync::Arc;
use super::super::client::get_buffer;
use super::super::shaders::linalg as kernels;
use super::super::{WgpuClient, WgpuRuntime};
use super::helpers::get_buffer_or_err;
use crate::algorithm::linalg::{SvdDecomposition, validate_linalg_dtype, validate_matrix_2d};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{
BinaryOps, CompareOps, ConditionalOps, LinalgOps, MatmulOps, ReduceOps, ScalarOps, UnaryOps,
};
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
fn get_tensor_buffer(tensor: &Tensor<WgpuRuntime>) -> Result<Arc<wgpu::Buffer>> {
let ptr = tensor.ptr();
get_buffer(ptr).ok_or_else(|| Error::Internal("Failed to get buffer from tensor".to_string()))
}
pub fn svd_decompose(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
) -> Result<SvdDecomposition<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (m, n) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "WGPU svd_decompose (only F32 supported)",
});
}
let transposed = m < n;
let (work_m, work_n) = if transposed { (n, m) } else { (m, n) };
let k = work_m.min(work_n);
let work_tensor = if transposed {
a.transpose(0, 1)?.contiguous()
} else {
a.contiguous()
};
let b_size = work_m * work_n * dtype.size_in_bytes();
let b_guard = AllocGuard::new(client.allocator(), b_size)?;
let b_ptr = b_guard.ptr();
let b_buffer = get_buffer_or_err!(b_ptr, "B (working matrix)");
let v_size = work_n * work_n * dtype.size_in_bytes();
let v_guard = AllocGuard::new(client.allocator(), v_size)?;
let v_ptr = v_guard.ptr();
let v_buffer = get_buffer_or_err!(v_ptr, "V (right singular vectors)");
let s_size = work_n * dtype.size_in_bytes();
let s_guard = AllocGuard::new(client.allocator(), s_size)?;
let s_ptr = s_guard.ptr();
let s_buffer = get_buffer_or_err!(s_ptr, "S (singular values)");
let converged_flag_size = std::mem::size_of::<i32>();
let converged_flag_guard = AllocGuard::new(client.allocator(), converged_flag_size)?;
let converged_flag_ptr = converged_flag_guard.ptr();
let converged_flag_buffer = get_buffer_or_err!(converged_flag_ptr, "SVD convergence flag");
let work_buffer = get_tensor_buffer(&work_tensor)?;
let mut encoder = client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("svd_copy_input"),
});
encoder.copy_buffer_to_buffer(&work_buffer, 0, &b_buffer, 0, b_size as u64);
client.queue.submit(std::iter::once(encoder.finish()));
let params: [u32; 2] = [work_m as u32, work_n as u32];
let params_buffer = client.create_uniform_buffer("svd_params", 8);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_svd_jacobi(
client.pipeline_cache(),
&client.queue,
&b_buffer,
&v_buffer,
&s_buffer,
&converged_flag_buffer,
¶ms_buffer,
dtype,
)?;
client.synchronize();
drop(converged_flag_guard);
let b_tensor =
unsafe { WgpuClient::tensor_from_raw(b_guard.release(), &[work_m, work_n], dtype, device) };
let v_tensor =
unsafe { WgpuClient::tensor_from_raw(v_guard.release(), &[work_n, work_n], dtype, device) };
let s_tensor =
unsafe { WgpuClient::tensor_from_raw(s_guard.release(), &[work_n], dtype, device) };
let (u, vt) = if transposed {
let u = v_tensor.narrow(0, 0, m)?.narrow(1, 0, k)?.contiguous();
let vt = b_tensor
.narrow(1, 0, k)? .transpose(0, 1)? .contiguous();
(u, vt)
} else {
let u = b_tensor.narrow(1, 0, k)?.contiguous();
let vt = v_tensor
.transpose(0, 1)? .narrow(0, 0, k)? .contiguous();
(u, vt)
};
let s = s_tensor.narrow(0, 0, k)?.contiguous();
Ok(SvdDecomposition { u, s, vt })
}
pub fn pinverse(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
rcond: Option<f64>,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (m, n) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "pinverse",
});
}
if m == 0 || n == 0 {
let out_guard = AllocGuard::new(client.allocator(), 0)?;
return Ok(unsafe {
WgpuClient::tensor_from_raw(out_guard.release(), &[n, m], dtype, device)
});
}
let svd = svd_decompose(client, a)?;
let max_sv_tensor = client.max(&svd.s, &[0], false)?;
let default_rcond = (m.max(n) as f64) * (f32::EPSILON as f64);
let rcond_val = rcond.unwrap_or(default_rcond);
let cutoff_tensor = client.mul_scalar(&max_sv_tensor, rcond_val)?;
let mask = client.gt(&svd.s, &cutoff_tensor)?; let s_reciprocal = client.recip(&svd.s)?; let zero = Tensor::<WgpuRuntime>::from_slice(&[0.0f32], &[], device);
let s_inv = client.where_cond(&mask, &s_reciprocal, &zero)?;
let s_inv_mat = LinalgOps::diagflat(client, &s_inv)?;
let v = svd.vt.transpose(0, 1)?.contiguous();
let ut = svd.u.transpose(0, 1)?.contiguous();
let v_sinv = client.matmul(&v, &s_inv_mat)?;
let pinv = client.matmul(&v_sinv, &ut)?;
Ok(pinv)
}
pub fn cond(client: &WgpuClient, a: &Tensor<WgpuRuntime>) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (m, n) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType { dtype, op: "cond" });
}
if m == 0 || n == 0 {
return Ok(Tensor::<WgpuRuntime>::from_slice(
&[f32::INFINITY],
&[],
device,
));
}
let svd = svd_decompose(client, a)?;
let max_sv_tensor = client.max(&svd.s, &[0], false)?;
let min_sv_tensor = client.min(&svd.s, &[0], false)?;
let ratio = client.div(&max_sv_tensor, &min_sv_tensor)?;
let eps = Tensor::<WgpuRuntime>::from_slice(&[f32::EPSILON], &[], device);
let infinity = Tensor::<WgpuRuntime>::from_slice(&[f32::INFINITY], &[], device);
let valid_mask = client.gt(&min_sv_tensor, &eps)?;
let cond_result = client.where_cond(&valid_mask, &ratio, &infinity)?;
Ok(cond_result)
}