use crate::DType;
use numr::algorithm::fft::{FftAlgorithms, FftDirection, FftNormalization};
use numr::error::{Error, Result};
use numr::ops::{ComplexOps, ScalarOps, ShapeOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn resample_impl<R, C>(client: &C, x: &Tensor<R>, num: usize, den: usize) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R>
+ ComplexOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ ShapeOps<R>
+ RuntimeClient<R>,
{
let n = x.shape()[0];
let device = x.device();
if n == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "Input signal cannot be empty".to_string(),
});
}
if num == 0 || den == 0 {
return Err(Error::InvalidArgument {
arg: "num/den",
reason: "Resampling factors must be positive".to_string(),
});
}
if num == den {
return Ok(x.clone());
}
let output_len = (n * num).div_ceil(den);
let fft_len = n.max(output_len).next_power_of_two();
let x_padded = if n < fft_len {
let pad_amount = fft_len - n;
client.pad(x, &[0, pad_amount], 0.0)?
} else {
x.clone()
};
let zeros = Tensor::zeros(&[fft_len], x.dtype(), device);
let x_complex = client.make_complex(&x_padded, &zeros)?;
let fft = client.fft(&x_complex, FftDirection::Forward, FftNormalization::None)?;
let half_n = n / 2;
let half_out = output_len / 2;
let pos_copy = half_n.min(half_out);
let neg_orig = (n.saturating_sub(1)) / 2;
let neg_out = (output_len.saturating_sub(1)) / 2;
let neg_copy = neg_orig.min(neg_out);
let out_fft_len = output_len.next_power_of_two();
let pos_len = pos_copy + 1; let pos_freqs = fft.narrow(0, 0, pos_len)?;
let neg_freqs = if neg_copy > 0 {
Some(fft.narrow(0, fft_len - neg_copy, neg_copy)?)
} else {
None
};
let new_fft = if let Some(neg) = neg_freqs {
let middle_len = output_len.saturating_sub(pos_len).saturating_sub(neg_copy);
if middle_len > 0 {
let zeros_real = Tensor::zeros(&[middle_len], x.dtype(), device);
let zeros_imag = Tensor::zeros(&[middle_len], x.dtype(), device);
let middle = client.make_complex(&zeros_real, &zeros_imag)?;
client.cat(&[&pos_freqs, &middle, &neg], 0)?
} else {
if pos_len + neg_copy == output_len {
client.cat(&[&pos_freqs, &neg], 0)?
} else if pos_len >= output_len {
pos_freqs.narrow(0, 0, output_len)?
} else {
let neg_keep = output_len - pos_len;
let neg_trunc = neg.narrow(0, neg_copy - neg_keep, neg_keep)?;
client.cat(&[&pos_freqs, &neg_trunc], 0)?
}
}
} else {
if output_len > pos_len {
let pad = output_len - pos_len;
let zeros_real = Tensor::zeros(&[pad], x.dtype(), device);
let zeros_imag = Tensor::zeros(&[pad], x.dtype(), device);
let padding = client.make_complex(&zeros_real, &zeros_imag)?;
client.cat(&[&pos_freqs, &padding], 0)?
} else if output_len < pos_len {
pos_freqs.narrow(0, 0, output_len)?
} else {
pos_freqs
}
};
debug_assert_eq!(new_fft.shape()[0], output_len);
let curr_len = new_fft.shape()[0];
let padded_fft = if curr_len < out_fft_len {
let pad = out_fft_len - curr_len;
let zeros_real = Tensor::zeros(&[pad], x.dtype(), device);
let zeros_imag = Tensor::zeros(&[pad], x.dtype(), device);
let padding = client.make_complex(&zeros_real, &zeros_imag)?;
client.cat(&[&new_fft, &padding], 0)?
} else {
new_fft
};
let scale = output_len as f64 / n as f64;
let fft_re = client.real(&padded_fft)?;
let fft_im = client.imag(&padded_fft)?;
let scaled_re = client.mul_scalar(&fft_re, scale)?;
let scaled_im = client.mul_scalar(&fft_im, scale)?;
let scaled_fft = client.make_complex(&scaled_re, &scaled_im)?;
let result_complex = client.fft(
&scaled_fft,
FftDirection::Inverse,
FftNormalization::Backward,
)?;
let result_real = client.real(&result_complex)?;
let result = result_real.narrow(0, 0, output_len)?;
Ok(result)
}