candle_transformers/models/whisper/
audio.rs1use candle::utils::get_num_threads;
5use std::sync::Arc;
6use std::thread;
7
8pub trait Float:
9 num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
10{
11}
12
13impl Float for f32 {}
14impl Float for f64 {}
15
16fn fft<T: Float>(inp: &[T]) -> Vec<T> {
18 let n = inp.len();
19 let zero = T::zero();
20 if n == 1 {
21 return vec![inp[0], zero];
22 }
23 if n % 2 == 1 {
24 return dft(inp);
25 }
26 let mut out = vec![zero; n * 2];
27
28 let mut even = Vec::with_capacity(n / 2);
29 let mut odd = Vec::with_capacity(n / 2);
30
31 for (i, &inp) in inp.iter().enumerate() {
32 if i % 2 == 0 {
33 even.push(inp)
34 } else {
35 odd.push(inp);
36 }
37 }
38
39 let even_fft = fft(&even);
40 let odd_fft = fft(&odd);
41
42 let two_pi = T::PI() + T::PI();
43 let n_t = T::from(n).unwrap();
44 for k in 0..n / 2 {
45 let k_t = T::from(k).unwrap();
46 let theta = two_pi * k_t / n_t;
47 let re = theta.cos();
48 let im = -theta.sin();
49
50 let re_odd = odd_fft[2 * k];
51 let im_odd = odd_fft[2 * k + 1];
52
53 out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
54 out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
55
56 out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
57 out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
58 }
59 out
60}
61
62fn dft<T: Float>(inp: &[T]) -> Vec<T> {
64 let zero = T::zero();
65 let n = inp.len();
66 let two_pi = T::PI() + T::PI();
67
68 let mut out = Vec::with_capacity(2 * n);
69 let n_t = T::from(n).unwrap();
70 for k in 0..n {
71 let k_t = T::from(k).unwrap();
72 let mut re = zero;
73 let mut im = zero;
74
75 for (j, &inp) in inp.iter().enumerate() {
76 let j_t = T::from(j).unwrap();
77 let angle = two_pi * k_t * j_t / n_t;
78 re += inp * angle.cos();
79 im -= inp * angle.sin();
80 }
81
82 out.push(re);
83 out.push(im);
84 }
85 out
86}
87
88#[allow(clippy::too_many_arguments)]
89fn log_mel_spectrogram_w<T: Float>(
91 ith: usize,
92 hann: &[T],
93 samples: &[T],
94 filters: &[T],
95 fft_size: usize,
96 fft_step: usize,
97 speed_up: bool,
98 n_len: usize,
99 n_mel: usize,
100 n_threads: usize,
101) -> Vec<T> {
102 let n_fft = if speed_up {
103 1 + fft_size / 4
104 } else {
105 1 + fft_size / 2
106 };
107
108 let zero = T::zero();
109 let half = T::from(0.5).unwrap();
110 let mut fft_in = vec![zero; fft_size];
111 let mut mel = vec![zero; n_len * n_mel];
112 let n_samples = samples.len();
113 let end = std::cmp::min(n_samples / fft_step + 1, n_len);
114
115 for i in (ith..end).step_by(n_threads) {
116 let offset = i * fft_step;
117
118 for j in 0..std::cmp::min(fft_size, n_samples - offset) {
120 fft_in[j] = hann[j] * samples[offset + j];
121 }
122
123 if n_samples - offset < fft_size {
125 fft_in[n_samples - offset..].fill(zero);
126 }
127
128 let mut fft_out: Vec<T> = fft(&fft_in);
130
131 for j in 0..fft_size {
133 fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
134 }
135 for j in 1..fft_size / 2 {
136 let v = fft_out[fft_size - j];
137 fft_out[j] += v;
138 }
139
140 if speed_up {
141 for j in 0..n_fft {
143 fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
144 }
145 }
146
147 for j in 0..n_mel {
149 let mut sum = zero;
150 let mut k = 0;
151 while k < n_fft.saturating_sub(3) {
153 sum += fft_out[k] * filters[j * n_fft + k]
154 + fft_out[k + 1] * filters[j * n_fft + k + 1]
155 + fft_out[k + 2] * filters[j * n_fft + k + 2]
156 + fft_out[k + 3] * filters[j * n_fft + k + 3];
157 k += 4;
158 }
159 while k < n_fft {
161 sum += fft_out[k] * filters[j * n_fft + k];
162 k += 1;
163 }
164 mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
165 }
166 }
167 mel
168}
169
170pub fn log_mel_spectrogram_<T: Float>(
171 samples: &[T],
172 filters: &[T],
173 fft_size: usize,
174 fft_step: usize,
175 n_mel: usize,
176 speed_up: bool,
177) -> Vec<T> {
178 let zero = T::zero();
179 let two_pi = T::PI() + T::PI();
180 let half = T::from(0.5).unwrap();
181 let one = T::from(1.0).unwrap();
182 let four = T::from(4.0).unwrap();
183 let fft_size_t = T::from(fft_size).unwrap();
184
185 let hann: Vec<T> = (0..fft_size)
186 .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
187 .collect();
188 let n_len = samples.len() / fft_step;
189
190 let pad = 100 * super::CHUNK_LENGTH / 2;
192 let n_len = if n_len % pad != 0 {
193 (n_len / pad + 1) * pad
194 } else {
195 n_len
196 };
197 let n_len = n_len + pad;
198 let samples = {
199 let mut samples_padded = samples.to_vec();
200 let to_add = n_len * fft_step - samples.len();
201 samples_padded.extend(std::iter::repeat_n(zero, to_add));
202 samples_padded
203 };
204
205 let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
207 let n_threads = std::cmp::max(n_threads, 2);
208
209 let hann = Arc::new(hann);
210 let samples = Arc::new(samples);
211 let filters = Arc::new(filters);
212
213 let all_outputs = thread::scope(|s| {
216 (0..n_threads)
217 .map(|thread_id| {
219 let hann = Arc::clone(&hann);
220 let samples = Arc::clone(&samples);
221 let filters = Arc::clone(&filters);
222 s.spawn(move || {
224 log_mel_spectrogram_w(
225 thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
226 n_mel, n_threads,
227 )
228 })
229 })
230 .collect::<Vec<_>>()
231 .into_iter()
232 .map(|handle| handle.join().expect("Thread failed"))
234 .collect::<Vec<_>>()
235 });
236
237 let l = all_outputs[0].len();
238 let mut mel = vec![zero; l];
239
240 for segment_start in (0..l).step_by(n_threads) {
242 for thread_output in all_outputs.iter() {
244 for offset in 0..n_threads {
246 let mel_index = segment_start + offset; if mel_index < mel.len() {
248 mel[mel_index] += thread_output[mel_index];
250 }
251 }
252 }
253 }
254
255 let mmax = mel
256 .iter()
257 .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
258 .copied()
259 .unwrap_or(zero)
260 - T::from(8).unwrap();
261 for m in mel.iter_mut() {
262 let v = T::max(*m, mmax);
263 *m = v / four + one
264 }
265 mel
266}
267
268pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
269 log_mel_spectrogram_(
270 samples,
271 filters,
272 super::N_FFT,
273 super::HOP_LENGTH,
274 cfg.num_mel_bins,
275 false,
276 )
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_fft() {
285 let input = vec![0.0, 1.0, 0.0, 0.0];
286 let output = fft(&input);
287 assert_eq!(
288 output,
289 vec![
290 1.0,
291 0.0,
292 6.123233995736766e-17,
293 -1.0,
294 -1.0,
295 0.0,
296 -6.123233995736766e-17,
297 1.0
298 ]
299 );
300 }
301
302 #[test]
303 fn test_dft() {
304 let input = vec![0.0, 1.0, 0.0, 0.0];
305 let output = dft(&input);
306 assert_eq!(
307 output,
308 vec![
309 1.0,
310 0.0,
311 6.123233995736766e-17,
312 -1.0,
313 -1.0,
314 -1.2246467991473532e-16,
315 -1.8369701987210297e-16,
316 1.0
317 ]
318 );
319 }
320
321 #[test]
322 fn test_log_mel_spectrogram() {
323 let samples = vec![0.0; 1000];
324 let filters = vec![0.0; 1000];
325 let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
326 assert_eq!(output.len(), 30_000);
327 }
328
329 #[test]
330 fn test_tiny_log_mel_spectrogram() {
331 let samples = vec![0.0; 100];
332 let filters = vec![0.0; 100];
333 let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
334 assert_eq!(output.len(), 6_000);
335 }
336}