#![cfg(all(feature = "cuda", feature = "triton-kernels"))]
use candle_core::cuda_backend::CudaStorage;
use candle_core::{op::BackpropOp, DType, Storage, Tensor};
use cudarc::driver::PushKernelArg;
use crate::triton_meta::parse_meta;
use crate::triton_ptx;
const MODULE_NAME: &str = "triton_layer_norm";
pub fn layer_norm_triton(
x: &Tensor,
gamma: &Tensor,
beta: &Tensor,
eps: f32,
) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let dims = x.dims();
let dim = *dims.last().unwrap();
let num_rows = x.elem_count() / dim;
if dtype != DType::F32 {
candle_core::bail!(
"triton layer_norm: only F32 currently has a triton-rs port (got {dtype:?})"
);
}
let meta = parse_meta(triton_ptx::layer_norm_f32::META)?;
let cuda_dev = x.device().as_cuda_device()?;
let kernel_name: &'static str = Box::leak(meta.name.into_boxed_str());
let func = cuda_dev.get_or_load_custom_func(
kernel_name,
MODULE_NAME,
triton_ptx::layer_norm_f32::PTX,
)?;
let grid_size = num_rows as u32;
let block_size = (meta.num_warps * 32) as u32;
let dim_i32 = dim as i32;
let inv_dim: f32 = 1.0 / dim as f32;
let elem_count = num_rows * dim;
let global_scratch: cudarc::driver::CudaSlice<u8> =
cuda_dev.alloc_zeros::<u8>(meta.global_scratch_size.max(1))?;
let profile_scratch: cudarc::driver::CudaSlice<u8> =
cuda_dev.alloc_zeros::<u8>(meta.profile_scratch_size.max(1))?;
let (x_s, x_l) = x.storage_and_layout();
let (gamma_s, gamma_l) = gamma.storage_and_layout();
let (beta_s, beta_l) = beta.storage_and_layout();
let x_cuda = match &*x_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("x must be on CUDA"),
};
let gamma_cuda = match &*gamma_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("gamma must be on CUDA"),
};
let beta_cuda = match &*beta_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("beta must be on CUDA"),
};
let xs = x_cuda.as_cuda_slice::<f32>()?;
let g = gamma_cuda.as_cuda_slice::<f32>()?;
let bt = beta_cuda.as_cuda_slice::<f32>()?;
let out = unsafe { cuda_dev.alloc::<f32>(elem_count)? };
let xs = xs.slice(x_l.start_offset()..);
let g = g.slice(gamma_l.start_offset()..);
let bt = bt.slice(beta_l.start_offset()..);
let mut builder = func.builder();
builder.arg(&xs);
builder.arg(&g);
builder.arg(&bt);
builder.arg(&out);
builder.arg(&dim_i32);
builder.arg(&inv_dim);
builder.arg(&eps);
builder.arg(&global_scratch);
builder.arg(&profile_scratch);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: meta.shared_mem as u32,
};
unsafe { builder.launch(cfg) }
.map_err(|e| candle_core::Error::Msg(format!("triton layer_norm launch: {e}")))?;
let output_storage = CudaStorage::wrap_cuda_slice(out, cuda_dev.clone());
drop(x_s);
drop(gamma_s);
drop(beta_s);
let shape = x.shape().clone();
Ok(Tensor::from_storage(
Storage::Cuda(output_storage),
shape,
BackpropOp::none(),
false,
))
}