use wgpu::CommandEncoderDescriptor;
use super::super::client::get_buffer;
use super::super::shaders::linalg as kernels;
use super::super::{WgpuClient, WgpuRuntime};
use crate::algorithm::linalg::{validate_linalg_dtype, validate_square_matrix};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
pub fn solve_triangular_lower(
client: &WgpuClient,
l: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
unit_diagonal: bool,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(l.dtype())?;
if l.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: l.dtype(),
rhs: b.dtype(),
});
}
let n = validate_square_matrix(l.shape())?;
let dtype = l.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "WGPU solve_triangular_lower (only F32 supported)",
});
}
let b_shape = b.shape();
let (num_rhs, b_is_vector) = if b_shape.len() == 1 {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: b_shape.to_vec(),
});
}
(1, true)
} else if b_shape.len() == 2 {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n, b_shape[1]],
got: b_shape.to_vec(),
});
}
(b_shape[1], false)
} else {
return Err(Error::Internal(format!(
"solve_triangular_lower requires b to be 1D or 2D, got {}D with shape {:?}",
b_shape.len(),
b_shape
)));
};
let l_buffer =
get_buffer(l.ptr()).ok_or_else(|| Error::Internal("Failed to get L buffer".to_string()))?;
let b_contig = b.contiguous();
let b_buffer = get_buffer(b_contig.ptr())
.ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?;
let x_total_size = n * num_rhs * dtype.size_in_bytes();
let x_out_guard = AllocGuard::new(client.allocator(), x_total_size)?;
let x_out_ptr = x_out_guard.ptr();
let x_out_buffer = get_buffer(x_out_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_out buffer".to_string()))?;
if b_is_vector {
let params: [u32; 2] = [n as u32, if unit_diagonal { 1 } else { 0 }];
let params_buffer = client.create_uniform_buffer("forward_params", 8);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_forward_sub(
client.pipeline_cache(),
&client.queue,
&l_buffer,
&b_buffer,
&x_out_buffer,
¶ms_buffer,
dtype,
)?;
} else {
let col_size = n * dtype.size_in_bytes();
let col_guard = AllocGuard::new(client.allocator(), col_size)?;
let col_ptr = col_guard.ptr();
let col_buffer = get_buffer(col_ptr)
.ok_or_else(|| Error::Internal("Failed to get col buffer".to_string()))?;
let x_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let x_col_ptr = x_col_guard.ptr();
let x_col_buffer = get_buffer(x_col_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_col buffer".to_string()))?;
for rhs in 0..num_rhs {
let extract_params: [u32; 4] = [n as u32, num_rhs as u32, rhs as u32, 0];
let extract_params_buffer = client.create_uniform_buffer("extract_params", 16);
client.write_buffer(&extract_params_buffer, &extract_params);
kernels::launch_extract_column(
client.pipeline_cache(),
&client.queue,
&b_buffer,
&col_buffer,
&extract_params_buffer,
n,
dtype,
)?;
let forward_params: [u32; 2] = [n as u32, if unit_diagonal { 1 } else { 0 }];
let forward_params_buffer = client.create_uniform_buffer("forward_params", 8);
client.write_buffer(&forward_params_buffer, &forward_params);
kernels::launch_forward_sub(
client.pipeline_cache(),
&client.queue,
&l_buffer,
&col_buffer,
&x_col_buffer,
&forward_params_buffer,
dtype,
)?;
let x_col_offset = rhs * n * dtype.size_in_bytes();
{
let mut encoder =
client
.wgpu_device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("copy_x_col_to_output"),
});
encoder.copy_buffer_to_buffer(
&x_col_buffer,
0,
&x_out_buffer,
x_col_offset as u64,
col_size as u64,
);
client.queue.submit(std::iter::once(encoder.finish()));
}
}
drop(col_guard);
drop(x_col_guard);
}
client.synchronize();
if b_is_vector {
let x = unsafe { WgpuClient::tensor_from_raw(x_out_guard.release(), &[n], dtype, device) };
Ok(x)
} else {
let x_col_major = unsafe {
WgpuClient::tensor_from_raw(x_out_guard.release(), &[num_rhs, n], dtype, device)
};
let x = x_col_major.transpose(0, 1)?;
Ok(x.contiguous())
}
}
pub fn solve_triangular_upper(
client: &WgpuClient,
u: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(u.dtype())?;
if u.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: u.dtype(),
rhs: b.dtype(),
});
}
let n = validate_square_matrix(u.shape())?;
let dtype = u.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "WGPU solve_triangular_upper (only F32 supported)",
});
}
let b_shape = b.shape();
let (num_rhs, b_is_vector) = if b_shape.len() == 1 {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: b_shape.to_vec(),
});
}
(1, true)
} else if b_shape.len() == 2 {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n, b_shape[1]],
got: b_shape.to_vec(),
});
}
(b_shape[1], false)
} else {
return Err(Error::Internal(format!(
"solve_triangular_upper requires b to be 1D or 2D, got {}D with shape {:?}",
b_shape.len(),
b_shape
)));
};
let u_buffer =
get_buffer(u.ptr()).ok_or_else(|| Error::Internal("Failed to get U buffer".to_string()))?;
let b_contig = b.contiguous();
let b_buffer = get_buffer(b_contig.ptr())
.ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?;
let x_total_size = n * num_rhs * dtype.size_in_bytes();
let x_out_guard = AllocGuard::new(client.allocator(), x_total_size)?;
let x_out_ptr = x_out_guard.ptr();
let x_out_buffer = get_buffer(x_out_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_out buffer".to_string()))?;
if b_is_vector {
let params: [u32; 1] = [n as u32];
let params_buffer = client.create_uniform_buffer("backward_params", 4);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_backward_sub(
client.pipeline_cache(),
&client.queue,
&u_buffer,
&b_buffer,
&x_out_buffer,
¶ms_buffer,
dtype,
)?;
} else {
let col_size = n * dtype.size_in_bytes();
let col_guard = AllocGuard::new(client.allocator(), col_size)?;
let col_ptr = col_guard.ptr();
let col_buffer = get_buffer(col_ptr)
.ok_or_else(|| Error::Internal("Failed to get col buffer".to_string()))?;
let x_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let x_col_ptr = x_col_guard.ptr();
let x_col_buffer = get_buffer(x_col_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_col buffer".to_string()))?;
for rhs in 0..num_rhs {
let extract_params: [u32; 4] = [n as u32, num_rhs as u32, rhs as u32, 0];
let extract_params_buffer = client.create_uniform_buffer("extract_params", 16);
client.write_buffer(&extract_params_buffer, &extract_params);
kernels::launch_extract_column(
client.pipeline_cache(),
&client.queue,
&b_buffer,
&col_buffer,
&extract_params_buffer,
n,
dtype,
)?;
let backward_params: [u32; 1] = [n as u32];
let backward_params_buffer = client.create_uniform_buffer("backward_params", 4);
client.write_buffer(&backward_params_buffer, &backward_params);
kernels::launch_backward_sub(
client.pipeline_cache(),
&client.queue,
&u_buffer,
&col_buffer,
&x_col_buffer,
&backward_params_buffer,
dtype,
)?;
let x_col_offset = rhs * n * dtype.size_in_bytes();
{
let mut encoder =
client
.wgpu_device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("copy_x_col_to_output"),
});
encoder.copy_buffer_to_buffer(
&x_col_buffer,
0,
&x_out_buffer,
x_col_offset as u64,
col_size as u64,
);
client.queue.submit(std::iter::once(encoder.finish()));
}
}
drop(col_guard);
drop(x_col_guard);
}
client.synchronize();
if b_is_vector {
let x = unsafe { WgpuClient::tensor_from_raw(x_out_guard.release(), &[n], dtype, device) };
Ok(x)
} else {
let x_col_major = unsafe {
WgpuClient::tensor_from_raw(x_out_guard.release(), &[num_rhs, n], dtype, device)
};
let x = x_col_major.transpose(0, 1)?;
Ok(x.contiguous())
}
}