Skip to main content

ct2rs/
whisper.rs

1// whisper.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! This module provides a speach transcriber.
10
11use std::fmt::{Debug, Formatter};
12use std::fs::File;
13use std::io::BufReader;
14use std::path::Path;
15
16use anyhow::{anyhow, Result};
17use mel_spec::mel::{log_mel_spectrogram, mel, norm_mel};
18use mel_spec::stft::Spectrogram;
19use ndarray::{s, stack, Array2, Axis};
20use serde::Deserialize;
21
22pub use super::sys::WhisperOptions;
23use super::tokenizers::hf;
24use super::{sys, Config, Tokenizer};
25
26const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
27
28/// A speach transcriber using the Whisper speech recognition model published by OpenAI.
29///
30/// # Example
31/// ```no_run
32/// use ct2rs::Whisper;
33///
34/// # fn main() -> anyhow::Result<()>{
35/// let whisper = Whisper::new("/path/to/model", Default::default())?;
36///
37/// let sampling_rate = whisper.sampling_rate();
38/// // Sample the source audio at the sampling rates shown above.
39/// // Each sample must be normalized to the range [-1, 1].
40/// let samples = vec![];
41///
42/// let res = whisper.generate(&samples, None, false, &Default::default())?;
43/// for r in res {
44///     println!("{}", r);
45/// }
46/// # Ok(())
47/// # }
48/// ```
49pub struct Whisper {
50    whisper: sys::Whisper,
51    tokenizer: hf::Tokenizer,
52    config: PreprocessorConfig,
53}
54
55impl Whisper {
56    /// Initializes the transcriber.
57    ///
58    /// # Arguments
59    /// * `path` - A path to the directory containing the language model to be loaded.
60    /// * `config` - A [`Config`] structure that specifies various settings
61    ///   and configurations for the `Whisper`.
62    ///
63    /// # Returns
64    /// Returns a `Result` that, if successful, contains the initialized `Whisper`. If an error
65    /// occurs during initialization, the function will return an error wrapped in the `Result`.
66    pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
67        Ok(Self {
68            whisper: sys::Whisper::new(&model_path, config)?,
69            tokenizer: hf::Tokenizer::new(&model_path)?,
70            config: PreprocessorConfig::read(model_path.as_ref().join(PREPROCESSOR_CONFIG_FILE))?,
71        })
72    }
73
74    /// Transcribe the given samples.
75    ///
76    /// # Arguments
77    /// * `samples` - Samples of the source audio. They must be sampled at the sampling rate
78    ///   returned by [`sampling_rate`][Whisper::sampling_rate] method and normalized to the range
79    ///   `[-1, 1]`. If the samples are longer than the maximum number of samples returned by
80    ///   [`n_samples`][Whisper::n_samples] method, they will be processed in segments.
81    /// * `language` - An optional language setting. It transcribes assuming the specified language.
82    ///   If `None`, it uses Whisper's language detection.
83    /// * `timestamp` - If `true`, the output will include timestamps.
84    /// * `options` - Settings.
85    ///
86    /// # Returns
87    /// Returns a `Result` containing a vector of transcribed strings if successful,
88    /// or an error if the translation fails.
89    pub fn generate(
90        &self,
91        samples: &[f32],
92        language: Option<&str>,
93        timestamp: bool,
94        options: &WhisperOptions,
95    ) -> Result<Vec<String>> {
96        let mut stft = Spectrogram::new(self.config.n_fft, self.config.hop_length);
97
98        let mut mel_spectrogram_vec = vec![];
99        for chunk in samples.chunks(self.config.n_samples) {
100            let mut mel_spectrogram_per_chunk =
101                Array2::zeros((self.config.feature_size, self.config.nb_max_frames));
102            for (i, flame) in chunk.chunks(self.config.hop_length).enumerate() {
103                if let Some(fft_frame) = stft.add(flame) {
104                    let mel = norm_mel(&log_mel_spectrogram(&fft_frame, &self.config.mel_filters))
105                        .mapv(|v| v as f32);
106                    mel_spectrogram_per_chunk
107                        .slice_mut(s![.., i])
108                        .assign(&mel.slice(s![.., 0]));
109                }
110            }
111            mel_spectrogram_vec.push(mel_spectrogram_per_chunk);
112        }
113
114        let mut mel_spectrogram = stack(
115            Axis(0),
116            &mel_spectrogram_vec
117                .iter()
118                .map(|a| a.view())
119                .collect::<Vec<_>>(),
120        )?;
121        if !mel_spectrogram.is_standard_layout() {
122            mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned()
123        }
124
125        let shape = mel_spectrogram.shape().to_vec();
126        let storage_view = sys::StorageView::new(
127            &shape,
128            mel_spectrogram.as_slice_mut().unwrap(),
129            Default::default(),
130        )?;
131
132        // Detect language.
133        let lang_token = match language {
134            Some(lang) => {
135                format!("<|{}|>", lang)
136            }
137            None => {
138                let detection_result = self.whisper.detect_language(&storage_view)?;
139                detection_result
140                    .into_iter()
141                    .next()
142                    .ok_or_else(|| anyhow!("failed to detect language"))?
143                    .into_iter()
144                    .next()
145                    .ok_or_else(|| anyhow!("failed to detect language"))?
146                    .language
147            }
148        };
149
150        // Transcribe.
151        let mut prompt = vec!["<|startoftranscript|>", &lang_token, "<|transcribe|>"];
152        if !timestamp {
153            prompt.push("<|notimestamps|>");
154        }
155        self.whisper
156            .generate(
157                &storage_view,
158                &vec![prompt; mel_spectrogram_vec.len()],
159                options,
160            )?
161            .into_iter()
162            .map(|res| {
163                let r = res
164                    .sequences
165                    .into_iter()
166                    .next()
167                    .ok_or_else(|| anyhow!("failed to transcribe samples"))?;
168                self.tokenizer.decode(r)
169            })
170            .collect()
171    }
172
173    /// Returns the expected sampling rate.
174    pub fn sampling_rate(&self) -> usize {
175        self.config.sampling_rate
176    }
177
178    /// Max number of samples per batch.
179    pub fn n_samples(&self) -> usize {
180        self.config.n_samples
181    }
182
183    /// Returns `true` if this model is multilingual.
184    #[inline]
185    pub fn is_multilingual(&self) -> bool {
186        self.whisper.is_multilingual()
187    }
188
189    /// Returns the number of languages supported.
190    #[inline]
191    pub fn num_languages(&self) -> usize {
192        self.whisper.num_languages()
193    }
194
195    /// Number of batches in the work queue.
196    #[inline]
197    pub fn num_queued_batches(&self) -> usize {
198        self.whisper.num_queued_batches()
199    }
200
201    /// Number of batches in the work queue or currently processed by a worker.
202    #[inline]
203    pub fn num_active_batches(&self) -> usize {
204        self.whisper.num_active_batches()
205    }
206
207    /// Number of parallel replicas.
208    #[inline]
209    pub fn num_replicas(&self) -> usize {
210        self.whisper.num_replicas()
211    }
212}
213
214impl Debug for Whisper {
215    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
216        write!(f, "{:?}", self.whisper)
217    }
218}
219
220#[derive(Debug)]
221#[allow(dead_code)]
222struct PreprocessorConfig {
223    chunk_length: usize,
224    feature_extractor_type: String,
225    feature_size: usize,
226    hop_length: usize,
227    n_fft: usize,
228    n_samples: usize,
229    nb_max_frames: usize,
230    padding_side: String,
231    padding_value: f32,
232    processor_class: String,
233    return_attention_mask: bool,
234    sampling_rate: usize,
235    mel_filters: Array2<f64>,
236}
237
238impl PreprocessorConfig {
239    fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
240        let file = File::open(path)?;
241        let reader = BufReader::new(file);
242
243        #[derive(Deserialize)]
244        struct PreprocessorConfigAux {
245            chunk_length: usize,
246            feature_extractor_type: String,
247            feature_size: usize,
248            hop_length: usize,
249            n_fft: usize,
250            n_samples: usize,
251            nb_max_frames: usize,
252            padding_side: String,
253            padding_value: f32,
254            processor_class: String,
255            return_attention_mask: bool,
256            sampling_rate: usize,
257            mel_filters: Option<Vec<Vec<f64>>>,
258        }
259        let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;
260
261        let mel_filters = if let Some(mel_filters) = aux.mel_filters {
262            let rows = mel_filters.len();
263            let cols = mel_filters.first().map(|row| row.len()).unwrap_or_default();
264            Array2::from_shape_vec((rows, cols), mel_filters.into_iter().flatten().collect())?
265        } else {
266            mel(
267                aux.sampling_rate as f64,
268                aux.n_fft,
269                aux.feature_size,
270                None,
271                None,
272                false,
273                true,
274            )
275        };
276
277        Ok(Self {
278            chunk_length: aux.chunk_length,
279            feature_extractor_type: aux.feature_extractor_type,
280            feature_size: aux.feature_size,
281            hop_length: aux.hop_length,
282            n_fft: aux.n_fft,
283            n_samples: aux.n_samples,
284            nb_max_frames: aux.nb_max_frames,
285            padding_side: aux.padding_side,
286            padding_value: aux.padding_value,
287            processor_class: aux.processor_class,
288            return_attention_mask: aux.return_attention_mask,
289            sampling_rate: aux.sampling_rate,
290            mel_filters,
291        })
292    }
293}
294
295#[cfg(test)]
296#[cfg(feature = "hub")]
297mod tests {
298    use crate::{download_model, Config, Device, Whisper};
299
300    const MODEL_ID: &str = "jkawamoto/whisper-tiny-ct2";
301
302    #[test]
303    #[ignore]
304    fn test_whisper_debug() {
305        let model_path = download_model(MODEL_ID).unwrap();
306        let w = Whisper::new(
307            &model_path,
308            Config {
309                device: if cfg!(feature = "cuda") {
310                    Device::CUDA
311                } else {
312                    Device::CPU
313                },
314                ..Default::default()
315            },
316        )
317        .unwrap();
318
319        assert!(format!("{:?}", w).contains(model_path.file_name().unwrap().to_str().unwrap()));
320    }
321}