use super::helpers::*;
use crate::error::Error;
use crate::error::Result;
use crate::ops::{matmul_bias_output_shape, matmul_output_shape, validate_matmul_bias_dtypes};
use crate::runtime::ensure_contiguous;
use crate::runtime::wgpu::shaders::{gemv_bt, matmul};
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
fn is_simple_transpose_2d(tensor: &Tensor<WgpuRuntime>) -> bool {
let shape = tensor.shape();
let strides = tensor.strides();
if shape.len() != 2 {
return false;
}
strides[0] == 1 && strides[1] == shape[0] as isize
}
fn is_batched_transpose_last2(tensor: &Tensor<WgpuRuntime>) -> bool {
let shape = tensor.shape();
let strides = tensor.strides();
if shape.len() != 3 {
return false;
}
let k = shape[1];
let n = shape[2];
strides[1] == 1 && strides[2] == k as isize && strides[0] == (n * k) as isize
}
pub(crate) fn native_matmul(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let out_shape = matmul_output_shape(a.shape(), b.shape())
.ok_or_else(|| Error::shape_mismatch(a.shape(), b.shape()))?;
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() == 2 && b_shape.len() == 2 {
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if m <= 16 && is_simple_transpose_2d(b) {
let a_contig = ensure_contiguous(a);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(b)?; let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: 1,
};
let params_buf = create_params_buffer(client, ¶ms);
gemv_bt::launch_gemv_bt(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
dtype,
)?;
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(&b_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: 1,
};
let params_buf = create_params_buffer(client, ¶ms);
if m * n > 256 * 256 {
matmul::launch_matmul(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
dtype,
)?;
} else {
matmul::launch_matmul_simple(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
dtype,
)?;
}
return Ok(out);
}
if a_shape.len() == 3 && b_shape.len() == 3 {
let batch_size = a_shape[0];
let m = a_shape[1];
let k = a_shape[2];
let n = b_shape[2];
if b_shape[0] != batch_size {
return Err(Error::ShapeMismatch {
expected: vec![batch_size, m, k],
got: b_shape.to_vec(),
});
}
if m <= 16 && is_batched_transpose_last2(b) {
let a_contig = ensure_contiguous(a);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(b)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: batch_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
gemv_bt::launch_batched_gemv_bt(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
batch_size,
dtype,
)?;
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(&b_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: batch_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
matmul::launch_batched_matmul(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
batch_size,
dtype,
)?;
return Ok(out);
}
let ndim_a = a_shape.len();
let ndim_b = b_shape.len();
if ndim_a < 2 || ndim_b < 2 {
return Err(Error::BackendLimitation {
backend: "WebGPU",
operation: "matmul",
reason: format!(
"requires at least 2D tensors, got shapes {:?} and {:?}",
a_shape, b_shape
),
});
}
let m = a_shape[ndim_a - 2];
let k = a_shape[ndim_a - 1];
let n = b_shape[ndim_b - 1];
let batch_a: usize = a_shape[..ndim_a - 2].iter().product();
let batch_b: usize = b_shape[..ndim_b - 2].iter().product();
let batch_size = batch_a.max(batch_b);
let a_3d = ensure_contiguous(a)
.reshape(&[batch_a, m, k])
.map_err(|_| Error::shape_mismatch(a_shape, b_shape))?;
let b_3d = ensure_contiguous(b)
.reshape(&[batch_b, k, n])
.map_err(|_| Error::shape_mismatch(a_shape, b_shape))?;
let (a_batched, b_batched) = if batch_a == batch_b {
(a_3d, b_3d)
} else if batch_a == 1 {
(
a_3d.broadcast_to(&[batch_size, m, k])
.map_err(|_| Error::shape_mismatch(a_shape, b_shape))?
.contiguous(),
b_3d,
)
} else if batch_b == 1 {
(
a_3d,
b_3d.broadcast_to(&[batch_size, k, n])
.map_err(|_| Error::shape_mismatch(a_shape, b_shape))?
.contiguous(),
)
} else {
return Err(Error::shape_mismatch(a_shape, b_shape));
};
let a_buf = get_tensor_buffer(&a_batched)?;
let b_buf = get_tensor_buffer(&b_batched)?;
let out_flat = alloc_output(client, &[batch_size, m, n], dtype);
let out_buf = get_tensor_buffer(&out_flat)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: batch_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
matmul::launch_batched_matmul(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&out_buf,
¶ms_buf,
m,
n,
batch_size,
dtype,
)?;
let result = out_flat
.reshape(&out_shape)
.map_err(|_| Error::shape_mismatch(a_shape, b_shape))?;
Ok(result)
}
pub(crate) fn native_matmul_bias(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
bias: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?;
let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape())
.ok_or_else(|| Error::shape_mismatch(a.shape(), b.shape()))?;
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() == 2 && b_shape.len() == 2 {
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let bias_contig = ensure_contiguous(bias);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(&b_contig)?;
let bias_buf = get_tensor_buffer(&bias_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: 1,
};
let params_buf = create_params_buffer(client, ¶ms);
matmul::launch_matmul_bias(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&bias_buf,
&out_buf,
¶ms_buf,
m,
n,
dtype,
)?;
return Ok(out);
}
if a_shape.len() == 3 && b_shape.len() == 3 {
let batch_size = a_shape[0];
let m = a_shape[1];
let k = a_shape[2];
let n = b_shape[2];
if b_shape[0] != batch_size {
return Err(Error::ShapeMismatch {
expected: vec![batch_size, m, k],
got: b_shape.to_vec(),
});
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let bias_contig = ensure_contiguous(bias);
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let b_buf = get_tensor_buffer(&b_contig)?;
let bias_buf = get_tensor_buffer(&bias_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MatmulParams {
m: m as u32,
k: k as u32,
n: n as u32,
batch_size: batch_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
matmul::launch_batched_matmul_bias(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&b_buf,
&bias_buf,
&out_buf,
¶ms_buf,
m,
n,
batch_size,
dtype,
)?;
return Ok(out);
}
Err(Error::BackendLimitation {
backend: "WebGPU",
operation: "matmul_bias",
reason: format!(
"only supports 2D and 3D tensors, got shapes {:?} and {:?}",
a.shape(),
b.shape()
),
})
}