#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use candle_core::{DType, Device, Result, Tensor};
pub fn build_delay_indices(
b: usize,
t: usize,
c: usize,
delay_pattern: &[i64],
device: &Device,
) -> Result<(Tensor, Tensor)> {
let delay_arr = Tensor::from_slice(delay_pattern, (c,), device)?;
let t_range = Tensor::arange(0i64, t as i64, device)?;
let t_idx_bxt = t_range.unsqueeze(0)?.expand((b, t))?;
let t_idx_bxtx1 = t_idx_bxt.unsqueeze(2)?;
let delay_view = delay_arr.reshape((1, 1, c))?;
let t_idx_bxtxc = t_idx_bxtx1.broadcast_sub(&delay_view)?;
let b_range = Tensor::arange(0i64, b as i64, device)?;
let b_idx_bxtxc = b_range.reshape((b, 1, 1))?.expand((b, t, c))?;
let c_range = Tensor::arange(0i64, c as i64, device)?;
let c_idx_bxtxc = c_range.reshape((1, 1, c))?.expand((b, t, c))?;
let t_max =
Tensor::from_slice(&[t as i64 - 1], (1,), device)?.broadcast_as(t_idx_bxtxc.shape())?;
let t_zero = Tensor::zeros((1,), DType::I64, device)?.broadcast_as(t_idx_bxtxc.shape())?;
let t_clamped_bxtxc = t_idx_bxtxc.clamp(&t_zero, &t_max)?;
let b_flat = b_idx_bxtxc.flatten_all()?;
let t_flat = t_clamped_bxtxc.flatten_all()?;
let c_flat = c_idx_bxtxc.flatten_all()?;
let indices_btcx3 = Tensor::stack(&[b_flat, t_flat, c_flat], 1)?;
Ok((t_idx_bxtxc, indices_btcx3))
}
pub fn apply_audio_delay(
audio_bxtxc: &Tensor,
pad_value: i64,
bos_value: i64,
precomp: &(Tensor, Tensor),
) -> Result<Tensor> {
let device = audio_bxtxc.device();
let (t_idx_bxtxc, indices_btcx3) = precomp;
let shape = audio_bxtxc.dims();
assert_eq!(shape.len(), 3, "Expected 3D tensor for audio_bxtxc");
let gathered_flat = gather_nd(audio_bxtxc, indices_btcx3)?;
let gathered_bxtxc = gathered_flat.reshape(shape)?;
let zero = Tensor::zeros((1,), DType::I64, device)?;
let t_len = Tensor::from_slice(&[shape[1] as i64], (1,), device)?;
let mask_bos = t_idx_bxtxc.broadcast_lt(&zero)?;
let mask_pad = t_idx_bxtxc.broadcast_ge(&t_len)?;
let bos_tensor =
Tensor::from_slice(&[bos_value], (1,), device)?.to_dtype(audio_bxtxc.dtype())?;
let pad_tensor =
Tensor::from_slice(&[pad_value], (1,), device)?.to_dtype(audio_bxtxc.dtype())?;
let temp = mask_pad.where_cond(
&pad_tensor.broadcast_as(mask_pad.shape())?,
&gathered_bxtxc.broadcast_as(mask_pad.shape())?,
)?;
let result_bxtxc = mask_bos.where_cond(&bos_tensor.broadcast_as(mask_pad.shape())?, &temp)?;
Ok(result_bxtxc)
}
pub fn build_revert_indices(
b: usize,
t: usize,
c: usize,
delay_pattern: &[i64],
device: &Device,
) -> Result<(Tensor, Tensor)> {
let delay_arr = Tensor::from_slice(delay_pattern, (c,), device)?;
let t_range = Tensor::arange(0i64, t as i64, device)?;
let t_idx_bt1 = t_range.unsqueeze(0)?.expand((b, t))?;
let t_idx_bt1 = t_idx_bt1.unsqueeze(2)?;
let delay_view = delay_arr.reshape((1, 1, c))?;
let t_plus_delay = t_idx_bt1.broadcast_add(&delay_view)?;
let t_max = Tensor::from_slice(&[t as i64 - 1], (1,), device)?;
let t_idx_bxtxc = t_plus_delay.broadcast_minimum(&t_max)?;
let b_range = Tensor::arange(0i64, b as i64, device)?;
let b_idx_bxtxc = b_range.reshape((b, 1, 1))?.expand((b, t, c))?;
let c_range = Tensor::arange(0i64, c as i64, device)?;
let c_idx_bxtxc = c_range.reshape((1, 1, c))?.expand((b, t, c))?;
let b_flat = b_idx_bxtxc.flatten_all()?;
let t_flat = t_idx_bxtxc.flatten_all()?;
let c_flat = c_idx_bxtxc.flatten_all()?;
let indices_btcx3 = Tensor::stack(&[b_flat, t_flat, c_flat], 1)?;
Ok((t_idx_bxtxc, indices_btcx3))
}
pub fn revert_audio_delay(
audio_bxtxc: &Tensor,
pad_value: i64,
precomp: &(Tensor, Tensor),
t: usize,
) -> Result<Tensor> {
let (t_idx_bxtxc, indices_btcx3) = precomp;
let device = audio_bxtxc.device();
let shape = audio_bxtxc.dims();
let gathered_flat = gather_nd(audio_bxtxc, indices_btcx3)?;
let gathered_bxtxc = gathered_flat.reshape(shape)?;
let t_len = Tensor::from_slice(&[t as i64], (1,), device)?;
let mask_pad = t_idx_bxtxc.broadcast_ge(&t_len)?;
let pad_tensor =
Tensor::from_slice(&[pad_value], (1,), device)?.to_dtype(audio_bxtxc.dtype())?;
let result_bxtxc =
mask_pad.where_cond(&pad_tensor.broadcast_as(mask_pad.shape())?, &gathered_bxtxc)?;
Ok(result_bxtxc)
}
pub fn gather_nd(tensor: &Tensor, indices: &Tensor) -> Result<Tensor> {
let n_indices = indices.dim(0)?;
let mut results = Vec::with_capacity(n_indices);
for i in 0..n_indices {
let idx = indices.get(i)?;
let b = idx.get(0)?.to_scalar::<i64>()?;
let t = idx.get(1)?.to_scalar::<i64>()?;
let c = idx.get(2)?.to_scalar::<i64>()?;
let value = tensor.get(b as usize)?.get(t as usize)?.get(c as usize)?;
results.push(value);
}
Tensor::stack(&results, 0)
}