1use std::fmt::{Debug, Formatter};
12use std::fs::File;
13use std::io::BufReader;
14use std::path::Path;
15
16use anyhow::{anyhow, Result};
17use mel_spec::mel::{log_mel_spectrogram, mel, norm_mel};
18use mel_spec::stft::Spectrogram;
19use ndarray::{s, stack, Array2, Axis};
20use serde::Deserialize;
21
22pub use super::sys::WhisperOptions;
23use super::tokenizers::hf;
24use super::{sys, Config, Tokenizer};
25
26const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
27
28pub struct Whisper {
50 whisper: sys::Whisper,
51 tokenizer: hf::Tokenizer,
52 config: PreprocessorConfig,
53}
54
55impl Whisper {
56 pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
67 Ok(Self {
68 whisper: sys::Whisper::new(&model_path, config)?,
69 tokenizer: hf::Tokenizer::new(&model_path)?,
70 config: PreprocessorConfig::read(model_path.as_ref().join(PREPROCESSOR_CONFIG_FILE))?,
71 })
72 }
73
74 pub fn generate(
90 &self,
91 samples: &[f32],
92 language: Option<&str>,
93 timestamp: bool,
94 options: &WhisperOptions,
95 ) -> Result<Vec<String>> {
96 let mut stft = Spectrogram::new(self.config.n_fft, self.config.hop_length);
97
98 let mut mel_spectrogram_vec = vec![];
99 for chunk in samples.chunks(self.config.n_samples) {
100 let mut mel_spectrogram_per_chunk =
101 Array2::zeros((self.config.feature_size, self.config.nb_max_frames));
102 for (i, flame) in chunk.chunks(self.config.hop_length).enumerate() {
103 if let Some(fft_frame) = stft.add(flame) {
104 let mel = norm_mel(&log_mel_spectrogram(&fft_frame, &self.config.mel_filters))
105 .mapv(|v| v as f32);
106 mel_spectrogram_per_chunk
107 .slice_mut(s![.., i])
108 .assign(&mel.slice(s![.., 0]));
109 }
110 }
111 mel_spectrogram_vec.push(mel_spectrogram_per_chunk);
112 }
113
114 let mut mel_spectrogram = stack(
115 Axis(0),
116 &mel_spectrogram_vec
117 .iter()
118 .map(|a| a.view())
119 .collect::<Vec<_>>(),
120 )?;
121 if !mel_spectrogram.is_standard_layout() {
122 mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned()
123 }
124
125 let shape = mel_spectrogram.shape().to_vec();
126 let storage_view = sys::StorageView::new(
127 &shape,
128 mel_spectrogram.as_slice_mut().unwrap(),
129 Default::default(),
130 )?;
131
132 let lang_token = match language {
134 Some(lang) => {
135 format!("<|{}|>", lang)
136 }
137 None => {
138 let detection_result = self.whisper.detect_language(&storage_view)?;
139 detection_result
140 .into_iter()
141 .next()
142 .ok_or_else(|| anyhow!("failed to detect language"))?
143 .into_iter()
144 .next()
145 .ok_or_else(|| anyhow!("failed to detect language"))?
146 .language
147 }
148 };
149
150 let mut prompt = vec!["<|startoftranscript|>", &lang_token, "<|transcribe|>"];
152 if !timestamp {
153 prompt.push("<|notimestamps|>");
154 }
155 self.whisper
156 .generate(
157 &storage_view,
158 &vec![prompt; mel_spectrogram_vec.len()],
159 options,
160 )?
161 .into_iter()
162 .map(|res| {
163 let r = res
164 .sequences
165 .into_iter()
166 .next()
167 .ok_or_else(|| anyhow!("failed to transcribe samples"))?;
168 self.tokenizer.decode(r)
169 })
170 .collect()
171 }
172
173 pub fn sampling_rate(&self) -> usize {
175 self.config.sampling_rate
176 }
177
178 pub fn n_samples(&self) -> usize {
180 self.config.n_samples
181 }
182
183 #[inline]
185 pub fn is_multilingual(&self) -> bool {
186 self.whisper.is_multilingual()
187 }
188
189 #[inline]
191 pub fn num_languages(&self) -> usize {
192 self.whisper.num_languages()
193 }
194
195 #[inline]
197 pub fn num_queued_batches(&self) -> usize {
198 self.whisper.num_queued_batches()
199 }
200
201 #[inline]
203 pub fn num_active_batches(&self) -> usize {
204 self.whisper.num_active_batches()
205 }
206
207 #[inline]
209 pub fn num_replicas(&self) -> usize {
210 self.whisper.num_replicas()
211 }
212}
213
214impl Debug for Whisper {
215 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
216 write!(f, "{:?}", self.whisper)
217 }
218}
219
220#[derive(Debug)]
221#[allow(dead_code)]
222struct PreprocessorConfig {
223 chunk_length: usize,
224 feature_extractor_type: String,
225 feature_size: usize,
226 hop_length: usize,
227 n_fft: usize,
228 n_samples: usize,
229 nb_max_frames: usize,
230 padding_side: String,
231 padding_value: f32,
232 processor_class: String,
233 return_attention_mask: bool,
234 sampling_rate: usize,
235 mel_filters: Array2<f64>,
236}
237
238impl PreprocessorConfig {
239 fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
240 let file = File::open(path)?;
241 let reader = BufReader::new(file);
242
243 #[derive(Deserialize)]
244 struct PreprocessorConfigAux {
245 chunk_length: usize,
246 feature_extractor_type: String,
247 feature_size: usize,
248 hop_length: usize,
249 n_fft: usize,
250 n_samples: usize,
251 nb_max_frames: usize,
252 padding_side: String,
253 padding_value: f32,
254 processor_class: String,
255 return_attention_mask: bool,
256 sampling_rate: usize,
257 mel_filters: Option<Vec<Vec<f64>>>,
258 }
259 let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;
260
261 let mel_filters = if let Some(mel_filters) = aux.mel_filters {
262 let rows = mel_filters.len();
263 let cols = mel_filters.first().map(|row| row.len()).unwrap_or_default();
264 Array2::from_shape_vec((rows, cols), mel_filters.into_iter().flatten().collect())?
265 } else {
266 mel(
267 aux.sampling_rate as f64,
268 aux.n_fft,
269 aux.feature_size,
270 None,
271 None,
272 false,
273 true,
274 )
275 };
276
277 Ok(Self {
278 chunk_length: aux.chunk_length,
279 feature_extractor_type: aux.feature_extractor_type,
280 feature_size: aux.feature_size,
281 hop_length: aux.hop_length,
282 n_fft: aux.n_fft,
283 n_samples: aux.n_samples,
284 nb_max_frames: aux.nb_max_frames,
285 padding_side: aux.padding_side,
286 padding_value: aux.padding_value,
287 processor_class: aux.processor_class,
288 return_attention_mask: aux.return_attention_mask,
289 sampling_rate: aux.sampling_rate,
290 mel_filters,
291 })
292 }
293}
294
295#[cfg(test)]
296#[cfg(feature = "hub")]
297mod tests {
298 use crate::{download_model, Config, Device, Whisper};
299
300 const MODEL_ID: &str = "jkawamoto/whisper-tiny-ct2";
301
302 #[test]
303 #[ignore]
304 fn test_whisper_debug() {
305 let model_path = download_model(MODEL_ID).unwrap();
306 let w = Whisper::new(
307 &model_path,
308 Config {
309 device: if cfg!(feature = "cuda") {
310 Device::CUDA
311 } else {
312 Device::CPU
313 },
314 ..Default::default()
315 },
316 )
317 .unwrap();
318
319 assert!(format!("{:?}", w).contains(model_path.file_name().unwrap().to_str().unwrap()));
320 }
321}