use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::{Context, DeviceBuffer, Stream};
use baracuda_kernels_sys::{
cublasCreate_v2, cublasDestroy_v2, cublasGemmEx, cublasHandle_t, cublasSetStream_v2,
CUBLAS_COMPUTE_32F, CUBLAS_COMPUTE_64F, CUBLAS_GEMM_DEFAULT, CUBLAS_OP_N, CUBLAS_OP_T,
CUDA_R_16BF, CUDA_R_16F, CUDA_R_32F, CUDA_R_64F,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, LossKind, LossReduction, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::common::check_supported_dtype;
pub const FLCE_DEFAULT_IGNORE_INDEX: i64 = -100;
const MAX_CHUNK_ROWS: i32 = 2048;
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct FusedLinearCrossEntropyDescriptor {
pub bt: i32,
pub h: i32,
pub v: i32,
pub reduction: LossReduction,
pub element: ElementKind,
pub ignore_index: i64,
}
impl FusedLinearCrossEntropyDescriptor {
#[inline]
pub fn new(bt: i32, h: i32, v: i32, element: ElementKind) -> Self {
Self {
bt,
h,
v,
reduction: LossReduction::Mean,
element,
ignore_index: FLCE_DEFAULT_IGNORE_INDEX,
}
}
#[inline]
#[must_use]
pub fn with_reduction(mut self, reduction: LossReduction) -> Self {
self.reduction = reduction;
self
}
#[inline]
#[must_use]
pub fn with_ignore_index(mut self, ignore_index: i64) -> Self {
self.ignore_index = ignore_index;
self
}
}
pub struct FusedLinearCrossEntropyArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 2>,
pub weight: TensorRef<'a, T, 2>,
pub target: TensorRef<'a, i64, 1>,
pub out: TensorMut<'a, T, 1>,
pub grad_input: Option<TensorMut<'a, T, 2>>,
pub grad_weight: Option<TensorMut<'a, T, 2>>,
}
pub struct FusedLinearCrossEntropyPlan<T: Element> {
desc: FusedLinearCrossEntropyDescriptor,
sku: KernelSku,
chunk_size: i32,
handle: Cell<cublasHandle_t>,
_marker: PhantomData<T>,
}
impl<T: Element> FusedLinearCrossEntropyPlan<T> {
pub fn select(
_stream: &Stream,
desc: &FusedLinearCrossEntropyDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FusedLinearCrossEntropyPlan: descriptor.element != T",
));
}
check_supported_dtype::<T>()?;
if desc.bt < 0 || desc.h < 1 || desc.v < 1 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: bt must be ≥ 0; h, v must be ≥ 1",
));
}
let chunk_size = pick_chunk_size(desc.bt, desc.h, desc.v);
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: if T::KIND == ElementKind::F64 {
ElementKind::F64
} else {
ElementKind::F32
},
bit_stable_on_same_hardware: false,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Loss,
op: LossKind::FusedLinearCrossEntropy as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
chunk_size,
handle: Cell::new(core::ptr::null_mut()),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn chunk_size(&self) -> i32 {
self.chunk_size
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn conceptual_scratch_bytes(&self) -> usize {
let elem_t = core::mem::size_of::<T>();
let logits_bytes = (self.chunk_size as usize) * (self.desc.v as usize) * elem_t;
let loss_bytes = (self.desc.bt as usize) * core::mem::size_of::<f32>();
logits_bytes + loss_bytes + 8 }
fn ensure_handle(&self) -> Result<cublasHandle_t> {
let h = self.handle.get();
if !h.is_null() {
return Ok(h);
}
let mut handle: cublasHandle_t = core::ptr::null_mut();
let mut last_status = 0;
for attempt in 0..5 {
let status = unsafe { cublasCreate_v2(&mut handle as *mut _) };
if status == 0 {
last_status = 0;
break;
}
last_status = status;
std::thread::sleep(std::time::Duration::from_millis(
10u64 * (attempt as u64 + 1),
));
}
if last_status != 0 {
return Err(Error::CutlassInternal(-last_status));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, h: cublasHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cublasSetStream_v2(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FusedLinearCrossEntropyArgs<'_, T>,
) -> Result<()> {
let bt = self.desc.bt;
let h = self.desc.h;
let v = self.desc.v;
if args.input.shape != [bt, h] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: input shape != [bt, h]",
));
}
if args.weight.shape != [v, h] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: weight shape != [v, h]",
));
}
if args.target.shape != [bt] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: target shape != [bt]",
));
}
if !args.input.is_contiguous() || !args.weight.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::FusedLinearCrossEntropyPlan: input and weight must be \
contiguous (Phase 47 v1 limitation)",
));
}
let expected_out_n = match self.desc.reduction {
LossReduction::None => bt,
LossReduction::Mean | LossReduction::Sum => 1,
};
if args.out.shape != [expected_out_n] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: out shape mismatch (expected \
[BT] for None or [1] for Mean/Sum)",
));
}
if let Some(ref gi) = args.grad_input {
if gi.shape != [bt, h] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: grad_input shape != [bt, h]",
));
}
}
if let Some(ref gw) = args.grad_weight {
if gw.shape != [v, h] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: grad_weight shape != [v, h]",
));
}
}
if bt == 0 {
return Ok(());
}
let ctx = stream.context();
let chunk_size = self.chunk_size;
let logits_elems = (chunk_size as usize) * (v as usize);
let mut logits_scratch: DeviceBuffer<T> =
DeviceBuffer::zeros(ctx, logits_elems).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: logits scratch alloc failed",
)
})?;
let mut loss_1d: DeviceBuffer<f32> = DeviceBuffer::zeros(ctx, bt as usize).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: loss_1d alloc failed",
)
})?;
let n_non_ignore = self.count_non_ignore(ctx, stream, &args.target)?;
if n_non_ignore == 0 {
self.zero_outputs(stream, &args)?;
return Ok(());
}
let scale_per_row: f32 = match self.desc.reduction {
LossReduction::Mean => 1.0f32 / (n_non_ignore as f32),
LossReduction::None | LossReduction::Sum => 1.0f32,
};
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let chunk_size_u = chunk_size as i32;
let n_chunks = (bt + chunk_size_u - 1) / chunk_size_u;
let input_ptr_base = args.input.data.as_raw().0 as *const c_void;
let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
let target_ptr = args.target.data.as_raw().0 as *const c_void;
let logits_ptr = logits_scratch.as_slice_mut().as_raw().0 as *mut c_void;
let loss_1d_ptr = loss_1d.as_slice_mut().as_raw().0 as *mut c_void;
let grad_input_ptr_base = args
.grad_input
.as_ref()
.map(|gi| gi.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let grad_weight_ptr = args
.grad_weight
.as_ref()
.map(|gw| gw.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let elem_t = core::mem::size_of::<T>() as isize;
let input_row_stride_elems = args.input.stride[0] as isize;
let grad_input_row_stride_elems = args
.grad_input
.as_ref()
.map(|gi| gi.stride[0] as isize)
.unwrap_or(0);
for chunk_id in 0..n_chunks {
let start = chunk_id * chunk_size_u;
let end = core::cmp::min((chunk_id + 1) * chunk_size_u, bt);
let n_rows = end - start;
if n_rows == 0 {
break;
}
let input_chunk_ptr = unsafe {
(input_ptr_base as *const u8)
.offset(start as isize * input_row_stride_elems * elem_t)
as *const c_void
};
let alpha_f32 = 1.0f32;
let beta_zero_f32 = 0.0f32;
let alpha_f64 = 1.0f64;
let beta_zero_f64 = 0.0f64;
self.gemm_ex(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
v,
n_rows,
h,
if T::KIND == ElementKind::F64 {
&alpha_f64 as *const f64 as *const c_void
} else {
&alpha_f32 as *const f32 as *const c_void
},
weight_ptr,
v, h as i32, input_chunk_ptr,
h as i32, if T::KIND == ElementKind::F64 {
&beta_zero_f64 as *const f64 as *const c_void
} else {
&beta_zero_f32 as *const f32 as *const c_void
},
logits_ptr,
v, )?;
let loss_1d_chunk_ptr = unsafe {
(loss_1d_ptr as *mut u8).offset(start as isize * core::mem::size_of::<f32>() as isize)
as *mut c_void
};
let target_chunk_ptr = unsafe {
(target_ptr as *const u8)
.offset(start as isize * core::mem::size_of::<i64>() as isize)
as *const c_void
};
let row_stride_logits = v as i64;
let status = unsafe {
match T::KIND {
ElementKind::F32 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_f32_run(
n_rows, v, row_stride_logits, self.desc.ignore_index, scale_per_row,
logits_ptr, target_chunk_ptr, loss_1d_chunk_ptr,
stream.as_raw() as *mut c_void,
),
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_f16_run(
n_rows, v, row_stride_logits, self.desc.ignore_index, scale_per_row,
logits_ptr, target_chunk_ptr, loss_1d_chunk_ptr,
stream.as_raw() as *mut c_void,
),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_bf16_run(
n_rows, v, row_stride_logits, self.desc.ignore_index, scale_per_row,
logits_ptr, target_chunk_ptr, loss_1d_chunk_ptr,
stream.as_raw() as *mut c_void,
),
ElementKind::F64 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_f64_run(
n_rows, v, row_stride_logits, self.desc.ignore_index, scale_per_row,
logits_ptr, target_chunk_ptr, loss_1d_chunk_ptr,
stream.as_raw() as *mut c_void,
),
_ => return Err(Error::Unsupported(
"baracuda-kernels::FusedLinearCrossEntropyPlan::run unwired dtype",
)),
}
};
if status != 0 {
return Err(Error::CutlassInternal(status));
}
if let Some(_) = args.grad_input.as_ref() {
let grad_input_chunk_ptr = unsafe {
(grad_input_ptr_base as *mut u8)
.offset(start as isize * grad_input_row_stride_elems * elem_t)
as *mut c_void
};
self.gemm_ex(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
h,
n_rows,
v,
if T::KIND == ElementKind::F64 {
&alpha_f64 as *const f64 as *const c_void
} else {
&alpha_f32 as *const f32 as *const c_void
},
weight_ptr,
h, h as i32, logits_ptr,
v as i32, if T::KIND == ElementKind::F64 {
&beta_zero_f64 as *const f64 as *const c_void
} else {
&beta_zero_f32 as *const f32 as *const c_void
},
grad_input_chunk_ptr,
h as i32, )?;
}
if let Some(_) = args.grad_weight.as_ref() {
let beta_one_f32 = 1.0f32;
let beta_one_f64 = 1.0f64;
self.gemm_ex(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
h,
v,
n_rows,
if T::KIND == ElementKind::F64 {
&alpha_f64 as *const f64 as *const c_void
} else {
&alpha_f32 as *const f32 as *const c_void
},
input_chunk_ptr,
h, h as i32, logits_ptr,
v as i32, if T::KIND == ElementKind::F64 {
&beta_one_f64 as *const f64 as *const c_void
} else {
&beta_one_f32 as *const f32 as *const c_void
},
grad_weight_ptr,
h as i32, )?;
}
}
let out_ptr = args.out.data.as_raw().0 as *mut c_void;
let bt_i64 = bt as i64;
let status = match self.desc.reduction {
LossReduction::None => unsafe {
match T::KIND {
ElementKind::F32 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_cast_f32_run(
bt_i64, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_cast_f16_run(
bt_i64, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_cast_bf16_run(
bt_i64, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::F64 => baracuda_kernels_sys::baracuda_kernels_loss_flce_per_row_cast_f64_run(
bt_i64, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
_ => return Err(Error::Unsupported("unwired dtype")),
}
},
LossReduction::Mean | LossReduction::Sum => {
let denom_inv = match self.desc.reduction {
LossReduction::Mean => 1.0f32 / (n_non_ignore as f32),
_ => 1.0f32,
};
unsafe {
match T::KIND {
ElementKind::F32 => baracuda_kernels_sys::baracuda_kernels_loss_flce_scalar_finalize_f32_run(
bt_i64, denom_inv, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_scalar_finalize_f16_run(
bt_i64, denom_inv, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_scalar_finalize_bf16_run(
bt_i64, denom_inv, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
ElementKind::F64 => baracuda_kernels_sys::baracuda_kernels_loss_flce_scalar_finalize_f64_run(
bt_i64, denom_inv, loss_1d_ptr as *const c_void, out_ptr, stream.as_raw() as *mut c_void),
_ => return Err(Error::Unsupported("unwired dtype")),
}
}
}
};
if status != 0 {
return Err(Error::CutlassInternal(status));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn gemm_ex(
&self,
handle: cublasHandle_t,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void,
_m_marker: i32, lda: i32,
b: *const c_void,
ldb: i32,
beta: *const c_void,
c: *mut c_void,
ldc: i32,
) -> Result<()> {
let (data_type, compute_type) = match T::KIND {
ElementKind::F16 => (CUDA_R_16F, CUBLAS_COMPUTE_32F),
ElementKind::Bf16 => (CUDA_R_16BF, CUBLAS_COMPUTE_32F),
ElementKind::F32 => (CUDA_R_32F, CUBLAS_COMPUTE_32F),
ElementKind::F64 => (CUDA_R_64F, CUBLAS_COMPUTE_64F),
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::FusedLinearCrossEntropyPlan::gemm_ex: unwired dtype",
))
}
};
let status = unsafe {
cublasGemmEx(
handle, transa, transb, m, n, k,
alpha, a, data_type, lda, b, data_type, ldb, beta, c, data_type, ldc,
compute_type, CUBLAS_GEMM_DEFAULT,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn count_non_ignore(
&self,
ctx: &Context,
stream: &Stream,
target: &TensorRef<'_, i64, 1>,
) -> Result<usize> {
let bt = self.desc.bt;
if bt == 0 {
return Ok(0);
}
let mut count_dev: DeviceBuffer<i64> = DeviceBuffer::zeros(ctx, 1).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: count buffer alloc failed",
)
})?;
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_loss_flce_count_non_ignore_run(
bt,
self.desc.ignore_index,
target.data.as_raw().0 as *const c_void,
count_dev.as_slice_mut().as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void,
)
};
if status != 0 {
return Err(Error::CutlassInternal(status));
}
let mut host = [0i64; 1];
count_dev.copy_to_host(&mut host).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: count D2H failed",
)
})?;
Ok(host[0] as usize)
}
fn zero_outputs(
&self,
stream: &Stream,
args: &FusedLinearCrossEntropyArgs<'_, T>,
) -> Result<()> {
use baracuda_driver::memory::memset_u8_async;
let out_bytes = args.out.numel() as usize * core::mem::size_of::<T>();
if out_bytes > 0 {
memset_u8_async(args.out.data.as_raw(), 0, out_bytes, stream).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: zero out failed",
)
})?;
}
if let Some(ref gi) = args.grad_input {
let bytes = gi.numel() as usize * core::mem::size_of::<T>();
if bytes > 0 {
memset_u8_async(gi.data.as_raw(), 0, bytes, stream).map_err(|_| {
Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyPlan: zero grad_input failed",
)
})?;
}
}
let _ = args.grad_weight.as_ref();
Ok(())
}
}
impl<T: Element> Drop for FusedLinearCrossEntropyPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if !h.is_null() {
unsafe {
cublasDestroy_v2(h);
}
self.handle.set(core::ptr::null_mut());
}
}
}
fn pick_chunk_size(bt: i32, h: i32, v: i32) -> i32 {
if bt <= 0 {
return 1;
}
let inc_factor = (v + h - 1) / h;
let raw = (bt + inc_factor - 1) / inc_factor;
let pw2 = next_pow2_i32(raw);
core::cmp::min(pw2, MAX_CHUNK_ROWS).max(1)
}
fn next_pow2_i32(x: i32) -> i32 {
if x <= 1 {
return 1;
}
let mut v = (x - 1) as u32;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
(v + 1) as i32
}
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct FusedLinearCrossEntropyBackwardDescriptor {
pub bt: i32,
pub h: i32,
pub v: i32,
pub element: ElementKind,
}
impl FusedLinearCrossEntropyBackwardDescriptor {
#[inline]
pub fn new(bt: i32, h: i32, v: i32, element: ElementKind) -> Self {
Self { bt, h, v, element }
}
}
pub struct FusedLinearCrossEntropyBackwardArgs<'a, T: Element> {
pub dy_scalar: f32,
pub grad_input: Option<TensorMut<'a, T, 2>>,
pub grad_weight: Option<TensorMut<'a, T, 2>>,
}
pub struct FusedLinearCrossEntropyBackwardPlan<T: Element> {
#[allow(dead_code)]
desc: FusedLinearCrossEntropyBackwardDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> FusedLinearCrossEntropyBackwardPlan<T> {
pub fn select(
_stream: &Stream,
desc: &FusedLinearCrossEntropyBackwardDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FusedLinearCrossEntropyBackwardPlan: descriptor.element != T",
));
}
check_supported_dtype::<T>()?;
if desc.bt < 0 || desc.h < 1 || desc.v < 1 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FusedLinearCrossEntropyBackwardPlan: invalid shape",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: if T::KIND == ElementKind::F64 {
ElementKind::F64
} else {
ElementKind::F32
},
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Loss,
op: LossKind::FusedLinearCrossEntropy as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FusedLinearCrossEntropyBackwardArgs<'_, T>,
) -> Result<()> {
let dy_scalar_f32 = args.dy_scalar;
if dy_scalar_f32 == 1.0 {
return Ok(());
}
if let Some(ref gi) = args.grad_input {
let numel = gi.numel();
let status = unsafe {
match T::KIND {
ElementKind::F32 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f32_run(
numel, dy_scalar_f32, gi.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f16_run(
numel, dy_scalar_f32, gi.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_bf16_run(
numel, dy_scalar_f32, gi.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::F64 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f64_run(
numel, dy_scalar_f32, gi.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
_ => return Err(Error::Unsupported("unwired dtype")),
}
};
if status != 0 {
return Err(Error::CutlassInternal(status));
}
}
if let Some(ref gw) = args.grad_weight {
let numel = gw.numel();
let status = unsafe {
match T::KIND {
ElementKind::F32 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f32_run(
numel, dy_scalar_f32, gw.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f16_run(
numel, dy_scalar_f32, gw.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_bf16_run(
numel, dy_scalar_f32, gw.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
ElementKind::F64 => baracuda_kernels_sys::baracuda_kernels_loss_flce_inplace_scale_f64_run(
numel, dy_scalar_f32, gw.data.as_raw().0 as *mut c_void,
stream.as_raw() as *mut c_void),
_ => return Err(Error::Unsupported("unwired dtype")),
}
};
if status != 0 {
return Err(Error::CutlassInternal(status));
}
}
Ok(())
}
}
#[cfg(test)]
mod chunk_size_tests {
use super::*;
#[test]
fn llama3_class_picks_2048() {
let cs = pick_chunk_size(4096 * 4, 4096, 32000);
assert_eq!(cs, 2048);
}
#[test]
fn small_problem_caps_at_bt() {
let cs = pick_chunk_size(128, 4096, 1000);
assert_eq!(cs, 128);
}
#[test]
fn empty_bt() {
let cs = pick_chunk_size(0, 128, 256);
assert_eq!(cs, 1);
}
#[test]
fn vocab_128k_llama3() {
let cs = pick_chunk_size(16384, 4096, 128 * 1024);
assert_eq!(cs, 512);
}
}