use candle::{DType, Device, Error, Tensor};
use crate::models::whisper::audio::{log_mel_spectrogram_, Float};
pub fn pcm_to_mel<T: Float>(samples: &[T], filters: &[T]) -> Vec<T> {
log_mel_spectrogram_(
samples,
filters,
super::N_FFT,
super::HOP_LENGTH,
super::N_MELS,
false,
)
}
pub fn extract_features(audio: &[f32], filters: &[f32], device: &Device) -> Result<Tensor, Error> {
const N_MELS: usize = super::N_MELS;
let mel_vec = pcm_to_mel(audio, filters);
let n_mel = super::N_MELS;
let n_len = mel_vec.len() / n_mel;
let mel_tensor = Tensor::from_vec(mel_vec, (n_mel, n_len), device)?;
let mel_tensor = mel_tensor.unsqueeze(0)?;
let mel = mel_tensor.flatten_all()?.to_vec1::<f32>()?;
let mel_len = mel.len();
let total_frames = mel_len / N_MELS;
let max_source_positions = 3000;
let mel_tensor = Tensor::from_vec(mel, (N_MELS, total_frames), device)
.map_err(|e| Error::Msg(format!("Failed to create mel tensor: {e}")))?;
let num_chunks = total_frames.div_ceil(max_source_positions);
let padded_frames = num_chunks * max_source_positions;
let padding_needed = padded_frames - total_frames;
let mel_padded = if padding_needed > 0 {
let padding = Tensor::zeros((N_MELS, padding_needed), DType::F32, device)?;
Tensor::cat(&[&mel_tensor, &padding], 1)?
} else {
mel_tensor
};
let reshaped = mel_padded.reshape((N_MELS, num_chunks, max_source_positions))?;
let audio_features = reshaped.transpose(0, 1)?;
Ok(audio_features)
}