use wgpu::CommandEncoderDescriptor;
use super::super::client::get_buffer;
use super::super::shaders::linalg as kernels;
use super::super::{WgpuClient, WgpuRuntime};
use crate::algorithm::linalg::{validate_linalg_dtype, validate_square_matrix};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
pub use super::lstsq::lstsq;
pub use super::triangular_solve::{solve_triangular_lower, solve_triangular_upper};
pub fn solve(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "WGPU solve (only F32 supported)",
});
}
let b_shape = b.shape();
let (b_rows, num_rhs) = if b_shape.len() == 1 {
(b_shape[0], 1)
} else if b_shape.len() == 2 {
(b_shape[0], b_shape[1])
} else {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: b_shape.to_vec(),
});
};
if b_rows != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: vec![b_rows],
});
}
use super::decompositions::lu_decompose;
let lu_result = lu_decompose(client, a)?;
let lu_buffer = get_buffer(lu_result.lu.ptr())
.ok_or_else(|| Error::Internal("Failed to get lu buffer".to_string()))?;
let pivots_buffer = get_buffer(lu_result.pivots.ptr())
.ok_or_else(|| Error::Internal("Failed to get pivots buffer".to_string()))?;
let col_size = n * dtype.size_in_bytes();
let pb_guard = AllocGuard::new(client.allocator(), col_size)?;
let pb_ptr = pb_guard.ptr();
let pb_buffer =
get_buffer(pb_ptr).ok_or_else(|| Error::Internal("Failed to get pb buffer".to_string()))?;
let y_guard = AllocGuard::new(client.allocator(), col_size)?;
let y_ptr = y_guard.ptr();
let y_buffer =
get_buffer(y_ptr).ok_or_else(|| Error::Internal("Failed to get y buffer".to_string()))?;
let col_guard = AllocGuard::new(client.allocator(), col_size)?;
let col_ptr = col_guard.ptr();
let col_buffer = get_buffer(col_ptr)
.ok_or_else(|| Error::Internal("Failed to get col buffer".to_string()))?;
let b_contig = b.contiguous();
let b_buffer = get_buffer(b_contig.ptr())
.ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?;
let x_total_size = n * num_rhs * dtype.size_in_bytes();
let x_out_guard = AllocGuard::new(client.allocator(), x_total_size)?;
let x_out_ptr = x_out_guard.ptr();
let x_out_buffer = get_buffer(x_out_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_out buffer".to_string()))?;
for rhs in 0..num_rhs {
if num_rhs == 1 {
let copy_params: [u32; 4] = [n as u32, 0, 0, 0];
let copy_params_buffer = client.create_uniform_buffer("copy_params", 16);
client.write_buffer(©_params_buffer, ©_params);
kernels::launch_matrix_copy(
client.pipeline_cache(),
&client.queue,
&b_buffer,
&col_buffer,
©_params_buffer,
n,
dtype,
)?;
} else {
let extract_params: [u32; 4] = [n as u32, num_rhs as u32, rhs as u32, 0];
let extract_params_buffer = client.create_uniform_buffer("extract_params", 16);
client.write_buffer(&extract_params_buffer, &extract_params);
kernels::launch_extract_column(
client.pipeline_cache(),
&client.queue,
&b_buffer,
&col_buffer,
&extract_params_buffer,
n,
dtype,
)?;
}
let perm_params: [u32; 4] = [n as u32, 0, 0, 0];
let perm_params_buffer = client.create_uniform_buffer("perm_params", 16);
client.write_buffer(&perm_params_buffer, &perm_params);
kernels::launch_apply_lu_permutation(
client.pipeline_cache(),
&client.queue,
&col_buffer,
&pb_buffer,
&pivots_buffer,
&perm_params_buffer,
dtype,
)?;
let forward_params: [u32; 2] = [n as u32, 1];
let forward_params_buffer = client.create_uniform_buffer("forward_params", 8);
client.write_buffer(&forward_params_buffer, &forward_params);
kernels::launch_forward_sub(
client.pipeline_cache(),
&client.queue,
&lu_buffer,
&pb_buffer,
&y_buffer,
&forward_params_buffer,
dtype,
)?;
let backward_params: [u32; 4] = [n as u32, 0, 0, 0];
let backward_params_buffer = client.create_uniform_buffer("backward_params", 16);
client.write_buffer(&backward_params_buffer, &backward_params);
let x_col_offset = rhs * n * dtype.size_in_bytes();
let x_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let x_col_ptr = x_col_guard.ptr();
let x_col_buffer = get_buffer(x_col_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_col buffer".to_string()))?;
kernels::launch_backward_sub(
client.pipeline_cache(),
&client.queue,
&lu_buffer,
&y_buffer,
&x_col_buffer,
&backward_params_buffer,
dtype,
)?;
{
let mut encoder =
client
.wgpu_device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("copy_x_col_to_output"),
});
encoder.copy_buffer_to_buffer(
&x_col_buffer,
0,
&x_out_buffer,
x_col_offset as u64,
col_size as u64,
);
client.queue.submit(std::iter::once(encoder.finish()));
}
drop(x_col_guard);
}
client.synchronize();
drop(pb_guard);
drop(y_guard);
drop(col_guard);
if b_shape.len() == 1 {
let x = unsafe { WgpuClient::tensor_from_raw(x_out_guard.release(), &[n], dtype, device) };
Ok(x)
} else {
let x_col_major = unsafe {
WgpuClient::tensor_from_raw(x_out_guard.release(), &[num_rhs, n], dtype, device)
};
let x = x_col_major.transpose(0, 1)?;
Ok(x.contiguous())
}
}