Skip to main content

burn_backend/backend/ops/modules/
ctc.rs

1use burn_std::{Shape, Slice};
2
3use crate::{
4    Backend, TensorMetadata, get_device_settings,
5    tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8/// Default CTC loss implementation using the forward (alpha) algorithm.
9///
10/// Computes the Connectionist Temporal Classification loss by summing over
11/// all valid alignments between the input and target sequences.
12///
13/// # Arguments
14///
15/// * `log_probs` - Log-probabilities of shape `[T, N, C]`
16/// * `targets` - Target indices of shape `[N, S]`
17/// * `input_lengths` - Actual input sequence lengths per batch element `[N]`
18/// * `target_lengths` - Actual target lengths per batch element `[N]`
19/// * `blank` - Index of the blank label
20///
21/// # Returns
22///
23/// Per-sample loss of shape `[N]`
24pub fn ctc_loss_default<B: Backend>(
25    log_probs: FloatTensor<B>,
26    targets: IntTensor<B>,
27    input_lengths: IntTensor<B>,
28    target_lengths: IntTensor<B>,
29    blank: usize,
30) -> FloatTensor<B> {
31    let alpha = AlphaCtx::<B>::compute(
32        log_probs,
33        &targets,
34        input_lengths,
35        target_lengths.clone(),
36        blank,
37    );
38    extract_loss::<B>(&alpha, target_lengths)
39}
40
41/// Compose the CTC gradient w.r.t. `log_probs` from pre-computed alpha, beta, and nll.
42///
43/// The T-iteration alpha and beta recursions are the dominant cost of the backward
44/// pass. Backends that fuse those recursions into a single kernel launch can call
45/// this helper to reuse the gradient composition.
46///
47/// # Arguments
48///
49/// * `log_probs` - Log-probabilities `[T, N, C]`
50/// * `targets` - Target label indices `[N, S]`
51/// * `input_lengths` - Actual input sequence lengths per batch element `[N]`
52/// * `grad_loss` - Upstream gradient w.r.t. the per-sample loss `[N]`
53/// * `log_alpha_full` - Alpha recursion output `[T, N, 2S+1]`
54/// * `log_beta_full` - Beta recursion output `[T, N, 2S+1]`
55/// * `nll` - Per-sample negative log-likelihood (forward loss) `[N]`
56/// * `blank` - Index of the blank label
57#[allow(clippy::too_many_arguments)]
58pub fn ctc_grad_from_alpha_beta_default<B: Backend>(
59    log_probs: FloatTensor<B>,
60    targets: IntTensor<B>,
61    input_lengths: IntTensor<B>,
62    grad_loss: FloatTensor<B>,
63    log_alpha_full: FloatTensor<B>,
64    log_beta_full: FloatTensor<B>,
65    nll: FloatTensor<B>,
66    blank: usize,
67) -> FloatTensor<B> {
68    let log_probs_shape = log_probs.shape();
69    let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
70    let target_shape = targets.shape();
71    let max_target_len = target_shape.dims::<2>()[1];
72    let max_l_prime_len = 2 * max_target_len + 1;
73    let device = B::float_device(&log_probs);
74    let int_dtype: burn_std::IntDType = targets.dtype().into();
75    let settings = get_device_settings::<B>(&device);
76
77    let blank_inserted_targets = insert_blanks::<B>(
78        &targets,
79        batch_size,
80        max_target_len,
81        max_l_prime_len,
82        blank,
83        &device,
84        int_dtype,
85    );
86
87    // Both log_alpha[t, n, s] and log_beta[t, n, s] include a factor of
88    // log_probs[t, n, l'[s]] (added on every recursion step). The CTC paper's
89    // alpha_hat * beta_hat product divides one of those factors out, so we
90    // subtract log_probs[t, n, l'[s]] when forming log_post.
91    //
92    // We then divide by total_prob = exp(-nll) to obtain the alignment
93    // posterior, which in log space means *adding* nll (since nll = -log P,
94    // dividing by P is adding nll). Per PyTorch's CTC backward kernel:
95    //   log_post[t, n, s] = log_alpha + log_beta - log_probs[t, n, l'[s]] - log P
96    //                     = log_alpha + log_beta - log_probs[t, n, l'[s]] + nll
97    let indices_3d = B::int_reshape(
98        blank_inserted_targets,
99        Shape::new([1, batch_size, max_l_prime_len]),
100    );
101    let indices_3d = B::int_expand(
102        indices_3d,
103        Shape::new([max_input_length, batch_size, max_l_prime_len]),
104    );
105    let log_probs_at_l = B::float_gather(2, log_probs.clone(), indices_3d.clone());
106
107    // Samples with an unreachable target yield nll = +inf. For those, log_alpha
108    // stays at -inf at many (t, s) while log_beta is finite at the boundary, so
109    // log_post = (-inf) + finite - finite + (+inf) = NaN and -exp(NaN) = NaN
110    // contaminates the gradient. `NaN * 0 = NaN` under IEEE 754, so zero_infinity
111    // masking on the outer grad_loss can't clear it. Capture the mask now and
112    // zero the gradient for those samples at the end.
113    let nll_is_inf = B::float_is_inf(nll.clone(), settings.bool_dtype);
114
115    let nll_b = B::float_reshape(nll, Shape::new([1, batch_size, 1]));
116    let nll_b = B::float_expand(
117        nll_b,
118        Shape::new([max_input_length, batch_size, max_l_prime_len]),
119    );
120    let log_post = B::float_add(
121        B::float_sub(B::float_add(log_alpha_full, log_beta_full), log_probs_at_l),
122        nll_b,
123    );
124
125    // grad starts as exp(log_probs) * grad_loss[None, :, None].
126    let grad_loss_3d = B::float_reshape(grad_loss, Shape::new([1, batch_size, 1]));
127    let grad_loss_b = B::float_expand(
128        grad_loss_3d.clone(),
129        Shape::new([max_input_length, batch_size, num_classes]),
130    );
131    let mut grad = B::float_mul(B::float_exp(log_probs), grad_loss_b);
132
133    // Subtract sum over s of grad_loss[n] * exp(log_post[t, n, s]) at index l'[n, s].
134    let grad_loss_post = B::float_expand(
135        grad_loss_3d,
136        Shape::new([max_input_length, batch_size, max_l_prime_len]),
137    );
138    let scatter_value = B::float_neg(B::float_mul(B::float_exp(log_post), grad_loss_post));
139
140    grad = B::float_scatter_add(2, grad, indices_3d, scatter_value);
141
142    // Mask out timesteps where t >= input_lengths[n].
143    let t_indices = B::int_arange(0..max_input_length as i64, &device, int_dtype);
144    let t_indices = B::int_reshape(t_indices, Shape::new([max_input_length, 1, 1]));
145    let t_indices = B::int_expand(
146        t_indices,
147        Shape::new([max_input_length, batch_size, num_classes]),
148    );
149    let il_b = B::int_reshape(input_lengths, Shape::new([1, batch_size, 1]));
150    let il_b = B::int_expand(
151        il_b,
152        Shape::new([max_input_length, batch_size, num_classes]),
153    );
154    let oob_mask = B::int_greater_equal(t_indices, il_b, settings.bool_dtype);
155
156    // Broadcast the nll-is-inf mask across [T, N, C] and OR with oob_mask so a
157    // single mask_fill zeros both unreachable samples and out-of-bound timesteps.
158    let nll_inf_b = B::bool_reshape(nll_is_inf, Shape::new([1, batch_size, 1]));
159    let nll_inf_b = B::bool_expand(
160        nll_inf_b,
161        Shape::new([max_input_length, batch_size, num_classes]),
162    );
163    let mask = B::bool_or(oob_mask, nll_inf_b);
164    B::float_mask_fill(grad, mask, 0.0.into())
165}
166
167/// Cached state from the alpha recursion. Only `last` is consumed by
168/// `ctc_loss_default` (via `extract_loss`); the other fields hold intermediate
169/// products that backends with a native backward kernel could reuse if wired
170/// up. They are kept here to document the recursion's outputs.
171#[allow(dead_code)]
172struct AlphaCtx<B: Backend> {
173    /// `log_alpha[T, N, 2S+1]` (full history).
174    full: FloatTensor<B>,
175    /// `log_alpha[T-1, :, :]` (last timestep; used to read out the loss).
176    last: FloatTensor<B>,
177    /// `l'` after blank insertion `[N, 2S+1]`.
178    blank_inserted_targets: IntTensor<B>,
179    /// `log_probs[t, n, l'[n, s]]` pre-gathered as `[T, N, 2S+1]`.
180    log_probs_at_l_full: FloatTensor<B>,
181    max_l_prime_len: usize,
182}
183
184impl<B: Backend> AlphaCtx<B> {
185    fn compute(
186        log_probs: FloatTensor<B>,
187        targets: &IntTensor<B>,
188        input_lengths: IntTensor<B>,
189        target_lengths: IntTensor<B>,
190        blank: usize,
191    ) -> Self {
192        let log_probs_shape = log_probs.shape();
193        let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
194        let target_shape = targets.shape();
195        let max_target_len = target_shape.dims::<2>()[1];
196        let device = B::float_device(&log_probs);
197        let float_dtype: burn_std::FloatDType = log_probs.dtype().into();
198        let int_dtype: burn_std::IntDType = targets.dtype().into();
199        let settings = get_device_settings::<B>(&device);
200
201        let max_l_prime_len = 2 * max_target_len + 1;
202        let blank_inserted_targets = insert_blanks::<B>(
203            targets,
204            batch_size,
205            max_target_len,
206            max_l_prime_len,
207            blank,
208            &device,
209            int_dtype,
210        );
211
212        // Pre-allocate the full alpha tensor [T, N, 2S+1] filled with -inf.
213        let mut alpha_full = B::float_full(
214            Shape::new([max_input_length, batch_size, max_l_prime_len]),
215            f32::NEG_INFINITY.into(),
216            &device,
217            float_dtype,
218        );
219
220        // Initialize alpha[0, :, 0] = log_probs[0, :, blank]
221        // and alpha[0, :, 1] = log_probs[0, :, l'[1]].
222        let log_probs_t0 = B::float_slice(
223            log_probs.clone(),
224            &[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
225        );
226        let log_probs_t0 = B::float_reshape(log_probs_t0, Shape::new([batch_size, num_classes]));
227
228        let first_blank = B::int_slice(
229            blank_inserted_targets.clone(),
230            &[Slice::full(), Slice::new(0, Some(1), 1)],
231        );
232        let log_prob_blank = B::float_gather(1, log_probs_t0.clone(), first_blank);
233        // Broadcast to [1, N, 1] for slice_assign into alpha_full.
234        let log_prob_blank_3d = B::float_reshape(log_prob_blank, Shape::new([1, batch_size, 1]));
235        alpha_full = B::float_slice_assign(
236            alpha_full,
237            &[
238                Slice::new(0, Some(1), 1),
239                Slice::full(),
240                Slice::new(0, Some(1), 1),
241            ],
242            log_prob_blank_3d,
243        );
244
245        if max_l_prime_len > 1 {
246            let first_label = B::int_slice(
247                blank_inserted_targets.clone(),
248                &[Slice::full(), Slice::new(1, Some(2), 1)],
249            );
250            let log_prob_first = B::float_gather(1, log_probs_t0, first_label);
251            let log_prob_first_3d =
252                B::float_reshape(log_prob_first, Shape::new([1, batch_size, 1]));
253            alpha_full = B::float_slice_assign(
254                alpha_full,
255                &[
256                    Slice::new(0, Some(1), 1),
257                    Slice::full(),
258                    Slice::new(1, Some(2), 1),
259                ],
260                log_prob_first_3d,
261            );
262        }
263
264        // Track the latest row separately for the recursion (cheaper than
265        // re-slicing alpha_full each iteration).
266        let mut log_alpha = B::float_slice(
267            alpha_full.clone(),
268            &[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
269        );
270        log_alpha = B::float_reshape(log_alpha, Shape::new([batch_size, max_l_prime_len]));
271
272        let l_prime_mask = create_l_prime_mask::<B>(
273            &blank_inserted_targets,
274            batch_size,
275            max_l_prime_len,
276            blank,
277            &device,
278            int_dtype,
279            settings.bool_dtype,
280        );
281        let s_mask = create_s_mask::<B>(
282            &target_lengths,
283            batch_size,
284            max_l_prime_len,
285            &device,
286            int_dtype,
287            settings.bool_dtype,
288        );
289
290        // Hoist out of the T-loop: padding tensors for right_shift (same
291        // value/shape at every iteration) and the full `[T, N, 2S+1]`
292        // gather of log_probs at l' (one T-sized gather replaces T small
293        // gathers).
294        let pad_1 = B::float_full(
295            Shape::new([batch_size, 1]),
296            f32::NEG_INFINITY.into(),
297            &device,
298            float_dtype,
299        );
300        let pad_2 = B::float_full(
301            Shape::new([batch_size, 2]),
302            f32::NEG_INFINITY.into(),
303            &device,
304            float_dtype,
305        );
306        let indices_3d = B::int_expand(
307            B::int_reshape(
308                blank_inserted_targets.clone(),
309                Shape::new([1, batch_size, max_l_prime_len]),
310            ),
311            Shape::new([max_input_length, batch_size, max_l_prime_len]),
312        );
313        let log_probs_at_l_full = B::float_gather(2, log_probs.clone(), indices_3d);
314
315        // Precompute `combined_mask_all[t, n, s] = (input_lengths[n] > t) AND
316        // s_mask[n, s]` for every t in one shot. The T-loop reads its row via
317        // a metadata-only slice instead of recomputing the `int_greater_elem`
318        // + bool_and per iteration.
319        let t_indices_2d = B::int_expand(
320            B::int_reshape(
321                B::int_arange(0..max_input_length as i64, &device, int_dtype),
322                Shape::new([max_input_length, 1]),
323            ),
324            Shape::new([max_input_length, batch_size]),
325        );
326        let il_tn = B::int_expand(
327            B::int_reshape(input_lengths.clone(), Shape::new([1, batch_size])),
328            Shape::new([max_input_length, batch_size]),
329        );
330        let t_mask_all = B::bool_expand(
331            B::bool_reshape(
332                B::int_greater(il_tn, t_indices_2d, settings.bool_dtype),
333                Shape::new([max_input_length, batch_size, 1]),
334            ),
335            Shape::new([max_input_length, batch_size, max_l_prime_len]),
336        );
337        let s_mask_bcast = B::bool_expand(
338            B::bool_reshape(s_mask.clone(), Shape::new([1, batch_size, max_l_prime_len])),
339            Shape::new([max_input_length, batch_size, max_l_prime_len]),
340        );
341        let combined_mask_all = B::bool_and(t_mask_all, s_mask_bcast);
342
343        for t in 1..max_input_length {
344            let combined_mask = B::bool_reshape(
345                B::bool_slice(
346                    combined_mask_all.clone(),
347                    &[
348                        Slice::new(t as isize, Some(t as isize + 1), 1),
349                        Slice::full(),
350                        Slice::full(),
351                    ],
352                ),
353                Shape::new([batch_size, max_l_prime_len]),
354            );
355
356            let log_alpha_s = log_alpha.clone();
357            let log_alpha_s_m1 = right_shift::<B>(&log_alpha, &pad_1, max_l_prime_len, 1);
358            let log_alpha_s_m2 = right_shift::<B>(&log_alpha, &pad_2, max_l_prime_len, 2);
359
360            let bar = log_sum_exp::<B>(log_alpha_s, log_alpha_s_m1, settings.bool_dtype);
361            let bar_with_skip = log_sum_exp::<B>(bar.clone(), log_alpha_s_m2, settings.bool_dtype);
362            let log_alpha_combined = B::float_mask_where(bar, l_prime_mask.clone(), bar_with_skip);
363
364            // Slice row t from the pre-gathered `[T, N, 2S+1]` tensor.
365            let log_probs_at_l = B::float_reshape(
366                B::float_slice(
367                    log_probs_at_l_full.clone(),
368                    &[
369                        Slice::new(t as isize, Some(t as isize + 1), 1),
370                        Slice::full(),
371                        Slice::full(),
372                    ],
373                ),
374                Shape::new([batch_size, max_l_prime_len]),
375            );
376            let new_alpha = B::float_add(log_alpha_combined, log_probs_at_l);
377            log_alpha = B::float_mask_where(log_alpha, combined_mask, new_alpha);
378
379            let log_alpha_3d = B::float_reshape(
380                log_alpha.clone(),
381                Shape::new([1, batch_size, max_l_prime_len]),
382            );
383            alpha_full = B::float_slice_assign(
384                alpha_full,
385                &[
386                    Slice::new(t as isize, Some(t as isize + 1), 1),
387                    Slice::full(),
388                    Slice::full(),
389                ],
390                log_alpha_3d,
391            );
392        }
393
394        Self {
395            full: alpha_full,
396            last: log_alpha,
397            blank_inserted_targets,
398            log_probs_at_l_full,
399            max_l_prime_len,
400        }
401    }
402}
403
404/// Extract the per-sample loss from the last alpha row.
405fn extract_loss<B: Backend>(alpha: &AlphaCtx<B>, target_lengths: IntTensor<B>) -> FloatTensor<B> {
406    let log_alpha_shape = alpha.last.shape();
407    let [batch_size, _] = log_alpha_shape.dims::<2>();
408    let device = B::float_device(&alpha.last);
409    let settings = get_device_settings::<B>(&device);
410
411    let last_blank_idx = B::int_mul_scalar(target_lengths.clone(), 2.into());
412    let last_blank_idx = B::int_reshape(last_blank_idx, Shape::new([batch_size, 1]));
413    let last_label_idx = B::int_clamp_min(
414        B::int_sub_scalar(last_blank_idx.clone(), 1.into()),
415        0.into(),
416    );
417
418    let log_alpha_last_blank = B::float_gather(1, alpha.last.clone(), last_blank_idx);
419    let log_alpha_last_blank = B::float_reshape(log_alpha_last_blank, Shape::new([batch_size]));
420
421    let log_alpha_last_label = B::float_gather(1, alpha.last.clone(), last_label_idx);
422    let log_alpha_last_label = B::float_reshape(log_alpha_last_label, Shape::new([batch_size]));
423
424    // For target_lengths == 0, last_label is meaningless: substitute -inf.
425    let target_len_zero = B::int_equal_elem(target_lengths, 0.into(), settings.bool_dtype);
426    let log_alpha_last_label = B::float_mask_fill(
427        log_alpha_last_label,
428        target_len_zero,
429        f32::NEG_INFINITY.into(),
430    );
431
432    let log_likelihood = log_sum_exp::<B>(
433        log_alpha_last_blank,
434        log_alpha_last_label,
435        settings.bool_dtype,
436    );
437    B::float_neg(log_likelihood)
438}
439
440/// Insert blank labels between each target label: [b, l1, b, l2, ..., b]
441fn insert_blanks<B: Backend>(
442    targets: &IntTensor<B>,
443    batch_size: usize,
444    max_target_len: usize,
445    max_l_prime_len: usize,
446    blank: usize,
447    device: &B::Device,
448    int_dtype: burn_std::IntDType,
449) -> IntTensor<B> {
450    let result = B::int_full(
451        Shape::new([batch_size, max_l_prime_len]),
452        (blank as i64).into(),
453        device,
454        int_dtype,
455    );
456
457    if max_target_len == 0 {
458        return result;
459    }
460
461    // Place every target label at odd columns {1, 3, 5, ...} in one
462    // strided slice_assign, equivalent to `result[:, 1::2] = targets`.
463    B::int_slice_assign(
464        result,
465        &[Slice::full(), Slice::new(1, None, 2)],
466        targets.clone(),
467    )
468}
469
470/// Right-shift a 2D float tensor by `shift` positions, prepending the
471/// pre-allocated `padding` tensor (shape `[batch_size, shift]`, value
472/// `-inf`) instead of materializing it each call.
473///
474/// Called inside the T-loop of the alpha recursion; hoisting the padding
475/// out of the loop eliminates `O(T)` `float_full` allocations.
476fn right_shift<B: Backend>(
477    tensor: &FloatTensor<B>,
478    padding: &FloatTensor<B>,
479    cols: usize,
480    shift: usize,
481) -> FloatTensor<B> {
482    // Shifting by more than the column count pushes every data slot off
483    // the right. Avoid the `cols - shift` usize underflow when
484    // `max_target_len == 0` (so `max_l_prime_len == 1`) by narrowing the
485    // all-`-inf` padding down to `cols`.
486    if cols < shift {
487        return B::float_slice(
488            padding.clone(),
489            &[Slice::full(), Slice::new(0, Some(cols as isize), 1)],
490        );
491    }
492    let shortened = B::float_slice(
493        tensor.clone(),
494        &[
495            Slice::full(),
496            Slice::new(0, Some((cols - shift) as isize), 1),
497        ],
498    );
499    B::float_cat(alloc::vec![padding.clone(), shortened], 1)
500}
501
502/// Compute `log(exp(a) + exp(b))` in a numerically stable way.
503///
504/// `log_sum_exp(a, b) = max(a, b) + log1p(exp(-|a - b|))`. The edge case is
505/// `a = b = -inf`, where `-|(-inf) - (-inf)| = NaN`; we detect `max == -inf`
506/// and substitute a `-inf` diff so the final sum stays `-inf` (both via the
507/// mask and because `log1p(exp(-inf)) = 0`). Gradient-safe: no `NaN` flows
508/// through the forward intermediates when inputs are `-inf`.
509///
510/// Precondition: inputs must be `<= 0` (log-probabilities). `+inf` inputs are
511/// not guarded and produce `NaN`; callers outside the CTC recursion should
512/// validate this themselves.
513fn log_sum_exp<B: Backend>(
514    a: FloatTensor<B>,
515    b: FloatTensor<B>,
516    bool_dtype: burn_std::BoolDType,
517) -> FloatTensor<B> {
518    // `-inf` values in `a` or `b` would make `a - b` evaluate to `NaN`
519    // (when both are `-inf`) and the backward pass through that `NaN`
520    // intermediate propagates `NaN` into the gradient even when the
521    // forward mask discards it (`0 * NaN = NaN` in IEEE). Clamp `-inf`
522    // to `0` on safe copies used only for the diff computation; compute
523    // `max` on the original values so its output is correct in the
524    // `-inf` cases.
525    let a_is_neg_inf = B::float_equal_elem(a.clone(), f32::NEG_INFINITY.into(), bool_dtype);
526    let b_is_neg_inf = B::float_equal_elem(b.clone(), f32::NEG_INFINITY.into(), bool_dtype);
527    let either_neg_inf = B::bool_or(a_is_neg_inf.clone(), b_is_neg_inf.clone());
528
529    let a_safe = B::float_mask_fill(a.clone(), a_is_neg_inf, 0.0.into());
530    let b_safe = B::float_mask_fill(b.clone(), b_is_neg_inf, 0.0.into());
531
532    let lt_mask = B::float_lower(a.clone(), b.clone(), bool_dtype);
533    let mx = B::float_mask_where(a, lt_mask, b);
534
535    // diff_safe = -|a_safe - b_safe|. Finite by construction. When either
536    // input was `-inf`, force it to `-inf` so `exp(diff) == 0` and the
537    // `log1p` term contributes nothing (`result = mx`). When both were
538    // `-inf`, `mx = -inf` so `result = -inf + 0 = -inf`.
539    let diff_safe = B::float_neg(B::float_abs(B::float_sub(a_safe, b_safe)));
540    let diff_final = B::float_mask_fill(diff_safe, either_neg_inf, f32::NEG_INFINITY.into());
541
542    B::float_add(mx, B::float_log1p(B::float_exp(diff_final)))
543}
544
545/// Mask for the alpha skip transition: `l'[s] != blank AND l'[s] != l'[s-2] AND s >= 2`.
546fn create_l_prime_mask<B: Backend>(
547    blank_inserted_targets: &IntTensor<B>,
548    batch_size: usize,
549    max_l_prime_len: usize,
550    blank: usize,
551    device: &B::Device,
552    int_dtype: burn_std::IntDType,
553    bool_dtype: burn_std::BoolDType,
554) -> BoolTensor<B> {
555    // The mask requires `s >= 2`, which is unsatisfiable when max_l_prime_len < 2
556    // (i.e. targets have shape [N, 0]). Bail out before the `max_l_prime_len - 2`
557    // usize subtraction underflows.
558    if max_l_prime_len < 2 {
559        return B::bool_zeros(
560            Shape::new([batch_size, max_l_prime_len]),
561            device,
562            bool_dtype,
563        );
564    }
565    let l_prime = blank_inserted_targets.clone();
566
567    let not_blank = B::int_not_equal_elem(l_prime.clone(), (blank as i64).into(), bool_dtype);
568
569    let l_prime_shifted = {
570        let padding = B::int_full(
571            Shape::new([batch_size, 2]),
572            (blank as i64).into(),
573            device,
574            int_dtype,
575        );
576        let shortened = B::int_slice(
577            l_prime.clone(),
578            &[
579                Slice::full(),
580                Slice::new(0, Some((max_l_prime_len - 2) as isize), 1),
581            ],
582        );
583        B::int_cat(alloc::vec![padding, shortened], 1)
584    };
585    let not_equal_s_m2 = B::int_not_equal(l_prime, l_prime_shifted, bool_dtype);
586
587    let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
588    let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
589    let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
590    let s_ge_2 = B::int_greater_equal_elem(col_indices, 2.into(), bool_dtype);
591
592    B::bool_and(B::bool_and(not_blank, not_equal_s_m2), s_ge_2)
593}
594
595/// Create a mask for valid s positions: s < 2 * target_length + 1
596fn create_s_mask<B: Backend>(
597    target_lengths: &IntTensor<B>,
598    batch_size: usize,
599    max_l_prime_len: usize,
600    device: &B::Device,
601    int_dtype: burn_std::IntDType,
602    bool_dtype: burn_std::BoolDType,
603) -> BoolTensor<B> {
604    let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
605    let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
606    let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
607
608    let lengths = B::int_mul_scalar(target_lengths.clone(), 2.into());
609    let lengths = B::int_add_scalar(lengths, 1.into());
610    let lengths = B::int_reshape(lengths, Shape::new([batch_size, 1]));
611    let lengths = B::int_expand(lengths, Shape::new([batch_size, max_l_prime_len]));
612
613    B::int_lower(col_indices, lengths, bool_dtype)
614}