use std::sync::Arc;
use candle_core::{
CpuStorage, CudaStorage, CustomOp2, DType, Error, Layout, Result, Shape, Tensor,
};
use half::bf16;
use kaio::prelude::{GpuBuffer, KaioDevice};
use kaio_ops::matmul_tc_bf16 as kaio_matmul_tc_bf16;
use crate::bridge;
pub struct MatmulTcBf16Op {
pub device: Arc<KaioDevice>,
}
impl CustomOp2 for MatmulTcBf16Op {
fn name(&self) -> &'static str {
"kaio::matmul_tc_bf16"
}
fn cpu_fwd(
&self,
_s1: &CpuStorage,
_l1: &Layout,
_s2: &CpuStorage,
_l2: &Layout,
) -> Result<(CpuStorage, Shape)> {
Err(Error::Msg(
"kaio-candle::matmul_tc_bf16: CPU fallback not supported. \
This op requires a CUDA device (bf16 variant requires SM 8.0+ \
for bf16 mma). KAIO's value prop is GPU-specific PTX — falling \
back to CPU would silently route around every perf claim. Call \
`.to_device(&Device::new_cuda(0)?)` on your tensors first."
.to_string(),
))
}
fn cuda_fwd(
&self,
s1: &CudaStorage,
l1: &Layout,
s2: &CudaStorage,
l2: &Layout,
) -> Result<(CudaStorage, Shape)> {
let (m_a, k_a) = bridge::ensure_rank2_contiguous_zero_offset("matmul_tc_bf16", 0, l1)?;
let (k_b, n_b) = bridge::ensure_rank2_contiguous_zero_offset("matmul_tc_bf16", 1, l2)?;
if k_a != k_b {
return Err(Error::Msg(format!(
"kaio-candle::matmul_tc_bf16: K mismatch between inputs — \
input #0 has shape [{m_a}, {k_a}] (K = {k_a}), \
input #1 has shape [{k_b}, {n_b}] (K = {k_b}). \
Inner dimensions must match."
)));
}
let m = u32::try_from(m_a)
.map_err(|_| Error::Msg(format!("matmul_tc_bf16: M ({m_a}) exceeds u32")))?;
let n = u32::try_from(n_b)
.map_err(|_| Error::Msg(format!("matmul_tc_bf16: N ({n_b}) exceeds u32")))?;
let k = u32::try_from(k_a)
.map_err(|_| Error::Msg(format!("matmul_tc_bf16: K ({k_a}) exceeds u32")))?;
let candle_dev = s1.device.clone();
bridge::ensure_ordinal_match(&candle_dev, &self.device)?;
let a_slice = bridge::slice_ref_from_storage::<bf16>(s1)?;
let b_slice = bridge::slice_ref_from_storage::<bf16>(s2)?;
let a_buf: &GpuBuffer<bf16> = bridge::buffer_ref_from_slice_readonly(a_slice);
let b_buf: &GpuBuffer<bf16> = bridge::buffer_ref_from_slice_readonly(b_slice);
let mut out_buf: GpuBuffer<f32> = self
.device
.alloc_zeros::<f32>(m_a * n_b)
.map_err(bridge::kaio_err)?;
bridge::sync_before_launch(&candle_dev, &self.device)?;
kaio_matmul_tc_bf16(&self.device, a_buf, b_buf, &mut out_buf, m, n, k)
.map_err(bridge::kaio_err)?;
bridge::sync_after_launch(&candle_dev, &self.device)?;
let out_slice = out_buf.into_cuda_slice();
let out_storage = bridge::storage_from_slice::<f32>(out_slice, candle_dev);
Ok((out_storage, Shape::from_dims(&[m_a, n_b])))
}
fn bwd(
&self,
a: &Tensor,
b: &Tensor,
_res: &Tensor,
grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
let grad_bf16 = grad_res.to_dtype(DType::BF16)?;
let b_t = b.t()?.contiguous()?;
let grad_a = matmul_tc_bf16(&self.device, &grad_bf16, &b_t)?;
let a_t = a.t()?.contiguous()?;
let grad_b = matmul_tc_bf16(&self.device, &a_t, &grad_bf16)?;
Ok((
Some(grad_a.to_dtype(DType::BF16)?),
Some(grad_b.to_dtype(DType::BF16)?),
))
}
}
pub fn matmul_tc_bf16(device: &Arc<KaioDevice>, a: &Tensor, b: &Tensor) -> Result<Tensor> {
a.apply_op2(
b,
MatmulTcBf16Op {
device: device.clone(),
},
)
}