1use burn_std::{Shape, Slice};
2
3use crate::{
4 Backend, TensorMetadata, get_device_settings,
5 tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8pub fn ctc_loss_default<B: Backend>(
25 log_probs: FloatTensor<B>,
26 targets: IntTensor<B>,
27 input_lengths: IntTensor<B>,
28 target_lengths: IntTensor<B>,
29 blank: usize,
30) -> FloatTensor<B> {
31 let alpha = AlphaCtx::<B>::compute(
32 log_probs,
33 &targets,
34 input_lengths,
35 target_lengths.clone(),
36 blank,
37 );
38 extract_loss::<B>(&alpha, target_lengths)
39}
40
41#[allow(clippy::too_many_arguments)]
58pub fn ctc_grad_from_alpha_beta_default<B: Backend>(
59 log_probs: FloatTensor<B>,
60 targets: IntTensor<B>,
61 input_lengths: IntTensor<B>,
62 grad_loss: FloatTensor<B>,
63 log_alpha_full: FloatTensor<B>,
64 log_beta_full: FloatTensor<B>,
65 nll: FloatTensor<B>,
66 blank: usize,
67) -> FloatTensor<B> {
68 let log_probs_shape = log_probs.shape();
69 let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
70 let target_shape = targets.shape();
71 let max_target_len = target_shape.dims::<2>()[1];
72 let max_l_prime_len = 2 * max_target_len + 1;
73 let device = B::float_device(&log_probs);
74 let int_dtype: burn_std::IntDType = targets.dtype().into();
75 let settings = get_device_settings::<B>(&device);
76
77 let blank_inserted_targets = insert_blanks::<B>(
78 &targets,
79 batch_size,
80 max_target_len,
81 max_l_prime_len,
82 blank,
83 &device,
84 int_dtype,
85 );
86
87 let indices_3d = B::int_reshape(
98 blank_inserted_targets,
99 Shape::new([1, batch_size, max_l_prime_len]),
100 );
101 let indices_3d = B::int_expand(
102 indices_3d,
103 Shape::new([max_input_length, batch_size, max_l_prime_len]),
104 );
105 let log_probs_at_l = B::float_gather(2, log_probs.clone(), indices_3d.clone());
106
107 let nll_is_inf = B::float_is_inf(nll.clone(), settings.bool_dtype);
114
115 let nll_b = B::float_reshape(nll, Shape::new([1, batch_size, 1]));
116 let nll_b = B::float_expand(
117 nll_b,
118 Shape::new([max_input_length, batch_size, max_l_prime_len]),
119 );
120 let log_post = B::float_add(
121 B::float_sub(B::float_add(log_alpha_full, log_beta_full), log_probs_at_l),
122 nll_b,
123 );
124
125 let grad_loss_3d = B::float_reshape(grad_loss, Shape::new([1, batch_size, 1]));
127 let grad_loss_b = B::float_expand(
128 grad_loss_3d.clone(),
129 Shape::new([max_input_length, batch_size, num_classes]),
130 );
131 let mut grad = B::float_mul(B::float_exp(log_probs), grad_loss_b);
132
133 let grad_loss_post = B::float_expand(
135 grad_loss_3d,
136 Shape::new([max_input_length, batch_size, max_l_prime_len]),
137 );
138 let scatter_value = B::float_neg(B::float_mul(B::float_exp(log_post), grad_loss_post));
139
140 grad = B::float_scatter_add(2, grad, indices_3d, scatter_value);
141
142 let t_indices = B::int_arange(0..max_input_length as i64, &device, int_dtype);
144 let t_indices = B::int_reshape(t_indices, Shape::new([max_input_length, 1, 1]));
145 let t_indices = B::int_expand(
146 t_indices,
147 Shape::new([max_input_length, batch_size, num_classes]),
148 );
149 let il_b = B::int_reshape(input_lengths, Shape::new([1, batch_size, 1]));
150 let il_b = B::int_expand(
151 il_b,
152 Shape::new([max_input_length, batch_size, num_classes]),
153 );
154 let oob_mask = B::int_greater_equal(t_indices, il_b, settings.bool_dtype);
155
156 let nll_inf_b = B::bool_reshape(nll_is_inf, Shape::new([1, batch_size, 1]));
159 let nll_inf_b = B::bool_expand(
160 nll_inf_b,
161 Shape::new([max_input_length, batch_size, num_classes]),
162 );
163 let mask = B::bool_or(oob_mask, nll_inf_b);
164 B::float_mask_fill(grad, mask, 0.0.into())
165}
166
167#[allow(dead_code)]
172struct AlphaCtx<B: Backend> {
173 full: FloatTensor<B>,
175 last: FloatTensor<B>,
177 blank_inserted_targets: IntTensor<B>,
179 log_probs_at_l_full: FloatTensor<B>,
181 max_l_prime_len: usize,
182}
183
184impl<B: Backend> AlphaCtx<B> {
185 fn compute(
186 log_probs: FloatTensor<B>,
187 targets: &IntTensor<B>,
188 input_lengths: IntTensor<B>,
189 target_lengths: IntTensor<B>,
190 blank: usize,
191 ) -> Self {
192 let log_probs_shape = log_probs.shape();
193 let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
194 let target_shape = targets.shape();
195 let max_target_len = target_shape.dims::<2>()[1];
196 let device = B::float_device(&log_probs);
197 let float_dtype: burn_std::FloatDType = log_probs.dtype().into();
198 let int_dtype: burn_std::IntDType = targets.dtype().into();
199 let settings = get_device_settings::<B>(&device);
200
201 let max_l_prime_len = 2 * max_target_len + 1;
202 let blank_inserted_targets = insert_blanks::<B>(
203 targets,
204 batch_size,
205 max_target_len,
206 max_l_prime_len,
207 blank,
208 &device,
209 int_dtype,
210 );
211
212 let mut alpha_full = B::float_full(
214 Shape::new([max_input_length, batch_size, max_l_prime_len]),
215 f32::NEG_INFINITY.into(),
216 &device,
217 float_dtype,
218 );
219
220 let log_probs_t0 = B::float_slice(
223 log_probs.clone(),
224 &[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
225 );
226 let log_probs_t0 = B::float_reshape(log_probs_t0, Shape::new([batch_size, num_classes]));
227
228 let first_blank = B::int_slice(
229 blank_inserted_targets.clone(),
230 &[Slice::full(), Slice::new(0, Some(1), 1)],
231 );
232 let log_prob_blank = B::float_gather(1, log_probs_t0.clone(), first_blank);
233 let log_prob_blank_3d = B::float_reshape(log_prob_blank, Shape::new([1, batch_size, 1]));
235 alpha_full = B::float_slice_assign(
236 alpha_full,
237 &[
238 Slice::new(0, Some(1), 1),
239 Slice::full(),
240 Slice::new(0, Some(1), 1),
241 ],
242 log_prob_blank_3d,
243 );
244
245 if max_l_prime_len > 1 {
246 let first_label = B::int_slice(
247 blank_inserted_targets.clone(),
248 &[Slice::full(), Slice::new(1, Some(2), 1)],
249 );
250 let log_prob_first = B::float_gather(1, log_probs_t0, first_label);
251 let log_prob_first_3d =
252 B::float_reshape(log_prob_first, Shape::new([1, batch_size, 1]));
253 alpha_full = B::float_slice_assign(
254 alpha_full,
255 &[
256 Slice::new(0, Some(1), 1),
257 Slice::full(),
258 Slice::new(1, Some(2), 1),
259 ],
260 log_prob_first_3d,
261 );
262 }
263
264 let mut log_alpha = B::float_slice(
267 alpha_full.clone(),
268 &[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
269 );
270 log_alpha = B::float_reshape(log_alpha, Shape::new([batch_size, max_l_prime_len]));
271
272 let l_prime_mask = create_l_prime_mask::<B>(
273 &blank_inserted_targets,
274 batch_size,
275 max_l_prime_len,
276 blank,
277 &device,
278 int_dtype,
279 settings.bool_dtype,
280 );
281 let s_mask = create_s_mask::<B>(
282 &target_lengths,
283 batch_size,
284 max_l_prime_len,
285 &device,
286 int_dtype,
287 settings.bool_dtype,
288 );
289
290 let pad_1 = B::float_full(
295 Shape::new([batch_size, 1]),
296 f32::NEG_INFINITY.into(),
297 &device,
298 float_dtype,
299 );
300 let pad_2 = B::float_full(
301 Shape::new([batch_size, 2]),
302 f32::NEG_INFINITY.into(),
303 &device,
304 float_dtype,
305 );
306 let indices_3d = B::int_expand(
307 B::int_reshape(
308 blank_inserted_targets.clone(),
309 Shape::new([1, batch_size, max_l_prime_len]),
310 ),
311 Shape::new([max_input_length, batch_size, max_l_prime_len]),
312 );
313 let log_probs_at_l_full = B::float_gather(2, log_probs.clone(), indices_3d);
314
315 let t_indices_2d = B::int_expand(
320 B::int_reshape(
321 B::int_arange(0..max_input_length as i64, &device, int_dtype),
322 Shape::new([max_input_length, 1]),
323 ),
324 Shape::new([max_input_length, batch_size]),
325 );
326 let il_tn = B::int_expand(
327 B::int_reshape(input_lengths.clone(), Shape::new([1, batch_size])),
328 Shape::new([max_input_length, batch_size]),
329 );
330 let t_mask_all = B::bool_expand(
331 B::bool_reshape(
332 B::int_greater(il_tn, t_indices_2d, settings.bool_dtype),
333 Shape::new([max_input_length, batch_size, 1]),
334 ),
335 Shape::new([max_input_length, batch_size, max_l_prime_len]),
336 );
337 let s_mask_bcast = B::bool_expand(
338 B::bool_reshape(s_mask.clone(), Shape::new([1, batch_size, max_l_prime_len])),
339 Shape::new([max_input_length, batch_size, max_l_prime_len]),
340 );
341 let combined_mask_all = B::bool_and(t_mask_all, s_mask_bcast);
342
343 for t in 1..max_input_length {
344 let combined_mask = B::bool_reshape(
345 B::bool_slice(
346 combined_mask_all.clone(),
347 &[
348 Slice::new(t as isize, Some(t as isize + 1), 1),
349 Slice::full(),
350 Slice::full(),
351 ],
352 ),
353 Shape::new([batch_size, max_l_prime_len]),
354 );
355
356 let log_alpha_s = log_alpha.clone();
357 let log_alpha_s_m1 = right_shift::<B>(&log_alpha, &pad_1, max_l_prime_len, 1);
358 let log_alpha_s_m2 = right_shift::<B>(&log_alpha, &pad_2, max_l_prime_len, 2);
359
360 let bar = log_sum_exp::<B>(log_alpha_s, log_alpha_s_m1, settings.bool_dtype);
361 let bar_with_skip = log_sum_exp::<B>(bar.clone(), log_alpha_s_m2, settings.bool_dtype);
362 let log_alpha_combined = B::float_mask_where(bar, l_prime_mask.clone(), bar_with_skip);
363
364 let log_probs_at_l = B::float_reshape(
366 B::float_slice(
367 log_probs_at_l_full.clone(),
368 &[
369 Slice::new(t as isize, Some(t as isize + 1), 1),
370 Slice::full(),
371 Slice::full(),
372 ],
373 ),
374 Shape::new([batch_size, max_l_prime_len]),
375 );
376 let new_alpha = B::float_add(log_alpha_combined, log_probs_at_l);
377 log_alpha = B::float_mask_where(log_alpha, combined_mask, new_alpha);
378
379 let log_alpha_3d = B::float_reshape(
380 log_alpha.clone(),
381 Shape::new([1, batch_size, max_l_prime_len]),
382 );
383 alpha_full = B::float_slice_assign(
384 alpha_full,
385 &[
386 Slice::new(t as isize, Some(t as isize + 1), 1),
387 Slice::full(),
388 Slice::full(),
389 ],
390 log_alpha_3d,
391 );
392 }
393
394 Self {
395 full: alpha_full,
396 last: log_alpha,
397 blank_inserted_targets,
398 log_probs_at_l_full,
399 max_l_prime_len,
400 }
401 }
402}
403
404fn extract_loss<B: Backend>(alpha: &AlphaCtx<B>, target_lengths: IntTensor<B>) -> FloatTensor<B> {
406 let log_alpha_shape = alpha.last.shape();
407 let [batch_size, _] = log_alpha_shape.dims::<2>();
408 let device = B::float_device(&alpha.last);
409 let settings = get_device_settings::<B>(&device);
410
411 let last_blank_idx = B::int_mul_scalar(target_lengths.clone(), 2.into());
412 let last_blank_idx = B::int_reshape(last_blank_idx, Shape::new([batch_size, 1]));
413 let last_label_idx = B::int_clamp_min(
414 B::int_sub_scalar(last_blank_idx.clone(), 1.into()),
415 0.into(),
416 );
417
418 let log_alpha_last_blank = B::float_gather(1, alpha.last.clone(), last_blank_idx);
419 let log_alpha_last_blank = B::float_reshape(log_alpha_last_blank, Shape::new([batch_size]));
420
421 let log_alpha_last_label = B::float_gather(1, alpha.last.clone(), last_label_idx);
422 let log_alpha_last_label = B::float_reshape(log_alpha_last_label, Shape::new([batch_size]));
423
424 let target_len_zero = B::int_equal_elem(target_lengths, 0.into(), settings.bool_dtype);
426 let log_alpha_last_label = B::float_mask_fill(
427 log_alpha_last_label,
428 target_len_zero,
429 f32::NEG_INFINITY.into(),
430 );
431
432 let log_likelihood = log_sum_exp::<B>(
433 log_alpha_last_blank,
434 log_alpha_last_label,
435 settings.bool_dtype,
436 );
437 B::float_neg(log_likelihood)
438}
439
440fn insert_blanks<B: Backend>(
442 targets: &IntTensor<B>,
443 batch_size: usize,
444 max_target_len: usize,
445 max_l_prime_len: usize,
446 blank: usize,
447 device: &B::Device,
448 int_dtype: burn_std::IntDType,
449) -> IntTensor<B> {
450 let result = B::int_full(
451 Shape::new([batch_size, max_l_prime_len]),
452 (blank as i64).into(),
453 device,
454 int_dtype,
455 );
456
457 if max_target_len == 0 {
458 return result;
459 }
460
461 B::int_slice_assign(
464 result,
465 &[Slice::full(), Slice::new(1, None, 2)],
466 targets.clone(),
467 )
468}
469
470fn right_shift<B: Backend>(
477 tensor: &FloatTensor<B>,
478 padding: &FloatTensor<B>,
479 cols: usize,
480 shift: usize,
481) -> FloatTensor<B> {
482 if cols < shift {
487 return B::float_slice(
488 padding.clone(),
489 &[Slice::full(), Slice::new(0, Some(cols as isize), 1)],
490 );
491 }
492 let shortened = B::float_slice(
493 tensor.clone(),
494 &[
495 Slice::full(),
496 Slice::new(0, Some((cols - shift) as isize), 1),
497 ],
498 );
499 B::float_cat(alloc::vec![padding.clone(), shortened], 1)
500}
501
502fn log_sum_exp<B: Backend>(
514 a: FloatTensor<B>,
515 b: FloatTensor<B>,
516 bool_dtype: burn_std::BoolDType,
517) -> FloatTensor<B> {
518 let a_is_neg_inf = B::float_equal_elem(a.clone(), f32::NEG_INFINITY.into(), bool_dtype);
526 let b_is_neg_inf = B::float_equal_elem(b.clone(), f32::NEG_INFINITY.into(), bool_dtype);
527 let either_neg_inf = B::bool_or(a_is_neg_inf.clone(), b_is_neg_inf.clone());
528
529 let a_safe = B::float_mask_fill(a.clone(), a_is_neg_inf, 0.0.into());
530 let b_safe = B::float_mask_fill(b.clone(), b_is_neg_inf, 0.0.into());
531
532 let lt_mask = B::float_lower(a.clone(), b.clone(), bool_dtype);
533 let mx = B::float_mask_where(a, lt_mask, b);
534
535 let diff_safe = B::float_neg(B::float_abs(B::float_sub(a_safe, b_safe)));
540 let diff_final = B::float_mask_fill(diff_safe, either_neg_inf, f32::NEG_INFINITY.into());
541
542 B::float_add(mx, B::float_log1p(B::float_exp(diff_final)))
543}
544
545fn create_l_prime_mask<B: Backend>(
547 blank_inserted_targets: &IntTensor<B>,
548 batch_size: usize,
549 max_l_prime_len: usize,
550 blank: usize,
551 device: &B::Device,
552 int_dtype: burn_std::IntDType,
553 bool_dtype: burn_std::BoolDType,
554) -> BoolTensor<B> {
555 if max_l_prime_len < 2 {
559 return B::bool_zeros(
560 Shape::new([batch_size, max_l_prime_len]),
561 device,
562 bool_dtype,
563 );
564 }
565 let l_prime = blank_inserted_targets.clone();
566
567 let not_blank = B::int_not_equal_elem(l_prime.clone(), (blank as i64).into(), bool_dtype);
568
569 let l_prime_shifted = {
570 let padding = B::int_full(
571 Shape::new([batch_size, 2]),
572 (blank as i64).into(),
573 device,
574 int_dtype,
575 );
576 let shortened = B::int_slice(
577 l_prime.clone(),
578 &[
579 Slice::full(),
580 Slice::new(0, Some((max_l_prime_len - 2) as isize), 1),
581 ],
582 );
583 B::int_cat(alloc::vec![padding, shortened], 1)
584 };
585 let not_equal_s_m2 = B::int_not_equal(l_prime, l_prime_shifted, bool_dtype);
586
587 let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
588 let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
589 let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
590 let s_ge_2 = B::int_greater_equal_elem(col_indices, 2.into(), bool_dtype);
591
592 B::bool_and(B::bool_and(not_blank, not_equal_s_m2), s_ge_2)
593}
594
595fn create_s_mask<B: Backend>(
597 target_lengths: &IntTensor<B>,
598 batch_size: usize,
599 max_l_prime_len: usize,
600 device: &B::Device,
601 int_dtype: burn_std::IntDType,
602 bool_dtype: burn_std::BoolDType,
603) -> BoolTensor<B> {
604 let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
605 let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
606 let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
607
608 let lengths = B::int_mul_scalar(target_lengths.clone(), 2.into());
609 let lengths = B::int_add_scalar(lengths, 1.into());
610 let lengths = B::int_reshape(lengths, Shape::new([batch_size, 1]));
611 let lengths = B::int_expand(lengths, Shape::new([batch_size, max_l_prime_len]));
612
613 B::int_lower(col_indices, lengths, bool_dtype)
614}