Skip to main content

burn_cubecl/kernel/
ctc.rs

1use cubecl::prelude::*;
2
3use crate::{
4    CubeRuntime, kernel::into_contiguous, ops::numeric::empty_device_dtype, tensor::CubeTensor,
5};
6use burn_backend::{Shape, TensorMetadata};
7
8/// Maximum `2 * max_target_len + 1` the kernel supports. The alpha/beta state is
9/// held in shared memory as two f32 buffers of this size (active row + scratch),
10/// so peak shared use at full capacity is `2 * 8192 * 4 = 64 KB`. Apple Metal
11/// caps shared memory at 32 KB per block, so the launch site sizes the buffer to
12/// the actual per-batch `max_l_prime`; this constant is only the kernel-side
13/// upper bound. Inputs exceeding it panic rather than silently degrade.
14const SHARED_ALPHA_CAPACITY: u32 = 8192;
15
16/// Class label at position `s` of the blank-inserted label sequence `l'`.
17/// Odd `s` reads the underlying target at index `(s-1)/2`; even `s` is a blank.
18#[cube]
19fn l_prime_class<I: Numeric>(
20    s: usize,
21    targets: &Tensor<I>,
22    n: usize,
23    tgt_n: usize,
24    tgt_s: usize,
25    blank: usize,
26) -> usize {
27    if s % 2 == 1 {
28        u32::cast_from(targets[n * tgt_n + ((s - 1) / 2) * tgt_s]) as usize
29    } else {
30        blank
31    }
32}
33
34/// Numerically stable `log(exp(a) + exp(b))` with a sentinel short-circuit.
35/// When `max(a, b) < unreachable_threshold`, returns `max(a, b)` directly so
36/// the sentinel value doesn't drift upward each recursion step when both
37/// inputs sit at the `-6e4` floor.
38///
39/// The threshold's magnitude is forced by f16: the sentinel can't go below
40/// `-65504` (f16 max magnitude), so it's `-6e4`, and the threshold has to sit
41/// above the sentinel but below any plausible legit alpha value, leaving a
42/// narrow band around `-1e4`. On sufficiently long sequences where legit
43/// alpha values naturally drop below `-1e4` (roughly `T * log(1/C) < -1e4`),
44/// reachable states get misclassified as unreachable. Mitigation is a
45/// WGSL-only path with a smaller sentinel; WGSL spec 8.7 lets implementations
46/// replace runtime `1/0` with zero, so `-inf` can't be synthesized reliably.
47#[cube]
48fn log_sum_exp2<F: Float>(a: F, b: F, unreachable_threshold: F, one: F) -> F {
49    let mut mx = a;
50    let mut mn = b;
51    if b > a {
52        mx = b;
53        mn = a;
54    }
55    if mx < unreachable_threshold {
56        mx
57    } else {
58        mx + (one + (mn - mx).exp()).ln()
59    }
60}
61
62/// Single alpha (or beta) recurrence step. `near`, `near_m1`, `near_m2` are
63/// the three values from the previous time row (alpha: `t-1`; beta: `t+1`).
64/// `log_p` is the emission log-prob at the current `(t, l'[s])` and
65/// `skip_allowed` toggles the 2-position skip transition.
66#[cube]
67fn recurrence_step<F: Float>(
68    near: F,
69    near_m1: F,
70    near_m2: F,
71    log_p: F,
72    skip_allowed: bool,
73    unreachable_threshold: F,
74    one: F,
75) -> F {
76    let lse_01 = log_sum_exp2::<F>(near, near_m1, unreachable_threshold, one);
77    let combined = if skip_allowed {
78        log_sum_exp2::<F>(lse_01, near_m2, unreachable_threshold, one)
79    } else {
80        lse_01
81    };
82    log_p + combined
83}
84
85/// Final `-log(alpha_last_blank + alpha_last_label)` reduction. Synthesizes a
86/// true `+inf` via `exp()` overflow when both final alphas are at the sentinel
87/// (the target is unreachable), so downstream `zero_infinity` logic can detect
88/// it via `is_inf`. Builds the overflow arithmetically from a runtime-dependent
89/// value (`target_len`, guaranteed >= 1 here) to keep WGSL's comptime-overflow
90/// validator quiet.
91#[cube]
92fn finalize_nll<F: Float>(
93    last_blank: F,
94    last_label: F,
95    target_len: usize,
96    unreachable_threshold: F,
97    one: F,
98) -> F {
99    let mut mx = last_blank;
100    let mut mn = last_label;
101    if last_label > last_blank {
102        mx = last_label;
103        mn = last_blank;
104    }
105    if mx < unreachable_threshold {
106        (F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
107    } else {
108        F::new(0.0) - (mx + (one + (mn - mx).exp()).ln())
109    }
110}
111
112/// Value to emit when `input_len == 0`. `target_len == 0` is the only case
113/// with a valid alignment (P(empty | empty) = 1, nll = 0); otherwise the
114/// target is unreachable and the output is `+inf` synthesized via overflow.
115#[cube]
116fn empty_input_nll<F: Float>(target_len: usize) -> F {
117    if target_len == 0 {
118        F::new(0.0)
119    } else {
120        (F::new(1000.0_f32) * F::cast_from(target_len as u32)).exp()
121    }
122}
123
124/// CTC alpha-recursion kernel.
125///
126/// Each cube handles one batch element. `cube_dim.x` is fixed at launch time
127/// (capped to the runtime's hardware limit); each thread strides over the `s`
128/// positions of the modified label sequence `l'` (length `2 * target_len + 1`),
129/// covering arbitrary target lengths up to `SHARED_ALPHA_CAPACITY`. `alpha` is
130/// kept in shared memory and the time loop runs sequentially inside the kernel,
131/// using two `sync_cube()` barriers per iteration: one to fence reads of
132/// `alpha[t-1]` before any thread writes `alpha[t]`, one to publish the new row
133/// before the next iteration. This collapses what would otherwise be roughly
134/// `40 * T` host-side dispatches into a single kernel launch.
135///
136/// Impossible alignments use a large finite negative sentinel (`-6.0e4`)
137/// rather than true `-inf`, because WGSL rejects `f32(-inf)` as an identifier
138/// and f16's range caps at ~65504. The recurrence treats values below a
139/// threshold (`-1.0e4`) as unreachable. If an entire sequence has no valid
140/// alignment (e.g. `target_length > input_length`), the kernel synthesizes
141/// `+inf` in the output so downstream `zero_infinity` masking in `burn-nn`
142/// can detect it via `is_inf`.
143#[cube(launch)]
144fn ctc_loss_kernel<F: Float, I: Numeric>(
145    log_probs: &Tensor<F>,      // [T, N, C]
146    targets: &Tensor<I>,        // [N, S_max]
147    input_lengths: &Tensor<I>,  // [N]
148    target_lengths: &Tensor<I>, // [N]
149    output: &mut Tensor<F>,     // [N]
150    blank: u32,
151    #[comptime] alpha_capacity: u32,
152    #[define(F, I)] _dtypes: [StorageType; 2],
153) {
154    let n = CUBE_POS_X as usize;
155    let cube_dim = CUBE_DIM_X as usize;
156    let alpha_cap = alpha_capacity as usize;
157    let blank_u = blank as usize;
158
159    let target_len = u32::cast_from(target_lengths[n]) as usize;
160    let input_len = u32::cast_from(input_lengths[n]) as usize;
161    let l_prime_len = 2 * target_len + 1;
162
163    // Empty-input edge case: handled identically in both kernels to keep the
164    // forward loss and the backward nll agreeing for this sample.
165    if input_len == 0 {
166        if UNIT_POS_X == 0 {
167            output[n] = empty_input_nll::<F>(target_len);
168        }
169        terminate!();
170    }
171
172    let lp_t = log_probs.stride(0);
173    let lp_n = log_probs.stride(1);
174    let lp_c = log_probs.stride(2);
175    let tgt_n = targets.stride(0);
176    let tgt_s = targets.stride(1);
177
178    // Two adjacent regions: alpha[0..alpha_cap] is the active row, the second
179    // half [alpha_cap..2*alpha_cap] is a write scratch buffer that we copy back
180    // to the active region after a sync. This avoids RAW hazards across stride
181    // batches in the t-loop (a thread writing alpha[s] races with another
182    // thread still reading alpha[s-1] or alpha[s-2] for its own s).
183    let mut alpha = SharedMemory::<F>::new(2 * alpha_cap);
184    // Sentinel for unreachable states. f16 caps at ~65504 magnitude, so we
185    // can't go lower than `-6e4` without blowing past that range; WGSL also
186    // rejects `f32(-inf)` as an identifier, so a real -inf literal isn't an
187    // option anyway. On f32 the sentinel drifts slightly each recursion step
188    // (log(2) per step when both log_sum_exp inputs sit at the sentinel),
189    // which is why the recurrence compares against a threshold instead of
190    // checking `== neg_inf`. See `log_sum_exp2` for the long-sequence caveat.
191    let neg_inf = F::new(-6.0e4_f32);
192    let unreachable_threshold = F::new(-1.0e4_f32);
193    let one = F::new(1.0);
194
195    // Initialize alpha at t = 0 for s < l_prime_len; positions beyond that
196    // are never read by the recurrence (s < l_prime_len in every read) so
197    // they don't need to be touched.
198    let mut s = UNIT_POS_X as usize;
199    while s < l_prime_len {
200        let mut init = neg_inf;
201        if s == 0 {
202            init = log_probs[n * lp_n + blank_u * lp_c];
203        } else if s == 1 {
204            let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
205            init = log_probs[n * lp_n + l1 * lp_c];
206        }
207        alpha[s] = init;
208        s += cube_dim;
209    }
210    sync_cube();
211
212    // Sequential time loop. Each iteration re-strides over s positions to
213    // compute alpha[t, s] from alpha[t-1, *] and writes back to the same
214    // shared memory after a full read fence.
215    for t in 1..input_len {
216        let mut s = UNIT_POS_X as usize;
217        while s < l_prime_len {
218            let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
219            let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
220
221            let l_class_m2 = if s >= 2 {
222                l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
223            } else {
224                blank_u
225            };
226            let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
227
228            let a_s = alpha[s];
229            let mut a_s_m1 = neg_inf;
230            if s >= 1 {
231                a_s_m1 = alpha[s - 1];
232            }
233            let mut a_s_m2 = neg_inf;
234            if s >= 2 {
235                a_s_m2 = alpha[s - 2];
236            }
237
238            alpha[alpha_cap + s] = recurrence_step::<F>(
239                a_s,
240                a_s_m1,
241                a_s_m2,
242                log_p,
243                skip_allowed,
244                unreachable_threshold,
245                one,
246            );
247            s += cube_dim;
248        }
249        sync_cube();
250
251        // Second pass: copy scratch back into the active alpha slots.
252        let mut s = UNIT_POS_X as usize;
253        while s < l_prime_len {
254            alpha[s] = alpha[alpha_cap + s];
255            s += cube_dim;
256        }
257        sync_cube();
258    }
259
260    // Reduce: only thread 0 writes the output for this batch element.
261    if UNIT_POS_X == 0 {
262        let last_blank = alpha[2 * target_len];
263        // Guard target_len = 0: index 2*0 - 1 underflows. Use -inf so
264        // log_sum_exp(last_blank, -inf) = last_blank (log_sum_exp(x, x) = x+ln2
265        // would be wrong here).
266        let mut last_label = neg_inf;
267        if target_len > 0 {
268            last_label = alpha[2 * target_len - 1];
269        }
270        output[n] = finalize_nll::<F>(
271            last_blank,
272            last_label,
273            target_len,
274            unreachable_threshold,
275            one,
276        );
277    }
278}
279
280/// Fused CTC loss for burn-cubecl. Single kernel launch covers the entire
281/// alpha recursion across all timesteps.
282///
283/// Panics if `2 * max_target_len + 1` exceeds `SHARED_ALPHA_CAPACITY` (8192).
284pub fn ctc_loss<R: CubeRuntime>(
285    log_probs: CubeTensor<R>,
286    targets: CubeTensor<R>,
287    input_lengths: CubeTensor<R>,
288    target_lengths: CubeTensor<R>,
289    blank: usize,
290) -> CubeTensor<R> {
291    // Manual stride indexing below requires a contiguous physical layout;
292    // fusion-produced tensors may arrive with layouts that break that
293    // assumption. No-op when already contiguous.
294    let log_probs = into_contiguous(log_probs);
295    let targets = into_contiguous(targets);
296    let input_lengths = into_contiguous(input_lengths);
297    let target_lengths = into_contiguous(target_lengths);
298
299    let log_probs_shape = log_probs.shape();
300    let [_t, batch_size, _c] = log_probs_shape.dims::<3>();
301    let target_shape = targets.shape();
302    let max_target_len = target_shape.dims::<2>()[1];
303    let max_l_prime = 2 * max_target_len + 1;
304
305    assert!(
306        max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
307        "ctc_loss: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
308         alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
309        max_l_prime,
310        SHARED_ALPHA_CAPACITY,
311    );
312
313    // Pick a thread count that fits the runtime's per-cube limit. We don't
314    // need one thread per s position - threads stride over s.
315    let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
316    let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
317
318    let client = log_probs.client.clone();
319    let device = log_probs.device.clone();
320    let f_dtype = log_probs.dtype;
321    let i_dtype = targets.dtype;
322    let output = empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
323
324    let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
325    let cube_dim = CubeDim::new_1d(cube_dim_x);
326
327    // Pass the actual max_l_prime (not the static capacity) so shared memory
328    // is sized to what we need. Metal limits threadgroup memory to 32 KB;
329    // allocating 2 * 8192 * sizeof(f32) = 64 KB would silently corrupt on
330    // Apple GPUs. Different max_l_prime values trigger separate kernel
331    // compilations (it's a comptime param), but that's fine: target lengths
332    // are stable within a dataset.
333    ctc_loss_kernel::launch::<R>(
334        &client,
335        cube_count,
336        cube_dim,
337        log_probs.into_tensor_arg(),
338        targets.into_tensor_arg(),
339        input_lengths.into_tensor_arg(),
340        target_lengths.into_tensor_arg(),
341        output.clone().into_tensor_arg(),
342        blank as u32,
343        max_l_prime as u32,
344        [f_dtype.into(), i_dtype.into()],
345    );
346
347    output
348}
349
350/// Fused CTC alpha + beta recursion kernel.
351///
352/// Runs the full forward alpha recursion and reverse beta recursion for one
353/// batch element per cube, reusing the same shared-memory layout twice.
354/// Writes `alpha_out[T, N, 2S+1]`, `beta_out[T, N, 2S+1]` and the per-sample
355/// negative log-likelihood `nll_out[N]`. The three outputs are everything the
356/// default CTC gradient-composition helper needs, so the caller can finish the
357/// backward pass with a handful of element-wise tensor ops.
358///
359/// The alpha phase is identical to `ctc_loss_kernel` except it additionally
360/// publishes each row to global memory. The beta phase mirrors it in reverse:
361/// initialize at `t = input_len - 1` from `log_probs[t, l'[s]]` at the two
362/// boundary `s` positions, then step backward reading `beta[t+1, s]`,
363/// `beta[t+1, s+1]`, and (when the skip transition is allowed) `beta[t+1, s+2]`.
364#[cube(launch)]
365fn ctc_alpha_beta_kernel<F: Float, I: Numeric>(
366    log_probs: &Tensor<F>,      // [T, N, C]
367    targets: &Tensor<I>,        // [N, S_max]
368    input_lengths: &Tensor<I>,  // [N]
369    target_lengths: &Tensor<I>, // [N]
370    alpha_out: &mut Tensor<F>,  // [T, N, 2S+1]
371    beta_out: &mut Tensor<F>,   // [T, N, 2S+1]
372    nll_out: &mut Tensor<F>,    // [N]
373    blank: u32,
374    #[comptime] alpha_capacity: u32,
375    #[define(F, I)] _dtypes: [StorageType; 2],
376) {
377    let n = CUBE_POS_X as usize;
378    let cube_dim = CUBE_DIM_X as usize;
379    let alpha_cap = alpha_capacity as usize;
380    let blank_u = blank as usize;
381
382    let target_len = u32::cast_from(target_lengths[n]) as usize;
383    let input_len = u32::cast_from(input_lengths[n]) as usize;
384    let l_prime_len = 2 * target_len + 1;
385
386    // Empty input: alpha_out and beta_out stay at the host-side -inf pre-fill.
387    // Emit the semantically correct nll (0 for target_len=0, +inf otherwise).
388    if input_len == 0 {
389        if UNIT_POS_X == 0 {
390            nll_out[n] = empty_input_nll::<F>(target_len);
391        }
392        terminate!();
393    }
394
395    let lp_t = log_probs.stride(0);
396    let lp_n = log_probs.stride(1);
397    let lp_c = log_probs.stride(2);
398    let tgt_n = targets.stride(0);
399    let tgt_s = targets.stride(1);
400    let ao_t = alpha_out.stride(0);
401    let ao_n = alpha_out.stride(1);
402    let ao_s = alpha_out.stride(2);
403    let bo_t = beta_out.stride(0);
404    let bo_n = beta_out.stride(1);
405    let bo_s = beta_out.stride(2);
406
407    // Shared memory layout: [0..alpha_cap] is the active row; [alpha_cap..2*alpha_cap]
408    // is scratch for the next row. Same layout is reused for alpha and beta. Beta
409    // reads are guarded by `s + 1 < l_prime_len` / `s + 2 < l_prime_len`, so the
410    // residual alpha values sitting in the active row between phases are never
411    // observed by beta (its boundary init overwrites every slot it reads).
412    let mut state = SharedMemory::<F>::new(2 * alpha_cap);
413    // Sentinel for unreachable states. See ctc_loss_kernel for the full
414    // rationale: f16's 65504 magnitude cap forces the -6e4 floor, WGSL
415    // rejects f32(-inf) literals, and the threshold catches sentinel drift.
416    let neg_inf = F::new(-6.0e4_f32);
417    let unreachable_threshold = F::new(-1.0e4_f32);
418    let one = F::new(1.0);
419
420    // Alpha phase (forward).
421    //
422    // Initialize alpha at t = 0 for s < l_prime_len. Positions beyond
423    // l_prime_len are never read by the recurrence, so they don't need
424    // to be touched in shared memory; and they stay at the host-side -inf
425    // pre-fill in alpha_out.
426    let mut s = UNIT_POS_X as usize;
427    while s < l_prime_len {
428        let mut init = neg_inf;
429        if s == 0 {
430            init = log_probs[n * lp_n + blank_u * lp_c];
431        } else if s == 1 {
432            let l1 = u32::cast_from(targets[n * tgt_n]) as usize;
433            init = log_probs[n * lp_n + l1 * lp_c];
434        }
435        state[s] = init;
436        alpha_out[n * ao_n + s * ao_s] = init;
437        s += cube_dim;
438    }
439    sync_cube();
440
441    for t in 1..input_len {
442        let mut s = UNIT_POS_X as usize;
443        while s < l_prime_len {
444            let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
445            let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
446
447            let l_class_m2 = if s >= 2 {
448                l_prime_class::<I>(s - 2, targets, n, tgt_n, tgt_s, blank_u)
449            } else {
450                blank_u
451            };
452            let skip_allowed = s >= 2 && l_class != blank_u && l_class != l_class_m2;
453
454            let a_s = state[s];
455            let mut a_s_m1 = neg_inf;
456            if s >= 1 {
457                a_s_m1 = state[s - 1];
458            }
459            let mut a_s_m2 = neg_inf;
460            if s >= 2 {
461                a_s_m2 = state[s - 2];
462            }
463
464            state[alpha_cap + s] = recurrence_step::<F>(
465                a_s,
466                a_s_m1,
467                a_s_m2,
468                log_p,
469                skip_allowed,
470                unreachable_threshold,
471                one,
472            );
473            s += cube_dim;
474        }
475        sync_cube();
476
477        let mut s = UNIT_POS_X as usize;
478        while s < l_prime_len {
479            state[s] = state[alpha_cap + s];
480            alpha_out[t * ao_t + n * ao_n + s * ao_s] = state[s];
481            s += cube_dim;
482        }
483        sync_cube();
484    }
485
486    if UNIT_POS_X == 0 {
487        let last_blank = state[2 * target_len];
488        // See ctc_loss_kernel: -inf sentinel keeps log_sum_exp correct for target_len = 0.
489        let mut last_label = neg_inf;
490        if target_len > 0 {
491            last_label = state[2 * target_len - 1];
492        }
493        nll_out[n] = finalize_nll::<F>(
494            last_blank,
495            last_label,
496            target_len,
497            unreachable_threshold,
498            one,
499        );
500    }
501
502    // Fence thread 0's read of state[2*target_len] / state[2*target_len - 1]
503    // against the beta boundary init, which writes those same positions.
504    sync_cube();
505
506    // Beta phase (reverse).
507    //
508    // Boundary initialization at t = input_len - 1: set beta[s] = log_probs[t, l'[s]]
509    // at s = 2*target_len, and when target_len > 0 also at s = 2*target_len - 1.
510    // All other s positions in range get -inf.
511    let t_last = input_len - 1;
512    let mut s = UNIT_POS_X as usize;
513    while s < l_prime_len {
514        let is_last_blank = s == 2 * target_len;
515        let is_last_label = target_len > 0 && s == 2 * target_len - 1;
516        let mut init = neg_inf;
517        if is_last_blank || is_last_label {
518            let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
519            init = log_probs[t_last * lp_t + n * lp_n + l_class * lp_c];
520        }
521        state[s] = init;
522        beta_out[t_last * bo_t + n * bo_n + s * bo_s] = init;
523        s += cube_dim;
524    }
525    sync_cube();
526
527    // Step back from t = input_len - 2 down to t = 0.
528    for t_rev in 1..input_len {
529        let t = input_len - 1 - t_rev;
530
531        let mut s = UNIT_POS_X as usize;
532        while s < l_prime_len {
533            let l_class = l_prime_class::<I>(s, targets, n, tgt_n, tgt_s, blank_u);
534            let log_p = log_probs[t * lp_t + n * lp_n + l_class * lp_c];
535
536            let l_class_p2 = if s + 2 < l_prime_len {
537                l_prime_class::<I>(s + 2, targets, n, tgt_n, tgt_s, blank_u)
538            } else {
539                blank_u
540            };
541            let skip_allowed = s + 2 < l_prime_len && l_class != blank_u && l_class != l_class_p2;
542
543            let b_s = state[s];
544            let mut b_s_p1 = neg_inf;
545            if s + 1 < l_prime_len {
546                b_s_p1 = state[s + 1];
547            }
548            let mut b_s_p2 = neg_inf;
549            if s + 2 < l_prime_len {
550                b_s_p2 = state[s + 2];
551            }
552
553            state[alpha_cap + s] = recurrence_step::<F>(
554                b_s,
555                b_s_p1,
556                b_s_p2,
557                log_p,
558                skip_allowed,
559                unreachable_threshold,
560                one,
561            );
562            s += cube_dim;
563        }
564        sync_cube();
565
566        let mut s = UNIT_POS_X as usize;
567        while s < l_prime_len {
568            state[s] = state[alpha_cap + s];
569            beta_out[t * bo_t + n * bo_n + s * bo_s] = state[s];
570            s += cube_dim;
571        }
572        sync_cube();
573    }
574}
575
576/// Host entry point for the fused alpha + beta + nll kernel.
577///
578/// Returns `(log_alpha_full, log_beta_full, nll)` with shapes
579/// `([T, N, 2S+1], [T, N, 2S+1], [N])`. Positions outside the valid
580/// `(t < input_length, s < 2*target_length+1)` rectangle hold the
581/// pre-fill value `-inf`, matching the default backend's convention.
582///
583/// Panics if `2 * max_target_len + 1` exceeds `SHARED_ALPHA_CAPACITY`.
584pub fn ctc_alpha_beta<R: CubeRuntime>(
585    log_probs: CubeTensor<R>,
586    targets: CubeTensor<R>,
587    input_lengths: CubeTensor<R>,
588    target_lengths: CubeTensor<R>,
589    blank: usize,
590) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {
591    // Manual stride indexing below requires a contiguous physical layout;
592    // fusion-produced tensors may arrive with layouts that break that
593    // assumption. No-op when already contiguous.
594    let log_probs = into_contiguous(log_probs);
595    let targets = into_contiguous(targets);
596    let input_lengths = into_contiguous(input_lengths);
597    let target_lengths = into_contiguous(target_lengths);
598
599    let log_probs_shape = log_probs.shape();
600    let [max_input_length, batch_size, _c] = log_probs_shape.dims::<3>();
601    let target_shape = targets.shape();
602    let max_target_len = target_shape.dims::<2>()[1];
603    let max_l_prime = 2 * max_target_len + 1;
604
605    assert!(
606        max_l_prime as u32 <= SHARED_ALPHA_CAPACITY,
607        "ctc_loss_backward: 2 * max_target_len + 1 = {} exceeds the kernel's shared-memory \
608         alpha capacity ({}). Reduce target length or raise SHARED_ALPHA_CAPACITY.",
609        max_l_prime,
610        SHARED_ALPHA_CAPACITY,
611    );
612
613    let hw_max = log_probs.client.properties().hardware.max_cube_dim.0;
614    let cube_dim_x = (max_l_prime as u32).min(hw_max).min(256);
615
616    let client = log_probs.client.clone();
617    let device = log_probs.device.clone();
618    let f_dtype = log_probs.dtype;
619    let i_dtype = targets.dtype;
620
621    // Pre-fill alpha/beta with -inf so positions the kernel doesn't touch
622    // (s >= 2U+1, or t outside the valid range for an individual batch
623    // element) are not read as stale zeros by the gradient composition.
624    let shape_abt = Shape::new([max_input_length, batch_size, max_l_prime]);
625    let neg_inf = InputScalar::new(f32::NEG_INFINITY, f_dtype);
626    let alpha_out = crate::ops::numeric::full_device_dtype::<R>(
627        client.clone(),
628        shape_abt.clone(),
629        device.clone(),
630        neg_inf,
631        f_dtype,
632    );
633    let beta_out = crate::ops::numeric::full_device_dtype::<R>(
634        client.clone(),
635        shape_abt,
636        device.clone(),
637        neg_inf,
638        f_dtype,
639    );
640    let nll_out =
641        empty_device_dtype::<R>(client.clone(), device, Shape::new([batch_size]), f_dtype);
642
643    let cube_count = CubeCount::Static(batch_size as u32, 1, 1);
644    let cube_dim = CubeDim::new_1d(cube_dim_x);
645
646    ctc_alpha_beta_kernel::launch::<R>(
647        &client,
648        cube_count,
649        cube_dim,
650        log_probs.into_tensor_arg(),
651        targets.into_tensor_arg(),
652        input_lengths.into_tensor_arg(),
653        target_lengths.into_tensor_arg(),
654        alpha_out.clone().into_tensor_arg(),
655        beta_out.clone().into_tensor_arg(),
656        nll_out.clone().into_tensor_arg(),
657        blank as u32,
658        max_l_prime as u32,
659        [f_dtype.into(), i_dtype.into()],
660    );
661
662    (alpha_out, beta_out, nll_out)
663}