#![cfg(feature = "gpu")]
use anyhow::{Result, anyhow};
use cubecl::client::ComputeClient;
use cubecl::cuda::{CudaDevice, CudaRuntime};
use cubecl::prelude::*;
use cubecl::std::tensor::TensorHandle;
use cubek_matmul::definition::MatmulElems;
use cubek_matmul::launch::{MatmulInputHandleRef, Strategy};
type R = CudaRuntime;
pub struct GpuContext {
client: ComputeClient<R>,
strategy: Strategy,
}
fn extract_result(raw: &[f32], m: usize, n: usize, row_stride: usize) -> Vec<f32> {
let mut out = Vec::with_capacity(m * n);
for i in 0..m {
let start = i * row_stride;
out.extend_from_slice(&raw[start..start + n]);
}
out
}
impl GpuContext {
pub fn new() -> Result<Self> {
let device = CudaDevice::new(0);
let client = R::client(&device);
Ok(Self {
client,
strategy: Strategy::Auto,
})
}
pub fn matmul_tn(
&self,
a_data: &[f32],
n_rows: usize,
a_cols: usize,
b_data: &[f32],
b_cols: usize,
) -> Result<Vec<f32>> {
let m = a_cols;
let k = n_rows;
let n = b_cols;
let f32_storage = f32::as_type_native_unchecked();
let a_handle = self.client.create_from_slice(f32::as_bytes(a_data));
let b_handle = self.client.create_from_slice(f32::as_bytes(b_data));
let a_tensor = TensorHandle::<R>::new(a_handle, vec![m, k], vec![k, 1], f32_storage);
let b_tensor = TensorHandle::<R>::new(b_handle, vec![k, n], vec![1, k], f32_storage);
let out_tensor = TensorHandle::<R>::empty(&self.client, vec![m, n], f32_storage);
let out_row_stride = out_tensor.strides()[0];
let mut dtypes = MatmulElems::from_single_dtype(f32_storage);
cubek_matmul::launch::launch_ref::<R>(
&self.strategy,
&self.client,
&MatmulInputHandleRef::new(a_tensor.as_ref(), f32_storage),
&MatmulInputHandleRef::new(b_tensor.as_ref(), f32_storage),
&out_tensor.as_ref(),
&mut dtypes,
)
.map_err(|e| anyhow!("GPU matmul failed: {:?}", e))?;
let out_bytes = self.client.read_one(out_tensor.handle);
let raw: &[f32] = f32::from_bytes(&out_bytes);
Ok(extract_result(raw, m, n, out_row_stride))
}
#[allow(dead_code)]
pub fn matmul_tn_tiled(
&self,
a_data: &[f32],
n_rows: usize,
a_cols: usize,
b_data: &[f32],
b_cols: usize,
tile_cols: usize,
) -> Result<Vec<f32>> {
let n = b_cols;
let f32_storage = f32::as_type_native_unchecked();
let mut result = vec![0.0f32; a_cols * n];
for tile_start in (0..a_cols).step_by(tile_cols) {
let tile_end = (tile_start + tile_cols).min(a_cols);
let tile_m = tile_end - tile_start;
let tile_data: Vec<f32> = (tile_start..tile_end)
.flat_map(|col| {
let offset = col * n_rows;
a_data[offset..offset + n_rows].iter().copied()
})
.collect();
let b_handle = self.client.create_from_slice(f32::as_bytes(b_data));
let b_tensor =
TensorHandle::<R>::new(b_handle, vec![n_rows, n], vec![1, n_rows], f32_storage);
let a_handle = self.client.create_from_slice(f32::as_bytes(&tile_data));
let a_tensor = TensorHandle::<R>::new(
a_handle,
vec![tile_m, n_rows],
vec![n_rows, 1],
f32_storage,
);
let out_tensor = TensorHandle::<R>::empty(&self.client, vec![tile_m, n], f32_storage);
let out_row_stride = out_tensor.strides()[0];
let mut dtypes = MatmulElems::from_single_dtype(f32_storage);
cubek_matmul::launch::launch_ref::<R>(
&self.strategy,
&self.client,
&MatmulInputHandleRef::new(a_tensor.as_ref(), f32_storage),
&MatmulInputHandleRef::new(b_tensor.as_ref(), f32_storage),
&out_tensor.as_ref(),
&mut dtypes,
)
.map_err(|e| anyhow!("GPU matmul tile failed: {:?}", e))?;
let tile_bytes = self.client.read_one(out_tensor.handle);
let raw: &[f32] = f32::from_bytes(&tile_bytes);
for i in 0..tile_m {
for j in 0..n {
result[(tile_start + i) * n + j] = raw[i * out_row_stride + j];
}
}
}
Ok(result)
}
pub fn matmul_tn_flex32(
&self,
a_data: &[f32],
n_rows: usize,
a_cols: usize,
b_data: &[f32],
b_cols: usize,
) -> Result<Vec<f32>> {
let m = a_cols;
let k = n_rows;
let n = b_cols;
let f32_storage = f32::as_type_native_unchecked();
let a_handle = self.client.create_from_slice(f32::as_bytes(a_data));
let b_handle = self.client.create_from_slice(f32::as_bytes(b_data));
let a_tensor = TensorHandle::<R>::new(a_handle, vec![m, k], vec![k, 1], f32_storage);
let b_tensor = TensorHandle::<R>::new(b_handle, vec![k, n], vec![1, k], f32_storage);
let out_tensor = TensorHandle::<R>::empty(&self.client, vec![m, n], f32_storage);
let out_row_stride = out_tensor.strides()[0];
let mut dtypes = MatmulElems::new_deprecated::<flex32>();
cubek_matmul::launch::launch_ref::<R>(
&self.strategy,
&self.client,
&MatmulInputHandleRef::new(a_tensor.as_ref(), f32_storage),
&MatmulInputHandleRef::new(b_tensor.as_ref(), f32_storage),
&out_tensor.as_ref(),
&mut dtypes,
)
.map_err(|e| anyhow!("GPU flex32 matmul failed: {:?}", e))?;
let out_bytes = self.client.read_one(out_tensor.handle);
let raw: &[f32] = f32::from_bytes(&out_bytes);
Ok(extract_result(raw, m, n, out_row_stride))
}
#[allow(dead_code)]
pub fn matmul_tn_tiled_flex32(
&self,
a_data: &[f32],
n_rows: usize,
a_cols: usize,
b_data: &[f32],
b_cols: usize,
tile_cols: usize,
) -> Result<Vec<f32>> {
let n = b_cols;
let f32_storage = f32::as_type_native_unchecked();
let mut result = vec![0.0f32; a_cols * n];
for tile_start in (0..a_cols).step_by(tile_cols) {
let tile_end = (tile_start + tile_cols).min(a_cols);
let tile_m = tile_end - tile_start;
let tile_data: Vec<f32> = (tile_start..tile_end)
.flat_map(|col| {
let offset = col * n_rows;
a_data[offset..offset + n_rows].iter().copied()
})
.collect();
let b_handle = self.client.create_from_slice(f32::as_bytes(b_data));
let b_tensor =
TensorHandle::<R>::new(b_handle, vec![n_rows, n], vec![1, n_rows], f32_storage);
let a_handle = self.client.create_from_slice(f32::as_bytes(&tile_data));
let a_tensor = TensorHandle::<R>::new(
a_handle,
vec![tile_m, n_rows],
vec![n_rows, 1],
f32_storage,
);
let out_tensor = TensorHandle::<R>::empty(&self.client, vec![tile_m, n], f32_storage);
let out_row_stride = out_tensor.strides()[0];
let mut dtypes = MatmulElems::new_deprecated::<flex32>();
cubek_matmul::launch::launch_ref::<R>(
&self.strategy,
&self.client,
&MatmulInputHandleRef::new(a_tensor.as_ref(), f32_storage),
&MatmulInputHandleRef::new(b_tensor.as_ref(), f32_storage),
&out_tensor.as_ref(),
&mut dtypes,
)
.map_err(|e| anyhow!("GPU flex32 matmul tile failed: {:?}", e))?;
let tile_bytes = self.client.read_one(out_tensor.handle);
let raw: &[f32] = f32::from_bytes(&tile_bytes);
for i in 0..tile_m {
for j in 0..n {
result[(tile_start + i) * n + j] = raw[i * out_row_stride + j];
}
}
}
Ok(result)
}
}