1use cubecl::prelude::*;
2
3use crate::{
4 CubeRuntime, kernel::into_contiguous, ops::numeric::empty_device_dtype, tensor::CubeTensor,
5};
6use burn_backend::{Shape, TensorMetadata};
7
8const SHARED_ALPHA_CAPACITY: u32 = 8192;
15
16#[cube]
19fn l_prime_class<I: Numeric>(
20 s: usize,
21 targets: &Tensor<I>,
22 n: usize,
23 tgt_n: usize,
24 tgt_s: usize,
25 blank: usize,
26) -> usize {
27 if s % 2 == 1 {
28 u32::cast_from(targets[n * tgt_n + ((s - 1) / 2) * tgt_s]) as usize
29 } else {
30 blank
31 }
32}
33
34#[cube]
48fn log_sum_exp2<F: Float>(a: F, b: F, unreachable_threshold: F, one: F) -> F {
49 let mut mx = a;
50 let mut mn = b;
51 if b > a {
52 mx = b;
53 mn = a;
54 }
55 if mx < unreachable_threshold {
56 mx
57 } else {
58 mx + (one + (mn - mx).exp()).ln()
59 }
60}
61
62#[cube]
67fn recurrence_step<F: Float>(
68 near: F,
69 near_m1: F,
70 near_m2: F,
71 log_p: F,
72 skip_allowed: bool,
73 unreachable_threshold: F,
74 one: F,
75) -> F {
76 let lse_01 = log_sum_exp2::<F>(near, near_m1, unreachable_threshold, one);
77 let combined = if skip_allowed {
78 log_sum_exp2::<F>(lse_01, near_m2, unreachable_threshold, one)
79 } else {
80 lse_01
81 };
82 log_p + combined
83}
84
85#[cube]
92fn finalize_nll<F: Float>(
93 last_blank: F,
94 last_label: F,
95 target_len: usize,
96 unreachable_threshold: F,
97 one: F,
98) -> F {
99 let mut mx = last_blank;
100 let mut mn = last_label;
101 if last_label > last_blank {
102 mx = last_label;
103 mn = last_blank;
104 }
105 if mx < unreachable_threshold {
106 (F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
107 } else {
108 F::new(0.0) - (mx + (one + (mn - mx).exp()).ln())
109 }
110}
111
112#[cube]
116fn empty_input_nll<F: Float>(target_len: usize) -> F {
117 if target_len == 0 {
118 F::new(0.0)
119 } else {
120 (F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
121 }
122}
123
124#[cube(launch)]
144fn ctc_loss_kernel<F: Float, I: Numeric>(
145 log_probs: &Tensor<F>, targets: &Tensor<I>, input_lengths: &Tensor<I>, target_lengths: &Tensor<I>, output: &mut Tensor<F>, blank: u32,
151 #[comptime] alpha_capacity: u32,
152 #[define(F, I)] _dtypes: [StorageType; 2],
153) {
154 let n = CUBE_POS_X as usize;
155 let cube_dim = CUBE_DIM_X as usize;
156 let alpha_cap = alpha_capacity as usize;
157 let blank_u = blank as usize;
158
159 let target_len = u32::cast_from(target_lengths[n]) as usize;
160 let input_len = u32::cast_from(input_lengths[n]) as usize;
161 let l_prime_len = 2 * target_len + 1;
162
163 if input_len == 0 {
166 if UNIT_POS_X == 0 {
167 output[n] = empty_input_nll::<F>(target_len);
168 }
169 terminate!();
170 }
171
172 let lp_t = log_probs.stride(0);
173 let lp_n = log_probs.stride(1);
174 let lp_c = log_probs.stride(2);
175 let tgt_n = targets.stride(0);
176 let tgt_s = targets.stride(1);
177
178 let mut alpha = SharedMemory::<F>::new(2 * alpha_cap);
184 let neg_inf = F::new(-6.0e4_f32);
192 let unreachable_threshold = F::new(-1.0e4_f32);
193 let one = F::new(1.0);
194
195 let mut s = UNIT_POS_X as usize;
199 while s < l_prime_len {
200 let mut init = neg_inf;
201 if s == 0 {
202 init = log_probs[n * lp_n + blank_u * lp_c];
203 } else if s == 1 {
204 let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
205 init = log_probs[n * lp_n + l1 * lp_c];
206 }
207 alpha[s] = init;
208 s += cube_dim;
209 }
210 sync_cube();
211
212 for t in 1..input_len {
216 let mut s = UNIT_POS_X as usize;
217 while s < l_prime_len {
218 let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
219 let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
220
221 let l_class_m2 = if s >= 2 {
222 l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
223 } else {
224 blank_u
225 };
226 let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
227
228 let a_s = alpha[s];
229 let mut a_s_m1 = neg_inf;
230 if s >= 1 {
231 a_s_m1 = alpha[s - 1];
232 }
233 let mut a_s_m2 = neg_inf;
234 if s >= 2 {
235 a_s_m2 = alpha[s - 2];
236 }
237
238 alpha[alpha_cap + s] = recurrence_step::<F>(
239 a_s,
240 a_s_m1,
241 a_s_m2,
242 log_p,
243 skip_allowed,
244 unreachable_threshold,
245 one,
246 );
247 s += cube_dim;
248 }
249 sync_cube();
250
251 let mut s = UNIT_POS_X as usize;
253 while s < l_prime_len {
254 alpha[s] = alpha[alpha_cap + s];
255 s += cube_dim;
256 }
257 sync_cube();
258 }
259
260 if UNIT_POS_X == 0 {
262 let last_blank = alpha[2 * target_len];
263 let mut last_label = neg_inf;
267 if target_len > 0 {
268 last_label = alpha[2 * target_len - 1];
269 }
270 output[n] = finalize_nll::<F>(
271 last_blank,
272 last_label,
273 target_len,
274 unreachable_threshold,
275 one,
276 );
277 }
278}
279
280pub fn ctc_loss<R: CubeRuntime>(
285 log_probs: CubeTensor<R>,
286 targets: CubeTensor<R>,
287 input_lengths: CubeTensor<R>,
288 target_lengths: CubeTensor<R>,
289 blank: usize,
290) -> CubeTensor<R> {
291 let log_probs = into_contiguous(log_probs);
295 let targets = into_contiguous(targets);
296 let input_lengths = into_contiguous(input_lengths);
297 let target_lengths = into_contiguous(target_lengths);
298
299 let log_probs_shape = log_probs.shape();
300 let [_t, batch_size, _c] = log_probs_shape.dims::<3>();
301 let target_shape = targets.shape();
302 let max_target_len = target_shape.dims::<2>()[1];
303 let max_l_prime = 2 * max_target_len + 1;
304
305 assert!(
306 max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
307 "ctc_loss: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
308 alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
309 max_l_prime,
310 SHARED_ALPHA_CAPACITY,
311 );
312
313 let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
316 let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
317
318 let client = log_probs.client.clone();
319 let device = log_probs.device.clone();
320 let f_dtype = log_probs.dtype;
321 let i_dtype = targets.dtype;
322 let output = empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
323
324 let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
325 let cube_dim = CubeDim::new_1d(cube_dim_x);
326
327 ctc_loss_kernel::launch::<R>(
334 &client,
335 cube_count,
336 cube_dim,
337 log_probs.into_tensor_arg(),
338 targets.into_tensor_arg(),
339 input_lengths.into_tensor_arg(),
340 target_lengths.into_tensor_arg(),
341 output.clone().into_tensor_arg(),
342 blank as u32,
343 max_l_prime as u32,
344 [f_dtype.into(), i_dtype.into()],
345 );
346
347 output
348}
349
350#[cube(launch)]
365fn ctc_alpha_beta_kernel<F: Float, I: Numeric>(
366 log_probs: &Tensor<F>, targets: &Tensor<I>, input_lengths: &Tensor<I>, target_lengths: &Tensor<I>, alpha_out: &mut Tensor<F>, beta_out: &mut Tensor<F>, nll_out: &mut Tensor<F>, blank: u32,
374 #[comptime] alpha_capacity: u32,
375 #[define(F, I)] _dtypes: [StorageType; 2],
376) {
377 let n = CUBE_POS_X as usize;
378 let cube_dim = CUBE_DIM_X as usize;
379 let alpha_cap = alpha_capacity as usize;
380 let blank_u = blank as usize;
381
382 let target_len = u32::cast_from(target_lengths[n]) as usize;
383 let input_len = u32::cast_from(input_lengths[n]) as usize;
384 let l_prime_len = 2 * target_len + 1;
385
386 if input_len == 0 {
389 if UNIT_POS_X == 0 {
390 nll_out[n] = empty_input_nll::<F>(target_len);
391 }
392 terminate!();
393 }
394
395 let lp_t = log_probs.stride(0);
396 let lp_n = log_probs.stride(1);
397 let lp_c = log_probs.stride(2);
398 let tgt_n = targets.stride(0);
399 let tgt_s = targets.stride(1);
400 let ao_t = alpha_out.stride(0);
401 let ao_n = alpha_out.stride(1);
402 let ao_s = alpha_out.stride(2);
403 let bo_t = beta_out.stride(0);
404 let bo_n = beta_out.stride(1);
405 let bo_s = beta_out.stride(2);
406
407 let mut state = SharedMemory::<F>::new(2 * alpha_cap);
413 let neg_inf = F::new(-6.0e4_f32);
417 let unreachable_threshold = F::new(-1.0e4_f32);
418 let one = F::new(1.0);
419
420 let mut s = UNIT_POS_X as usize;
427 while s < l_prime_len {
428 let mut init = neg_inf;
429 if s == 0 {
430 init = log_probs[n * lp_n + blank_u * lp_c];
431 } else if s == 1 {
432 let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
433 init = log_probs[n * lp_n + l1 * lp_c];
434 }
435 state[s] = init;
436 alpha_out[n * ao_n + s * ao_s] = init;
437 s += cube_dim;
438 }
439 sync_cube();
440
441 for t in 1..input_len {
442 let mut s = UNIT_POS_X as usize;
443 while s < l_prime_len {
444 let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
445 let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
446
447 let l_class_m2 = if s >= 2 {
448 l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
449 } else {
450 blank_u
451 };
452 let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
453
454 let a_s = state[s];
455 let mut a_s_m1 = neg_inf;
456 if s >= 1 {
457 a_s_m1 = state[s - 1];
458 }
459 let mut a_s_m2 = neg_inf;
460 if s >= 2 {
461 a_s_m2 = state[s - 2];
462 }
463
464 state[alpha_cap + s] = recurrence_step::<F>(
465 a_s,
466 a_s_m1,
467 a_s_m2,
468 log_p,
469 skip_allowed,
470 unreachable_threshold,
471 one,
472 );
473 s += cube_dim;
474 }
475 sync_cube();
476
477 let mut s = UNIT_POS_X as usize;
478 while s < l_prime_len {
479 state[s] = state[alpha_cap + s];
480 alpha_out[t * ao_t + n * ao_n + s * ao_s] = state[s];
481 s += cube_dim;
482 }
483 sync_cube();
484 }
485
486 if UNIT_POS_X == 0 {
487 let last_blank = state[2 * target_len];
488 let mut last_label = neg_inf;
490 if target_len > 0 {
491 last_label = state[2 * target_len - 1];
492 }
493 nll_out[n] = finalize_nll::<F>(
494 last_blank,
495 last_label,
496 target_len,
497 unreachable_threshold,
498 one,
499 );
500 }
501
502 sync_cube();
505
506 let t_last = input_len - 1;
512 let mut s = UNIT_POS_X as usize;
513 while s < l_prime_len {
514 let is_last_blank = s == 2 * target_len;
515 let is_last_label = target_len > 0 && s == 2 * target_len - 1;
516 let mut init = neg_inf;
517 if is_last_blank || is_last_label {
518 let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
519 init = log_probs[t_last * lp_t + n * lp_n + l_class * lp_c];
520 }
521 state[s] = init;
522 beta_out[t_last * bo_t + n * bo_n + s * bo_s] = init;
523 s += cube_dim;
524 }
525 sync_cube();
526
527 for t_rev in 1..input_len {
529 let t = input_len - 1 - t_rev;
530
531 let mut s = UNIT_POS_X as usize;
532 while s < l_prime_len {
533 let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
534 let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
535
536 let l_class_p2 = if s + 2 < l_prime_len {
537 l_prime_class::<I>(s + 2, targets, n, tgt_n, tgt_s, blank_u)
538 } else {
539 blank_u
540 };
541 let skip_allowed = s + 2 < l_prime_len && l_class != blank_u && l_class != l_class_p2;
542
543 let b_s = state[s];
544 let mut b_s_p1 = neg_inf;
545 if s + 1 < l_prime_len {
546 b_s_p1 = state[s + 1];
547 }
548 let mut b_s_p2 = neg_inf;
549 if s + 2 < l_prime_len {
550 b_s_p2 = state[s + 2];
551 }
552
553 state[alpha_cap + s] = recurrence_step::<F>(
554 b_s,
555 b_s_p1,
556 b_s_p2,
557 log_p,
558 skip_allowed,
559 unreachable_threshold,
560 one,
561 );
562 s += cube_dim;
563 }
564 sync_cube();
565
566 let mut s = UNIT_POS_X as usize;
567 while s < l_prime_len {
568 state[s] = state[alpha_cap + s];
569 beta_out[t * bo_t + n * bo_n + s * bo_s] = state[s];
570 s += cube_dim;
571 }
572 sync_cube();
573 }
574}
575
576pub fn ctc_alpha_beta<R: CubeRuntime>(
585 log_probs: CubeTensor<R>,
586 targets: CubeTensor<R>,
587 input_lengths: CubeTensor<R>,
588 target_lengths: CubeTensor<R>,
589 blank: usize,
590) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {
591 let log_probs = into_contiguous(log_probs);
595 let targets = into_contiguous(targets);
596 let input_lengths = into_contiguous(input_lengths);
597 let target_lengths = into_contiguous(target_lengths);
598
599 let log_probs_shape = log_probs.shape();
600 let [max_input_length, batch_size, _c] = log_probs_shape.dims::<3>();
601 let target_shape = targets.shape();
602 let max_target_len = target_shape.dims::<2>()[1];
603 let max_l_prime = 2 * max_target_len + 1;
604
605 assert!(
606 max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
607 "ctc_loss_backward: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
608 alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
609 max_l_prime,
610 SHARED_ALPHA_CAPACITY,
611 );
612
613 let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
614 let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
615
616 let client = log_probs.client.clone();
617 let device = log_probs.device.clone();
618 let f_dtype = log_probs.dtype;
619 let i_dtype = targets.dtype;
620
621 let shape_abt = Shape::new([max_input_length, batch_size, max_l_prime]);
625 let neg_inf = InputScalar::new(f32::NEG_INFINITY, f_dtype);
626 let alpha_out = crate::ops::numeric::full_device_dtype::<R>(
627 client.clone(),
628 shape_abt.clone(),
629 device.clone(),
630 neg_inf,
631 f_dtype,
632 );
633 let beta_out = crate::ops::numeric::full_device_dtype::<R>(
634 client.clone(),
635 shape_abt,
636 device.clone(),
637 neg_inf,
638 f_dtype,
639 );
640 let nll_out =
641 empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
642
643 let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
644 let cube_dim = CubeDim::new_1d(cube_dim_x);
645
646 ctc_alpha_beta_kernel::launch::<R>(
647 &client,
648 cube_count,
649 cube_dim,
650 log_probs.into_tensor_arg(),
651 targets.into_tensor_arg(),
652 input_lengths.into_tensor_arg(),
653 target_lengths.into_tensor_arg(),
654 alpha_out.clone().into_tensor_arg(),
655 beta_out.clone().into_tensor_arg(),
656 nll_out.clone().into_tensor_arg(),
657 blank as u32,
658 max_l_prime as u32,
659 [f_dtype.into(), i_dtype.into()],
660 );
661
662 (alpha_out, beta_out, nll_out)
663}