candle_transformers/models/whisper/
audio.rs

1// Audio processing code, adapted from whisper.cpp
2// https://github.com/ggerganov/whisper.cpp
3
4use candle::utils::get_num_threads;
5use std::sync::Arc;
6use std::thread;
7
8pub trait Float:
9    num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
10{
11}
12
13impl Float for f32 {}
14impl Float for f64 {}
15
16// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
17fn fft<T: Float>(inp: &[T]) -> Vec<T> {
18    let n = inp.len();
19    let zero = T::zero();
20    if n == 1 {
21        return vec![inp[0], zero];
22    }
23    if n % 2 == 1 {
24        return dft(inp);
25    }
26    let mut out = vec![zero; n * 2];
27
28    let mut even = Vec::with_capacity(n / 2);
29    let mut odd = Vec::with_capacity(n / 2);
30
31    for (i, &inp) in inp.iter().enumerate() {
32        if i % 2 == 0 {
33            even.push(inp)
34        } else {
35            odd.push(inp);
36        }
37    }
38
39    let even_fft = fft(&even);
40    let odd_fft = fft(&odd);
41
42    let two_pi = T::PI() + T::PI();
43    let n_t = T::from(n).unwrap();
44    for k in 0..n / 2 {
45        let k_t = T::from(k).unwrap();
46        let theta = two_pi * k_t / n_t;
47        let re = theta.cos();
48        let im = -theta.sin();
49
50        let re_odd = odd_fft[2 * k];
51        let im_odd = odd_fft[2 * k + 1];
52
53        out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
54        out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
55
56        out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
57        out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
58    }
59    out
60}
61
62// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
63fn dft<T: Float>(inp: &[T]) -> Vec<T> {
64    let zero = T::zero();
65    let n = inp.len();
66    let two_pi = T::PI() + T::PI();
67
68    let mut out = Vec::with_capacity(2 * n);
69    let n_t = T::from(n).unwrap();
70    for k in 0..n {
71        let k_t = T::from(k).unwrap();
72        let mut re = zero;
73        let mut im = zero;
74
75        for (j, &inp) in inp.iter().enumerate() {
76            let j_t = T::from(j).unwrap();
77            let angle = two_pi * k_t * j_t / n_t;
78            re += inp * angle.cos();
79            im -= inp * angle.sin();
80        }
81
82        out.push(re);
83        out.push(im);
84    }
85    out
86}
87
88#[allow(clippy::too_many_arguments)]
89// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
90fn log_mel_spectrogram_w<T: Float>(
91    ith: usize,
92    hann: &[T],
93    samples: &[T],
94    filters: &[T],
95    fft_size: usize,
96    fft_step: usize,
97    speed_up: bool,
98    n_len: usize,
99    n_mel: usize,
100    n_threads: usize,
101) -> Vec<T> {
102    let n_fft = if speed_up {
103        1 + fft_size / 4
104    } else {
105        1 + fft_size / 2
106    };
107
108    let zero = T::zero();
109    let half = T::from(0.5).unwrap();
110    let mut fft_in = vec![zero; fft_size];
111    let mut mel = vec![zero; n_len * n_mel];
112    let n_samples = samples.len();
113    let end = std::cmp::min(n_samples / fft_step + 1, n_len);
114
115    for i in (ith..end).step_by(n_threads) {
116        let offset = i * fft_step;
117
118        // apply Hanning window
119        for j in 0..std::cmp::min(fft_size, n_samples - offset) {
120            fft_in[j] = hann[j] * samples[offset + j];
121        }
122
123        // fill the rest with zeros
124        if n_samples - offset < fft_size {
125            fft_in[n_samples - offset..].fill(zero);
126        }
127
128        // FFT
129        let mut fft_out: Vec<T> = fft(&fft_in);
130
131        // Calculate modulus^2 of complex numbers
132        for j in 0..fft_size {
133            fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
134        }
135        for j in 1..fft_size / 2 {
136            let v = fft_out[fft_size - j];
137            fft_out[j] += v;
138        }
139
140        if speed_up {
141            // scale down in the frequency domain results in a speed up in the time domain
142            for j in 0..n_fft {
143                fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
144            }
145        }
146
147        // mel spectrogram
148        for j in 0..n_mel {
149            let mut sum = zero;
150            let mut k = 0;
151            // Unroll loop
152            while k < n_fft.saturating_sub(3) {
153                sum += fft_out[k] * filters[j * n_fft + k]
154                    + fft_out[k + 1] * filters[j * n_fft + k + 1]
155                    + fft_out[k + 2] * filters[j * n_fft + k + 2]
156                    + fft_out[k + 3] * filters[j * n_fft + k + 3];
157                k += 4;
158            }
159            // Handle remainder
160            while k < n_fft {
161                sum += fft_out[k] * filters[j * n_fft + k];
162                k += 1;
163            }
164            mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
165        }
166    }
167    mel
168}
169
170pub fn log_mel_spectrogram_<T: Float>(
171    samples: &[T],
172    filters: &[T],
173    fft_size: usize,
174    fft_step: usize,
175    n_mel: usize,
176    speed_up: bool,
177) -> Vec<T> {
178    let zero = T::zero();
179    let two_pi = T::PI() + T::PI();
180    let half = T::from(0.5).unwrap();
181    let one = T::from(1.0).unwrap();
182    let four = T::from(4.0).unwrap();
183    let fft_size_t = T::from(fft_size).unwrap();
184
185    let hann: Vec<T> = (0..fft_size)
186        .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
187        .collect();
188    let n_len = samples.len() / fft_step;
189
190    // pad audio with at least one extra chunk of zeros
191    let pad = 100 * super::CHUNK_LENGTH / 2;
192    let n_len = if n_len % pad != 0 {
193        (n_len / pad + 1) * pad
194    } else {
195        n_len
196    };
197    let n_len = n_len + pad;
198    let samples = {
199        let mut samples_padded = samples.to_vec();
200        let to_add = n_len * fft_step - samples.len();
201        samples_padded.extend(std::iter::repeat_n(zero, to_add));
202        samples_padded
203    };
204
205    // ensure that the number of threads is even and less than 12
206    let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
207    let n_threads = std::cmp::max(n_threads, 2);
208
209    let hann = Arc::new(hann);
210    let samples = Arc::new(samples);
211    let filters = Arc::new(filters);
212
213    // use scope to allow for non static references to be passed to the threads
214    // and directly collect the results into a single vector
215    let all_outputs = thread::scope(|s| {
216        (0..n_threads)
217            // create threads and return their handles
218            .map(|thread_id| {
219                let hann = Arc::clone(&hann);
220                let samples = Arc::clone(&samples);
221                let filters = Arc::clone(&filters);
222                // spawn new thread and start work
223                s.spawn(move || {
224                    log_mel_spectrogram_w(
225                        thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
226                        n_mel, n_threads,
227                    )
228                })
229            })
230            .collect::<Vec<_>>()
231            .into_iter()
232            // wait for each thread to finish and collect their results
233            .map(|handle| handle.join().expect("Thread failed"))
234            .collect::<Vec<_>>()
235    });
236
237    let l = all_outputs[0].len();
238    let mut mel = vec![zero; l];
239
240    // iterate over mel spectrogram segments, dividing work by threads.
241    for segment_start in (0..l).step_by(n_threads) {
242        // go through each thread's output.
243        for thread_output in all_outputs.iter() {
244            // add each thread's piece to our mel spectrogram.
245            for offset in 0..n_threads {
246                let mel_index = segment_start + offset; // find location in mel.
247                if mel_index < mel.len() {
248                    // Make sure we don't go out of bounds.
249                    mel[mel_index] += thread_output[mel_index];
250                }
251            }
252        }
253    }
254
255    let mmax = mel
256        .iter()
257        .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
258        .copied()
259        .unwrap_or(zero)
260        - T::from(8).unwrap();
261    for m in mel.iter_mut() {
262        let v = T::max(*m, mmax);
263        *m = v / four + one
264    }
265    mel
266}
267
268pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
269    log_mel_spectrogram_(
270        samples,
271        filters,
272        super::N_FFT,
273        super::HOP_LENGTH,
274        cfg.num_mel_bins,
275        false,
276    )
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_fft() {
285        let input = vec![0.0, 1.0, 0.0, 0.0];
286        let output = fft(&input);
287        assert_eq!(
288            output,
289            vec![
290                1.0,
291                0.0,
292                6.123233995736766e-17,
293                -1.0,
294                -1.0,
295                0.0,
296                -6.123233995736766e-17,
297                1.0
298            ]
299        );
300    }
301
302    #[test]
303    fn test_dft() {
304        let input = vec![0.0, 1.0, 0.0, 0.0];
305        let output = dft(&input);
306        assert_eq!(
307            output,
308            vec![
309                1.0,
310                0.0,
311                6.123233995736766e-17,
312                -1.0,
313                -1.0,
314                -1.2246467991473532e-16,
315                -1.8369701987210297e-16,
316                1.0
317            ]
318        );
319    }
320
321    #[test]
322    fn test_log_mel_spectrogram() {
323        let samples = vec![0.0; 1000];
324        let filters = vec![0.0; 1000];
325        let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
326        assert_eq!(output.len(), 30_000);
327    }
328
329    #[test]
330    fn test_tiny_log_mel_spectrogram() {
331        let samples = vec![0.0; 100];
332        let filters = vec![0.0; 100];
333        let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
334        assert_eq!(output.len(), 6_000);
335    }
336}