use burn_std::{Shape, Slice};
use crate::{
Backend, TensorMetadata, get_device_settings,
tensor::{BoolTensor, FloatTensor, IntTensor},
};
pub fn ctc_loss_default<B: Backend>(
log_probs: FloatTensor<B>,
targets: IntTensor<B>,
input_lengths: IntTensor<B>,
target_lengths: IntTensor<B>,
blank: usize,
) -> FloatTensor<B> {
let alpha = AlphaCtx::<B>::compute(
log_probs,
&targets,
input_lengths,
target_lengths.clone(),
blank,
);
extract_loss::<B>(&alpha, target_lengths)
}
#[allow(clippy::too_many_arguments)]
pub fn ctc_grad_from_alpha_beta_default<B: Backend>(
log_probs: FloatTensor<B>,
targets: IntTensor<B>,
input_lengths: IntTensor<B>,
grad_loss: FloatTensor<B>,
log_alpha_full: FloatTensor<B>,
log_beta_full: FloatTensor<B>,
nll: FloatTensor<B>,
blank: usize,
) -> FloatTensor<B> {
let log_probs_shape = log_probs.shape();
let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
let target_shape = targets.shape();
let max_target_len = target_shape.dims::<2>()[1];
let max_l_prime_len = 2 * max_target_len + 1;
let device = B::float_device(&log_probs);
let int_dtype: burn_std::IntDType = targets.dtype().into();
let settings = get_device_settings::<B>(&device);
let blank_inserted_targets = insert_blanks::<B>(
&targets,
batch_size,
max_target_len,
max_l_prime_len,
blank,
&device,
int_dtype,
);
let indices_3d = B::int_reshape(
blank_inserted_targets,
Shape::new([1, batch_size, max_l_prime_len]),
);
let indices_3d = B::int_expand(
indices_3d,
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let log_probs_at_l = B::float_gather(2, log_probs.clone(), indices_3d.clone());
let nll_is_inf = B::float_is_inf(nll.clone(), settings.bool_dtype);
let nll_b = B::float_reshape(nll, Shape::new([1, batch_size, 1]));
let nll_b = B::float_expand(
nll_b,
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let log_post = B::float_add(
B::float_sub(B::float_add(log_alpha_full, log_beta_full), log_probs_at_l),
nll_b,
);
let grad_loss_3d = B::float_reshape(grad_loss, Shape::new([1, batch_size, 1]));
let grad_loss_b = B::float_expand(
grad_loss_3d.clone(),
Shape::new([max_input_length, batch_size, num_classes]),
);
let mut grad = B::float_mul(B::float_exp(log_probs), grad_loss_b);
let grad_loss_post = B::float_expand(
grad_loss_3d,
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let scatter_value = B::float_neg(B::float_mul(B::float_exp(log_post), grad_loss_post));
grad = B::float_scatter_add(2, grad, indices_3d, scatter_value);
let t_indices = B::int_arange(0..max_input_length as i64, &device, int_dtype);
let t_indices = B::int_reshape(t_indices, Shape::new([max_input_length, 1, 1]));
let t_indices = B::int_expand(
t_indices,
Shape::new([max_input_length, batch_size, num_classes]),
);
let il_b = B::int_reshape(input_lengths, Shape::new([1, batch_size, 1]));
let il_b = B::int_expand(
il_b,
Shape::new([max_input_length, batch_size, num_classes]),
);
let oob_mask = B::int_greater_equal(t_indices, il_b, settings.bool_dtype);
let nll_inf_b = B::bool_reshape(nll_is_inf, Shape::new([1, batch_size, 1]));
let nll_inf_b = B::bool_expand(
nll_inf_b,
Shape::new([max_input_length, batch_size, num_classes]),
);
let mask = B::bool_or(oob_mask, nll_inf_b);
B::float_mask_fill(grad, mask, 0.0.into())
}
#[allow(dead_code)]
struct AlphaCtx<B: Backend> {
full: FloatTensor<B>,
last: FloatTensor<B>,
blank_inserted_targets: IntTensor<B>,
log_probs_at_l_full: FloatTensor<B>,
max_l_prime_len: usize,
}
impl<B: Backend> AlphaCtx<B> {
fn compute(
log_probs: FloatTensor<B>,
targets: &IntTensor<B>,
input_lengths: IntTensor<B>,
target_lengths: IntTensor<B>,
blank: usize,
) -> Self {
let log_probs_shape = log_probs.shape();
let [max_input_length, batch_size, num_classes] = log_probs_shape.dims::<3>();
let target_shape = targets.shape();
let max_target_len = target_shape.dims::<2>()[1];
let device = B::float_device(&log_probs);
let float_dtype: burn_std::FloatDType = log_probs.dtype().into();
let int_dtype: burn_std::IntDType = targets.dtype().into();
let settings = get_device_settings::<B>(&device);
let max_l_prime_len = 2 * max_target_len + 1;
let blank_inserted_targets = insert_blanks::<B>(
targets,
batch_size,
max_target_len,
max_l_prime_len,
blank,
&device,
int_dtype,
);
let mut alpha_full = B::float_full(
Shape::new([max_input_length, batch_size, max_l_prime_len]),
f32::NEG_INFINITY.into(),
&device,
float_dtype,
);
let log_probs_t0 = B::float_slice(
log_probs.clone(),
&[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
);
let log_probs_t0 = B::float_reshape(log_probs_t0, Shape::new([batch_size, num_classes]));
let first_blank = B::int_slice(
blank_inserted_targets.clone(),
&[Slice::full(), Slice::new(0, Some(1), 1)],
);
let log_prob_blank = B::float_gather(1, log_probs_t0.clone(), first_blank);
let log_prob_blank_3d = B::float_reshape(log_prob_blank, Shape::new([1, batch_size, 1]));
alpha_full = B::float_slice_assign(
alpha_full,
&[
Slice::new(0, Some(1), 1),
Slice::full(),
Slice::new(0, Some(1), 1),
],
log_prob_blank_3d,
);
if max_l_prime_len > 1 {
let first_label = B::int_slice(
blank_inserted_targets.clone(),
&[Slice::full(), Slice::new(1, Some(2), 1)],
);
let log_prob_first = B::float_gather(1, log_probs_t0, first_label);
let log_prob_first_3d =
B::float_reshape(log_prob_first, Shape::new([1, batch_size, 1]));
alpha_full = B::float_slice_assign(
alpha_full,
&[
Slice::new(0, Some(1), 1),
Slice::full(),
Slice::new(1, Some(2), 1),
],
log_prob_first_3d,
);
}
let mut log_alpha = B::float_slice(
alpha_full.clone(),
&[Slice::new(0, Some(1), 1), Slice::full(), Slice::full()],
);
log_alpha = B::float_reshape(log_alpha, Shape::new([batch_size, max_l_prime_len]));
let l_prime_mask = create_l_prime_mask::<B>(
&blank_inserted_targets,
batch_size,
max_l_prime_len,
blank,
&device,
int_dtype,
settings.bool_dtype,
);
let s_mask = create_s_mask::<B>(
&target_lengths,
batch_size,
max_l_prime_len,
&device,
int_dtype,
settings.bool_dtype,
);
let pad_1 = B::float_full(
Shape::new([batch_size, 1]),
f32::NEG_INFINITY.into(),
&device,
float_dtype,
);
let pad_2 = B::float_full(
Shape::new([batch_size, 2]),
f32::NEG_INFINITY.into(),
&device,
float_dtype,
);
let indices_3d = B::int_expand(
B::int_reshape(
blank_inserted_targets.clone(),
Shape::new([1, batch_size, max_l_prime_len]),
),
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let log_probs_at_l_full = B::float_gather(2, log_probs.clone(), indices_3d);
let t_indices_2d = B::int_expand(
B::int_reshape(
B::int_arange(0..max_input_length as i64, &device, int_dtype),
Shape::new([max_input_length, 1]),
),
Shape::new([max_input_length, batch_size]),
);
let il_tn = B::int_expand(
B::int_reshape(input_lengths.clone(), Shape::new([1, batch_size])),
Shape::new([max_input_length, batch_size]),
);
let t_mask_all = B::bool_expand(
B::bool_reshape(
B::int_greater(il_tn, t_indices_2d, settings.bool_dtype),
Shape::new([max_input_length, batch_size, 1]),
),
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let s_mask_bcast = B::bool_expand(
B::bool_reshape(s_mask.clone(), Shape::new([1, batch_size, max_l_prime_len])),
Shape::new([max_input_length, batch_size, max_l_prime_len]),
);
let combined_mask_all = B::bool_and(t_mask_all, s_mask_bcast);
for t in 1..max_input_length {
let combined_mask = B::bool_reshape(
B::bool_slice(
combined_mask_all.clone(),
&[
Slice::new(t as isize, Some(t as isize + 1), 1),
Slice::full(),
Slice::full(),
],
),
Shape::new([batch_size, max_l_prime_len]),
);
let log_alpha_s = log_alpha.clone();
let log_alpha_s_m1 = right_shift::<B>(&log_alpha, &pad_1, max_l_prime_len, 1);
let log_alpha_s_m2 = right_shift::<B>(&log_alpha, &pad_2, max_l_prime_len, 2);
let bar = log_sum_exp::<B>(log_alpha_s, log_alpha_s_m1, settings.bool_dtype);
let bar_with_skip = log_sum_exp::<B>(bar.clone(), log_alpha_s_m2, settings.bool_dtype);
let log_alpha_combined = B::float_mask_where(bar, l_prime_mask.clone(), bar_with_skip);
let log_probs_at_l = B::float_reshape(
B::float_slice(
log_probs_at_l_full.clone(),
&[
Slice::new(t as isize, Some(t as isize + 1), 1),
Slice::full(),
Slice::full(),
],
),
Shape::new([batch_size, max_l_prime_len]),
);
let new_alpha = B::float_add(log_alpha_combined, log_probs_at_l);
log_alpha = B::float_mask_where(log_alpha, combined_mask, new_alpha);
let log_alpha_3d = B::float_reshape(
log_alpha.clone(),
Shape::new([1, batch_size, max_l_prime_len]),
);
alpha_full = B::float_slice_assign(
alpha_full,
&[
Slice::new(t as isize, Some(t as isize + 1), 1),
Slice::full(),
Slice::full(),
],
log_alpha_3d,
);
}
Self {
full: alpha_full,
last: log_alpha,
blank_inserted_targets,
log_probs_at_l_full,
max_l_prime_len,
}
}
}
fn extract_loss<B: Backend>(alpha: &AlphaCtx<B>, target_lengths: IntTensor<B>) -> FloatTensor<B> {
let log_alpha_shape = alpha.last.shape();
let [batch_size, _] = log_alpha_shape.dims::<2>();
let device = B::float_device(&alpha.last);
let settings = get_device_settings::<B>(&device);
let last_blank_idx = B::int_mul_scalar(target_lengths.clone(), 2.into());
let last_blank_idx = B::int_reshape(last_blank_idx, Shape::new([batch_size, 1]));
let last_label_idx = B::int_clamp_min(
B::int_sub_scalar(last_blank_idx.clone(), 1.into()),
0.into(),
);
let log_alpha_last_blank = B::float_gather(1, alpha.last.clone(), last_blank_idx);
let log_alpha_last_blank = B::float_reshape(log_alpha_last_blank, Shape::new([batch_size]));
let log_alpha_last_label = B::float_gather(1, alpha.last.clone(), last_label_idx);
let log_alpha_last_label = B::float_reshape(log_alpha_last_label, Shape::new([batch_size]));
let target_len_zero = B::int_equal_elem(target_lengths, 0.into(), settings.bool_dtype);
let log_alpha_last_label = B::float_mask_fill(
log_alpha_last_label,
target_len_zero,
f32::NEG_INFINITY.into(),
);
let log_likelihood = log_sum_exp::<B>(
log_alpha_last_blank,
log_alpha_last_label,
settings.bool_dtype,
);
B::float_neg(log_likelihood)
}
fn insert_blanks<B: Backend>(
targets: &IntTensor<B>,
batch_size: usize,
max_target_len: usize,
max_l_prime_len: usize,
blank: usize,
device: &B::Device,
int_dtype: burn_std::IntDType,
) -> IntTensor<B> {
let result = B::int_full(
Shape::new([batch_size, max_l_prime_len]),
(blank as i64).into(),
device,
int_dtype,
);
if max_target_len == 0 {
return result;
}
B::int_slice_assign(
result,
&[Slice::full(), Slice::new(1, None, 2)],
targets.clone(),
)
}
fn right_shift<B: Backend>(
tensor: &FloatTensor<B>,
padding: &FloatTensor<B>,
cols: usize,
shift: usize,
) -> FloatTensor<B> {
if cols < shift {
return B::float_slice(
padding.clone(),
&[Slice::full(), Slice::new(0, Some(cols as isize), 1)],
);
}
let shortened = B::float_slice(
tensor.clone(),
&[
Slice::full(),
Slice::new(0, Some((cols - shift) as isize), 1),
],
);
B::float_cat(alloc::vec![padding.clone(), shortened], 1)
}
fn log_sum_exp<B: Backend>(
a: FloatTensor<B>,
b: FloatTensor<B>,
bool_dtype: burn_std::BoolDType,
) -> FloatTensor<B> {
let a_is_neg_inf = B::float_equal_elem(a.clone(), f32::NEG_INFINITY.into(), bool_dtype);
let b_is_neg_inf = B::float_equal_elem(b.clone(), f32::NEG_INFINITY.into(), bool_dtype);
let either_neg_inf = B::bool_or(a_is_neg_inf.clone(), b_is_neg_inf.clone());
let a_safe = B::float_mask_fill(a.clone(), a_is_neg_inf, 0.0.into());
let b_safe = B::float_mask_fill(b.clone(), b_is_neg_inf, 0.0.into());
let lt_mask = B::float_lower(a.clone(), b.clone(), bool_dtype);
let mx = B::float_mask_where(a, lt_mask, b);
let diff_safe = B::float_neg(B::float_abs(B::float_sub(a_safe, b_safe)));
let diff_final = B::float_mask_fill(diff_safe, either_neg_inf, f32::NEG_INFINITY.into());
B::float_add(mx, B::float_log1p(B::float_exp(diff_final)))
}
fn create_l_prime_mask<B: Backend>(
blank_inserted_targets: &IntTensor<B>,
batch_size: usize,
max_l_prime_len: usize,
blank: usize,
device: &B::Device,
int_dtype: burn_std::IntDType,
bool_dtype: burn_std::BoolDType,
) -> BoolTensor<B> {
if max_l_prime_len < 2 {
return B::bool_zeros(
Shape::new([batch_size, max_l_prime_len]),
device,
bool_dtype,
);
}
let l_prime = blank_inserted_targets.clone();
let not_blank = B::int_not_equal_elem(l_prime.clone(), (blank as i64).into(), bool_dtype);
let l_prime_shifted = {
let padding = B::int_full(
Shape::new([batch_size, 2]),
(blank as i64).into(),
device,
int_dtype,
);
let shortened = B::int_slice(
l_prime.clone(),
&[
Slice::full(),
Slice::new(0, Some((max_l_prime_len - 2) as isize), 1),
],
);
B::int_cat(alloc::vec![padding, shortened], 1)
};
let not_equal_s_m2 = B::int_not_equal(l_prime, l_prime_shifted, bool_dtype);
let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
let s_ge_2 = B::int_greater_equal_elem(col_indices, 2.into(), bool_dtype);
B::bool_and(B::bool_and(not_blank, not_equal_s_m2), s_ge_2)
}
fn create_s_mask<B: Backend>(
target_lengths: &IntTensor<B>,
batch_size: usize,
max_l_prime_len: usize,
device: &B::Device,
int_dtype: burn_std::IntDType,
bool_dtype: burn_std::BoolDType,
) -> BoolTensor<B> {
let col_indices = B::int_arange(0..max_l_prime_len as i64, device, int_dtype);
let col_indices = B::int_reshape(col_indices, Shape::new([1, max_l_prime_len]));
let col_indices = B::int_expand(col_indices, Shape::new([batch_size, max_l_prime_len]));
let lengths = B::int_mul_scalar(target_lengths.clone(), 2.into());
let lengths = B::int_add_scalar(lengths, 1.into());
let lengths = B::int_reshape(lengths, Shape::new([batch_size, 1]));
let lengths = B::int_expand(lengths, Shape::new([batch_size, max_l_prime_len]));
B::int_lower(col_indices, lengths, bool_dtype)
}