use cubecl::prelude::*;
use crate::{
CubeRuntime, kernel::into_contiguous, ops::numeric::empty_device_dtype, tensor::CubeTensor,
};
use burn_backend::{Shape, TensorMetadata};
const SHARED_ALPHA_CAPACITY: u32 = 8192;
#[cube]
fn l_prime_class<I: Numeric>(
s: usize,
targets: &Tensor<I>,
n: usize,
tgt_n: usize,
tgt_s: usize,
blank: usize,
) -> usize {
if s % 2 == 1 {
u32::cast_from(targets[n * tgt_n + ((s - 1) / 2) * tgt_s]) as usize
} else {
blank
}
}
#[cube]
fn log_sum_exp2<F: Float>(a: F, b: F, unreachable_threshold: F, one: F) -> F {
let mut mx = a;
let mut mn = b;
if b > a {
mx = b;
mn = a;
}
if mx < unreachable_threshold {
mx
} else {
mx + (one + (mn - mx).exp()).ln()
}
}
#[cube]
fn recurrence_step<F: Float>(
near: F,
near_m1: F,
near_m2: F,
log_p: F,
skip_allowed: bool,
unreachable_threshold: F,
one: F,
) -> F {
let lse_01 = log_sum_exp2::<F>(near, near_m1, unreachable_threshold, one);
let combined = if skip_allowed {
log_sum_exp2::<F>(lse_01, near_m2, unreachable_threshold, one)
} else {
lse_01
};
log_p + combined
}
#[cube]
fn finalize_nll<F: Float>(
last_blank: F,
last_label: F,
target_len: usize,
unreachable_threshold: F,
one: F,
) -> F {
let mut mx = last_blank;
let mut mn = last_label;
if last_label > last_blank {
mx = last_label;
mn = last_blank;
}
if mx < unreachable_threshold {
(F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
} else {
F::new(0.0) - (mx + (one + (mn - mx).exp()).ln())
}
}
#[cube]
fn empty_input_nll<F: Float>(target_len: usize) -> F {
if target_len == 0 {
F::new(0.0)
} else {
(F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
}
}
#[cube(launch)]
fn ctc_loss_kernel<F: Float, I: Numeric>(
log_probs: &Tensor<F>, targets: &Tensor<I>, input_lengths: &Tensor<I>, target_lengths: &Tensor<I>, output: &mut Tensor<F>, blank: u32,
#[comptime] alpha_capacity: u32,
#[define(F, I)] _dtypes: [StorageType; 2],
) {
let n = CUBE_POS_X as usize;
let cube_dim = CUBE_DIM_X as usize;
let alpha_cap = alpha_capacity as usize;
let blank_u = blank as usize;
let target_len = u32::cast_from(target_lengths[n]) as usize;
let input_len = u32::cast_from(input_lengths[n]) as usize;
let l_prime_len = 2 * target_len + 1;
if input_len == 0 {
if UNIT_POS_X == 0 {
output[n] = empty_input_nll::<F>(target_len);
}
terminate!();
}
let lp_t = log_probs.stride(0);
let lp_n = log_probs.stride(1);
let lp_c = log_probs.stride(2);
let tgt_n = targets.stride(0);
let tgt_s = targets.stride(1);
let mut alpha = SharedMemory::<F>::new(2 * alpha_cap);
let neg_inf = F::new(-6.0e4_f32);
let unreachable_threshold = F::new(-1.0e4_f32);
let one = F::new(1.0);
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let mut init = neg_inf;
if s == 0 {
init = log_probs[n * lp_n + blank_u * lp_c];
} else if s == 1 {
let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
init = log_probs[n * lp_n + l1 * lp_c];
}
alpha[s] = init;
s += cube_dim;
}
sync_cube();
for t in 1..input_len {
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
let l_class_m2 = if s >= 2 {
l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
} else {
blank_u
};
let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
let a_s = alpha[s];
let mut a_s_m1 = neg_inf;
if s >= 1 {
a_s_m1 = alpha[s - 1];
}
let mut a_s_m2 = neg_inf;
if s >= 2 {
a_s_m2 = alpha[s - 2];
}
alpha[alpha_cap + s] = recurrence_step::<F>(
a_s,
a_s_m1,
a_s_m2,
log_p,
skip_allowed,
unreachable_threshold,
one,
);
s += cube_dim;
}
sync_cube();
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
alpha[s] = alpha[alpha_cap + s];
s += cube_dim;
}
sync_cube();
}
if UNIT_POS_X == 0 {
let last_blank = alpha[2 * target_len];
let mut last_label = neg_inf;
if target_len > 0 {
last_label = alpha[2 * target_len - 1];
}
output[n] = finalize_nll::<F>(
last_blank,
last_label,
target_len,
unreachable_threshold,
one,
);
}
}
pub fn ctc_loss<R: CubeRuntime>(
log_probs: CubeTensor<R>,
targets: CubeTensor<R>,
input_lengths: CubeTensor<R>,
target_lengths: CubeTensor<R>,
blank: usize,
) -> CubeTensor<R> {
let log_probs = into_contiguous(log_probs);
let targets = into_contiguous(targets);
let input_lengths = into_contiguous(input_lengths);
let target_lengths = into_contiguous(target_lengths);
let log_probs_shape = log_probs.shape();
let [_t, batch_size, _c] = log_probs_shape.dims::<3>();
let target_shape = targets.shape();
let max_target_len = target_shape.dims::<2>()[1];
let max_l_prime = 2 * max_target_len + 1;
assert!(
max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
"ctc_loss: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
max_l_prime,
SHARED_ALPHA_CAPACITY,
);
let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
let client = log_probs.client.clone();
let device = log_probs.device.clone();
let f_dtype = log_probs.dtype;
let i_dtype = targets.dtype;
let output = empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
let cube_dim = CubeDim::new_1d(cube_dim_x);
ctc_loss_kernel::launch::<R>(
&client,
cube_count,
cube_dim,
log_probs.into_tensor_arg(),
targets.into_tensor_arg(),
input_lengths.into_tensor_arg(),
target_lengths.into_tensor_arg(),
output.clone().into_tensor_arg(),
blank as u32,
max_l_prime as u32,
[f_dtype.into(), i_dtype.into()],
);
output
}
#[cube(launch)]
fn ctc_alpha_beta_kernel<F: Float, I: Numeric>(
log_probs: &Tensor<F>, targets: &Tensor<I>, input_lengths: &Tensor<I>, target_lengths: &Tensor<I>, alpha_out: &mut Tensor<F>, beta_out: &mut Tensor<F>, nll_out: &mut Tensor<F>, blank: u32,
#[comptime] alpha_capacity: u32,
#[define(F, I)] _dtypes: [StorageType; 2],
) {
let n = CUBE_POS_X as usize;
let cube_dim = CUBE_DIM_X as usize;
let alpha_cap = alpha_capacity as usize;
let blank_u = blank as usize;
let target_len = u32::cast_from(target_lengths[n]) as usize;
let input_len = u32::cast_from(input_lengths[n]) as usize;
let l_prime_len = 2 * target_len + 1;
if input_len == 0 {
if UNIT_POS_X == 0 {
nll_out[n] = empty_input_nll::<F>(target_len);
}
terminate!();
}
let lp_t = log_probs.stride(0);
let lp_n = log_probs.stride(1);
let lp_c = log_probs.stride(2);
let tgt_n = targets.stride(0);
let tgt_s = targets.stride(1);
let ao_t = alpha_out.stride(0);
let ao_n = alpha_out.stride(1);
let ao_s = alpha_out.stride(2);
let bo_t = beta_out.stride(0);
let bo_n = beta_out.stride(1);
let bo_s = beta_out.stride(2);
let mut state = SharedMemory::<F>::new(2 * alpha_cap);
let neg_inf = F::new(-6.0e4_f32);
let unreachable_threshold = F::new(-1.0e4_f32);
let one = F::new(1.0);
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let mut init = neg_inf;
if s == 0 {
init = log_probs[n * lp_n + blank_u * lp_c];
} else if s == 1 {
let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
init = log_probs[n * lp_n + l1 * lp_c];
}
state[s] = init;
alpha_out[n * ao_n + s * ao_s] = init;
s += cube_dim;
}
sync_cube();
for t in 1..input_len {
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
let l_class_m2 = if s >= 2 {
l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
} else {
blank_u
};
let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
let a_s = state[s];
let mut a_s_m1 = neg_inf;
if s >= 1 {
a_s_m1 = state[s - 1];
}
let mut a_s_m2 = neg_inf;
if s >= 2 {
a_s_m2 = state[s - 2];
}
state[alpha_cap + s] = recurrence_step::<F>(
a_s,
a_s_m1,
a_s_m2,
log_p,
skip_allowed,
unreachable_threshold,
one,
);
s += cube_dim;
}
sync_cube();
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
state[s] = state[alpha_cap + s];
alpha_out[t * ao_t + n * ao_n + s * ao_s] = state[s];
s += cube_dim;
}
sync_cube();
}
if UNIT_POS_X == 0 {
let last_blank = state[2 * target_len];
let mut last_label = neg_inf;
if target_len > 0 {
last_label = state[2 * target_len - 1];
}
nll_out[n] = finalize_nll::<F>(
last_blank,
last_label,
target_len,
unreachable_threshold,
one,
);
}
sync_cube();
let t_last = input_len - 1;
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let is_last_blank = s == 2 * target_len;
let is_last_label = target_len > 0 && s == 2 * target_len - 1;
let mut init = neg_inf;
if is_last_blank || is_last_label {
let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
init = log_probs[t_last * lp_t + n * lp_n + l_class * lp_c];
}
state[s] = init;
beta_out[t_last * bo_t + n * bo_n + s * bo_s] = init;
s += cube_dim;
}
sync_cube();
for t_rev in 1..input_len {
let t = input_len - 1 - t_rev;
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
let l_class_p2 = if s + 2 < l_prime_len {
l_prime_class::<I>(s + 2, targets, n, tgt_n, tgt_s, blank_u)
} else {
blank_u
};
let skip_allowed = s + 2 < l_prime_len && l_class != blank_u && l_class != l_class_p2;
let b_s = state[s];
let mut b_s_p1 = neg_inf;
if s + 1 < l_prime_len {
b_s_p1 = state[s + 1];
}
let mut b_s_p2 = neg_inf;
if s + 2 < l_prime_len {
b_s_p2 = state[s + 2];
}
state[alpha_cap + s] = recurrence_step::<F>(
b_s,
b_s_p1,
b_s_p2,
log_p,
skip_allowed,
unreachable_threshold,
one,
);
s += cube_dim;
}
sync_cube();
let mut s = UNIT_POS_X as usize;
while s < l_prime_len {
state[s] = state[alpha_cap + s];
beta_out[t * bo_t + n * bo_n + s * bo_s] = state[s];
s += cube_dim;
}
sync_cube();
}
}
pub fn ctc_alpha_beta<R: CubeRuntime>(
log_probs: CubeTensor<R>,
targets: CubeTensor<R>,
input_lengths: CubeTensor<R>,
target_lengths: CubeTensor<R>,
blank: usize,
) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {
let log_probs = into_contiguous(log_probs);
let targets = into_contiguous(targets);
let input_lengths = into_contiguous(input_lengths);
let target_lengths = into_contiguous(target_lengths);
let log_probs_shape = log_probs.shape();
let [max_input_length, batch_size, _c] = log_probs_shape.dims::<3>();
let target_shape = targets.shape();
let max_target_len = target_shape.dims::<2>()[1];
let max_l_prime = 2 * max_target_len + 1;
assert!(
max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
"ctc_loss_backward: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
max_l_prime,
SHARED_ALPHA_CAPACITY,
);
let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
let client = log_probs.client.clone();
let device = log_probs.device.clone();
let f_dtype = log_probs.dtype;
let i_dtype = targets.dtype;
let shape_abt = Shape::new([max_input_length, batch_size, max_l_prime]);
let neg_inf = InputScalar::new(f32::NEG_INFINITY, f_dtype);
let alpha_out = crate::ops::numeric::full_device_dtype::<R>(
client.clone(),
shape_abt.clone(),
device.clone(),
neg_inf,
f_dtype,
);
let beta_out = crate::ops::numeric::full_device_dtype::<R>(
client.clone(),
shape_abt,
device.clone(),
neg_inf,
f_dtype,
);
let nll_out =
empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
let cube_dim = CubeDim::new_1d(cube_dim_x);
ctc_alpha_beta_kernel::launch::<R>(
&client,
cube_count,
cube_dim,
log_probs.into_tensor_arg(),
targets.into_tensor_arg(),
input_lengths.into_tensor_arg(),
target_lengths.into_tensor_arg(),
alpha_out.clone().into_tensor_arg(),
beta_out.clone().into_tensor_arg(),
nll_out.clone().into_tensor_arg(),
blank as u32,
max_l_prime as u32,
[f_dtype.into(), i_dtype.into()],
);
(alpha_out, beta_out, nll_out)
}