use wgpu::CommandEncoderDescriptor;
use super::super::client::get_buffer;
use super::super::shaders::linalg as kernels;
use super::super::{WgpuClient, WgpuRuntime};
use super::decompositions::qr_decompose_internal;
use crate::algorithm::linalg::{validate_linalg_dtype, validate_matrix_2d};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::MatmulOps;
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
pub fn lstsq(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
let (m, n) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "WGPU lstsq (only F32 supported)",
});
}
let b_shape = b.shape();
let (num_rhs, b_is_vector) = if b_shape.len() == 1 {
if b_shape[0] != m {
return Err(Error::ShapeMismatch {
expected: vec![m],
got: b_shape.to_vec(),
});
}
(1, true)
} else if b_shape.len() == 2 {
if b_shape[0] != m {
return Err(Error::ShapeMismatch {
expected: vec![m, b_shape[1]],
got: b_shape.to_vec(),
});
}
(b_shape[1], false)
} else {
return Err(Error::Internal(format!(
"lstsq requires b to be 1D or 2D, got {}D with shape {:?}",
b_shape.len(),
b_shape
)));
};
if m < n {
return Err(Error::Internal(format!(
"lstsq: underdetermined system not supported (A is {}x{}, requires m >= n)",
m, n
)));
}
let qr = qr_decompose_internal(client, a, false)?;
let q_t = qr.q.transpose(0, 1)?.contiguous();
let b_mat = if b_is_vector {
b.reshape(&[m, 1])?.contiguous()
} else {
b.contiguous()
};
let qtb = client.matmul(&q_t, &b_mat)?;
let r_buffer = get_buffer(qr.r.ptr())
.ok_or_else(|| Error::Internal("Failed to get R buffer".to_string()))?;
let x_total_size = n * num_rhs * dtype.size_in_bytes();
let x_out_guard = AllocGuard::new(client.allocator(), x_total_size)?;
let x_out_ptr = x_out_guard.ptr();
let x_out_buffer = get_buffer(x_out_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_out buffer".to_string()))?;
if b_is_vector {
let qtb_flat = qtb.reshape(&[m])?;
let qtb_n = qtb_flat.narrow(0, 0, n)?.contiguous();
let qtb_buffer = get_buffer(qtb_n.ptr())
.ok_or_else(|| Error::Internal("Failed to get qtb buffer".to_string()))?;
let params: [u32; 1] = [n as u32];
let params_buffer = client.create_uniform_buffer("backward_params", 4);
client.write_buffer(¶ms_buffer, ¶ms);
kernels::launch_backward_sub(
client.pipeline_cache(),
&client.queue,
&r_buffer,
&qtb_buffer,
&x_out_buffer,
¶ms_buffer,
dtype,
)?;
} else {
let qtb_contig = qtb.contiguous();
let qtb_buffer = get_buffer(qtb_contig.ptr())
.ok_or_else(|| Error::Internal("Failed to get qtb buffer".to_string()))?;
let col_size = n * dtype.size_in_bytes();
let col_guard = AllocGuard::new(client.allocator(), col_size)?;
let col_ptr = col_guard.ptr();
let col_buffer = get_buffer(col_ptr)
.ok_or_else(|| Error::Internal("Failed to get col buffer".to_string()))?;
let x_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let x_col_ptr = x_col_guard.ptr();
let x_col_buffer = get_buffer(x_col_ptr)
.ok_or_else(|| Error::Internal("Failed to get x_col buffer".to_string()))?;
for rhs in 0..num_rhs {
let extract_params: [u32; 4] = [n as u32, num_rhs as u32, rhs as u32, 0];
let extract_params_buffer = client.create_uniform_buffer("extract_params", 16);
client.write_buffer(&extract_params_buffer, &extract_params);
kernels::launch_extract_column(
client.pipeline_cache(),
&client.queue,
&qtb_buffer,
&col_buffer,
&extract_params_buffer,
n, dtype,
)?;
let backward_params: [u32; 1] = [n as u32];
let backward_params_buffer = client.create_uniform_buffer("backward_params", 4);
client.write_buffer(&backward_params_buffer, &backward_params);
kernels::launch_backward_sub(
client.pipeline_cache(),
&client.queue,
&r_buffer,
&col_buffer,
&x_col_buffer,
&backward_params_buffer,
dtype,
)?;
let x_col_offset = rhs * n * dtype.size_in_bytes();
{
let mut encoder =
client
.wgpu_device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("copy_x_col_to_output"),
});
encoder.copy_buffer_to_buffer(
&x_col_buffer,
0,
&x_out_buffer,
x_col_offset as u64,
col_size as u64,
);
client.queue.submit(std::iter::once(encoder.finish()));
}
}
drop(col_guard);
drop(x_col_guard);
}
client.synchronize();
if b_is_vector {
let x = unsafe { WgpuClient::tensor_from_raw(x_out_guard.release(), &[n], dtype, device) };
Ok(x)
} else {
let x_col_major = unsafe {
WgpuClient::tensor_from_raw(x_out_guard.release(), &[num_rhs, n], dtype, device)
};
let x = x_col_major.transpose(0, 1)?;
Ok(x.contiguous())
}
}