use crate::shapes::TensorLayoutBuffers;
use crate::tensor::{AsTensorMut, AsTensorRef};
use khal::Shader;
use khal::backend::{GpuBackend, GpuBackendError, GpuPass};
use crate::shaders::linalg::{GemmNaive, GemmTiled};
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum MatrixMode {
Normal,
Transposed,
}
impl MatrixMode {
pub fn transpose(&mut self) {
match self {
Self::Normal => *self = Self::Transposed,
Self::Transposed => *self = Self::Normal,
}
}
}
pub const N: MatrixMode = MatrixMode::Normal;
pub const T: MatrixMode = MatrixMode::Transposed;
#[derive(Shader)]
pub struct Gemm {
pub gemm_naive: GemmNaive,
pub gemm_tiled: GemmTiled,
}
impl Gemm {
pub fn dispatch(
&self,
backend: &GpuBackend,
#[cfg_attr(feature = "push_constants", allow(unused_variables))]
shapes: &mut TensorLayoutBuffers,
pass: &mut GpuPass,
mut out: impl AsTensorMut<f32>,
lhs: impl AsTensorRef<f32>,
rhs: impl AsTensorRef<f32>,
) -> Result<(), GpuBackendError> {
let mut out = out.as_tensor_mut();
let lhs = lhs.as_tensor_ref();
let rhs = rhs.as_tensor_ref();
let shape_out = out.layout().canonicalize();
let shape_lhs = lhs.layout().canonicalize();
let shape_rhs = rhs.layout().canonicalize();
assert_eq!(
shape_out.size[0], shape_lhs.size[0],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_out.size[1], shape_lhs.size[1],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_lhs.size[0], shape_rhs.size[0],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_lhs.size[1], shape_rhs.size[1],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_out.size[2], shape_lhs.size[2],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_out.size[3], shape_rhs.size[3],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
assert_eq!(
shape_lhs.size[3], shape_rhs.size[2],
"incompatible gemm shapes: {:?} = {:?} x {:?}",
shape_out, shape_lhs, shape_rhs
);
let use_tiled = cfg!(not(target_arch = "wasm32"))
&& (shape_out.size[2] >= 32 || shape_out.size[3] >= 32);
let grid = if use_tiled {
Self::tiled_grid(&shape_out)
} else {
Self::naive_grid(&shape_out)
};
#[cfg(not(feature = "push_constants"))]
{
shapes.insert(backend, shape_out)?;
shapes.insert(backend, shape_lhs)?;
shapes.insert(backend, shape_rhs)?;
let gpu_shape_out = shapes.get(shape_out).unwrap_or_else(|| unreachable!());
let gpu_shape_lhs = shapes.get(shape_lhs).unwrap_or_else(|| unreachable!());
let gpu_shape_rhs = shapes.get(shape_rhs).unwrap_or_else(|| unreachable!());
let mut buf_out = out.buffer_mut();
if use_tiled {
self.gemm_tiled.call(
pass,
grid,
&gpu_shape_out.as_slice(),
&gpu_shape_lhs.as_slice(),
&gpu_shape_rhs.as_slice(),
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
)
} else {
self.gemm_naive.call(
pass,
grid,
&gpu_shape_out.as_slice(),
&gpu_shape_lhs.as_slice(),
&gpu_shape_rhs.as_slice(),
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
)
}
}
#[cfg(feature = "push_constants")]
{
let mut buf_out = out.buffer_mut();
if use_tiled {
self.gemm_tiled.call(
pass,
grid,
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
crate::shaders::linalg::Shapes3 {
shape_out: shape_out.into(),
shape_lhs: shape_lhs.into(),
shape_rhs: shape_rhs.into(),
},
)
} else {
self.gemm_navive.call(
pass,
grid,
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
crate::shaders::linalg::Shapes3 {
shape_out: shape_out.into(),
shape_lhs: shape_lhs.into(),
shape_rhs: shape_rhs.into(),
},
)
}
}
}
pub fn dispatch_naive(
&self,
backend: &GpuBackend,
#[cfg_attr(feature = "push_constants", allow(unused_variables))]
shapes: &mut TensorLayoutBuffers,
pass: &mut GpuPass,
mut out: impl AsTensorMut<f32>,
lhs: impl AsTensorRef<f32>,
rhs: impl AsTensorRef<f32>,
) -> Result<(), GpuBackendError> {
let mut out = out.as_tensor_mut();
let lhs = lhs.as_tensor_ref();
let rhs = rhs.as_tensor_ref();
let shape_out = out.layout().canonicalize();
let shape_lhs = lhs.layout().canonicalize();
let shape_rhs = rhs.layout().canonicalize();
let grid = Self::naive_grid(&shape_out);
#[cfg(not(feature = "push_constants"))]
{
shapes.insert(backend, shape_out)?;
shapes.insert(backend, shape_lhs)?;
shapes.insert(backend, shape_rhs)?;
let mut buf_out = out.buffer_mut();
self.gemm_naive.call(
pass,
grid,
&shapes.get(shape_out).unwrap().as_slice(),
&shapes.get(shape_lhs).unwrap().as_slice(),
&shapes.get(shape_rhs).unwrap().as_slice(),
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
)
}
#[cfg(feature = "push_constants")]
{
let mut buf_out = out.buffer_mut();
self.gemm_naive.call(
pass,
grid,
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
crate::shaders::linalg::Shapes3 {
shape_out: shape_out.into(),
shape_lhs: shape_lhs.into(),
shape_rhs: shape_rhs.into(),
},
)
}
}
pub fn dispatch_tiled(
&self,
backend: &GpuBackend,
#[cfg_attr(feature = "push_constants", allow(unused_variables))]
shapes: &mut TensorLayoutBuffers,
pass: &mut GpuPass,
mut out: impl AsTensorMut<f32>,
lhs: impl AsTensorRef<f32>,
rhs: impl AsTensorRef<f32>,
) -> Result<(), GpuBackendError> {
let mut out = out.as_tensor_mut();
let lhs = lhs.as_tensor_ref();
let rhs = rhs.as_tensor_ref();
let shape_out = out.layout().canonicalize();
let shape_lhs = lhs.layout().canonicalize();
let shape_rhs = rhs.layout().canonicalize();
let grid = Self::tiled_grid(&shape_out);
#[cfg(not(feature = "push_constants"))]
{
shapes.insert(backend, shape_out)?;
shapes.insert(backend, shape_lhs)?;
shapes.insert(backend, shape_rhs)?;
let mut buf_out = out.buffer_mut();
self.gemm_tiled.call(
pass,
grid,
&shapes.get(shape_out).unwrap().as_slice(),
&shapes.get(shape_lhs).unwrap().as_slice(),
&shapes.get(shape_rhs).unwrap().as_slice(),
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
)
}
#[cfg(feature = "push_constants")]
{
let mut buf_out = out.buffer_mut();
self.gemm_tiled.call(
pass,
grid,
&mut buf_out,
&lhs.buffer(),
&rhs.buffer(),
crate::shaders::linalg::Shapes3 {
shape_out: shape_out.into(),
shape_lhs: shape_lhs.into(),
shape_rhs: shape_rhs.into(),
},
)
}
}
fn naive_grid(shape_out: &crate::shapes::TensorLayout) -> [u32; 3] {
[shape_out.size[3], shape_out.size[2], shape_out.size[1]]
}
fn tiled_grid(shape_out: &crate::shapes::TensorLayout) -> [u32; 3] {
const TILE_M: u32 = 64;
const TILE_N: u32 = 64;
const WG_M: u32 = 16;
const WG_N: u32 = 16;
let num_wg_m = shape_out.size[2].div_ceil(TILE_M);
let num_wg_n = shape_out.size[3].div_ceil(TILE_N);
[num_wg_n * WG_N, num_wg_m * WG_M, shape_out.size[1]]
}
}
#[cfg(test)]
mod test {
use crate::shapes::TensorLayoutBuffers;
use crate::tensor::Tensor;
use approx::assert_relative_eq;
use khal::backend::{Backend, Encoder, GpuBackend, WebGpu};
use khal::{BufferUsages, Shader};
use nalgebra::{DMatrix, DVector};
use wgpu::{Features, Limits};
#[futures_test::test]
#[serial_test::serial]
async fn gpu_gemm_webgpu() {
let webgpu = WebGpu::new(Features::default(), Limits::default())
.await
.unwrap();
let backend = GpuBackend::WebGpu(webgpu);
gpu_gemm_generic(&backend).await;
}
#[cfg(feature = "cpu")]
#[futures_test::test]
async fn gpu_gemm_cpu() {
gpu_gemm_generic(&GpuBackend::Cpu).await;
}
#[cfg(feature = "cuda")]
#[futures_test::test]
async fn gpu_gemm_cuda() {
let cuda = GpuBackend::Cuda(khal::backend::cuda::Cuda::new(0).unwrap());
gpu_gemm_generic(&cuda).await;
}
#[cfg(feature = "metal")]
#[futures_test::test]
#[serial_test::serial]
async fn gpu_gemm_metal() {
let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
gpu_gemm_generic(&metal).await;
}
async fn gpu_gemm_generic(backend: &GpuBackend) {
let gemm = super::Gemm::from_backend(backend).unwrap();
let mut shapes = TensorLayoutBuffers::new(backend);
const NROWS: u32 = 256;
const NCOLS: u32 = 256;
let m_cpu = DMatrix::<f32>::new_random(NROWS as usize, NCOLS as usize);
let v_cpu = DVector::<f32>::new_random(NCOLS as usize);
let lhs_cpu = DVector::<f32>::zeros(NROWS as usize);
let mut gpu_result = DVector::<f32>::zeros(NROWS as usize);
let m = Tensor::matrix_from_na(backend, &m_cpu, BufferUsages::STORAGE).unwrap();
let mut v = Tensor::vector(backend, &v_cpu, BufferUsages::STORAGE).unwrap();
v.unsqueeze(1);
let mut result = Tensor::vector(
backend,
&lhs_cpu,
BufferUsages::STORAGE | BufferUsages::COPY_SRC,
)
.unwrap();
result.unsqueeze(1);
let t0 = std::time::Instant::now();
let mut encoder = backend.begin_encoding();
let mut pass = encoder.begin_pass("gemm", None);
gemm.dispatch(
backend,
&mut shapes,
&mut pass,
&mut result,
&m,
&v,
)
.unwrap();
drop(pass);
backend.submit(encoder).unwrap();
backend.synchronize().unwrap();
println!("GEMM before read: {}", t0.elapsed().as_secs_f32());
backend
.slow_read_buffer(result.buffer(), gpu_result.as_mut_slice())
.await
.unwrap();
println!("GEMM time: {}", t0.elapsed().as_secs_f32());
let cpu_result = &m_cpu * &v_cpu;
assert_relative_eq!(gpu_result, cpu_result, epsilon = 1.0e-3);
}
}