1use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12use crate::MemvidError;
13
14#[cfg(feature = "whisper")]
16use crate::Result;
17#[cfg(feature = "whisper")]
18use std::path::Path;
19
20#[derive(Debug, Clone)]
26pub struct WhisperModelInfo {
27 pub model_id: &'static str,
29 pub name: &'static str,
31 pub size_mb: f32,
33 pub is_default: bool,
35 pub language: &'static str,
37}
38
39pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
41 WhisperModelInfo {
42 model_id: "openai/whisper-small.en",
43 name: "whisper-small-en",
44 size_mb: 244.0,
45 is_default: true,
46 language: "en",
47 },
48 WhisperModelInfo {
49 model_id: "openai/whisper-small",
50 name: "whisper-small",
51 size_mb: 244.0,
52 is_default: false,
53 language: "multilingual",
54 },
55];
56
57pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
59 WHISPER_MODELS
60 .iter()
61 .find(|m| m.name == name || m.model_id == name)
62 .unwrap_or_else(|| {
63 WHISPER_MODELS
64 .iter()
65 .find(|m| m.is_default)
66 .expect("default whisper model")
67 })
68}
69
70pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
72 WHISPER_MODELS
73 .iter()
74 .find(|m| m.is_default)
75 .expect("default whisper model exists")
76}
77
78#[derive(Debug, Clone)]
84pub struct WhisperConfig {
85 pub model_name: String,
87 pub models_dir: PathBuf,
89 pub offline: bool,
91}
92
93impl Default for WhisperConfig {
94 fn default() -> Self {
95 let models_dir = std::env::var("MEMVID_MODELS_DIR")
96 .ok()
97 .map(PathBuf::from)
98 .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
99 .unwrap_or_else(|| PathBuf::from(".memvid/models"));
100
101 let model_name = std::env::var("MEMVID_WHISPER_MODEL")
102 .unwrap_or_else(|_| "whisper-small-en".to_string());
103
104 let offline = std::env::var("MEMVID_OFFLINE").is_ok();
105
106 Self {
107 model_name,
108 models_dir,
109 offline,
110 }
111 }
112}
113
114#[derive(Debug, thiserror::Error)]
120pub enum WhisperError {
121 #[error("Whisper model '{model}' not found. {hint}")]
123 ModelNotFound { model: String, hint: String },
124
125 #[error("Failed to decode audio at {path:?}: {cause}")]
127 AudioDecodeError { path: PathBuf, cause: String },
128
129 #[error("Failed to decode audio bytes: {cause}")]
131 AudioBytesDecodeError { cause: String },
132
133 #[error("Whisper inference error: {cause}")]
135 InferenceError { cause: String },
136
137 #[error("Failed to download Whisper model: {cause}")]
139 DownloadError { cause: String },
140}
141
142impl From<WhisperError> for MemvidError {
143 fn from(err: WhisperError) -> Self {
144 MemvidError::ExtractionFailed {
145 reason: err.to_string().into_boxed_str(),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TranscriptionResult {
157 pub text: String,
159 pub language: String,
161 pub duration_secs: f32,
163 #[serde(default)]
165 pub segments: Vec<TranscriptionSegment>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct TranscriptionSegment {
171 pub start: f32,
173 pub end: f32,
175 pub text: String,
177}
178
179#[cfg(feature = "whisper")]
184mod audio {
185 use super::*;
186 use std::fs::File;
187 use symphonia::core::audio::SampleBuffer;
188 use symphonia::core::codecs::DecoderOptions;
189 use symphonia::core::formats::FormatOptions;
190 use symphonia::core::io::MediaSourceStream;
191 use symphonia::core::meta::MetadataOptions;
192 use symphonia::core::probe::Hint;
193
194 pub const WHISPER_SAMPLE_RATE: u32 = 16000;
196
197 pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
199 let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
200 path: path.to_path_buf(),
201 cause: e.to_string(),
202 })?;
203
204 let mss = MediaSourceStream::new(Box::new(file), Default::default());
205
206 let mut hint = Hint::new();
208 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
209 hint.with_extension(ext);
210 }
211
212 let format_opts = FormatOptions::default();
214 let metadata_opts = MetadataOptions::default();
215 let probed = symphonia::default::get_probe()
216 .format(&hint, mss, &format_opts, &metadata_opts)
217 .map_err(|e| WhisperError::AudioDecodeError {
218 path: path.to_path_buf(),
219 cause: format!("Failed to probe audio format: {}", e),
220 })?;
221
222 let mut format = probed.format;
223
224 let track = format
226 .tracks()
227 .iter()
228 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
229 .ok_or_else(|| WhisperError::AudioDecodeError {
230 path: path.to_path_buf(),
231 cause: "No audio track found".to_string(),
232 })?;
233
234 let track_id = track.id;
235 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
236 let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
237
238 let decoder_opts = DecoderOptions::default();
240 let mut decoder = symphonia::default::get_codecs()
241 .make(&track.codec_params, &decoder_opts)
242 .map_err(|e| WhisperError::AudioDecodeError {
243 path: path.to_path_buf(),
244 cause: format!("Failed to create decoder: {}", e),
245 })?;
246
247 let mut samples: Vec<f32> = Vec::new();
248
249 loop {
251 let packet = match format.next_packet() {
252 Ok(p) => p,
253 Err(symphonia::core::errors::Error::IoError(e))
254 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
255 {
256 break;
257 }
258 Err(_) => break,
259 };
260
261 if packet.track_id() != track_id {
262 continue;
263 }
264
265 let decoded = match decoder.decode(&packet) {
266 Ok(d) => d,
267 Err(_) => continue,
268 };
269
270 let spec = *decoded.spec();
271 let num_frames = decoded.frames();
272
273 if num_frames == 0 {
274 continue;
275 }
276
277 let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
278 sample_buf.copy_interleaved_ref(decoded);
279
280 let interleaved = sample_buf.samples();
281
282 if channels > 1 {
284 for chunk in interleaved.chunks(channels) {
285 let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
286 samples.push(mono);
287 }
288 } else {
289 samples.extend_from_slice(interleaved);
290 }
291 }
292
293 let duration_secs = samples.len() as f32 / sample_rate as f32;
294
295 let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
297 let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
298 let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
299 tracing::info!(
300 sample_rate = sample_rate,
301 channels = channels,
302 samples_before = samples.len(),
303 pre_min = pre_min,
304 pre_max = pre_max,
305 pre_rms = pre_rms,
306 "Audio before resampling"
307 );
308
309 let samples = if sample_rate != WHISPER_SAMPLE_RATE {
311 let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
312
313 let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
315 let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
316 let post_rms =
317 (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
318 tracing::info!(
319 samples_after = resampled.len(),
320 post_min = post_min,
321 post_max = post_max,
322 post_rms = post_rms,
323 "Audio after resampling"
324 );
325 resampled
326 } else {
327 tracing::info!("Audio already at 16kHz, no resampling needed");
328 samples
329 };
330
331 Ok((samples, duration_secs))
332 }
333
334 fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
336 use rubato::{FftFixedIn, Resampler};
337
338 if from_rate == to_rate {
339 return samples.to_vec();
340 }
341
342 let chunk_size = 1024;
344 let mut resampler = FftFixedIn::<f32>::new(
345 from_rate as usize,
346 to_rate as usize,
347 chunk_size,
348 2, 1, )
351 .expect("Failed to create resampler");
352
353 let mut output = Vec::new();
354 let mut pos = 0;
355
356 while pos < samples.len() {
358 let end = (pos + chunk_size).min(samples.len());
359 let chunk = &samples[pos..end];
360
361 let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
363 let mut padded = chunk.to_vec();
364 padded.resize(chunk_size, 0.0);
365 padded
366 } else {
367 chunk.to_vec()
368 };
369
370 let input = vec![input_chunk];
371 let resampled = resampler.process(&input, None).expect("Resampling failed");
372
373 if !resampled.is_empty() && !resampled[0].is_empty() {
374 output.extend_from_slice(&resampled[0]);
375 }
376
377 pos += chunk_size;
378 }
379
380 let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
382 output.truncate(expected_len);
383
384 output
385 }
386}
387
388#[cfg(feature = "whisper")]
389pub use audio::*;
390
391#[cfg(feature = "whisper")]
396mod inference {
397 use super::*;
398 use candle_core::{DType, Device, IndexOp, Tensor};
399 use candle_nn::VarBuilder;
400 use candle_transformers::models::whisper::{self as m, Config, audio};
401 use hf_hub::{Repo, RepoType, api::sync::Api};
402 use tokenizers::Tokenizer;
403
404 pub struct WhisperTranscriber {
406 model: Model,
407 tokenizer: Tokenizer,
408 config: Config,
409 mel_filters: Vec<f32>,
410 device: Device,
411 }
412
413 #[allow(dead_code)]
414 enum Model {
415 Normal(m::model::Whisper),
416 Quantized(m::quantized_model::Whisper),
417 }
418
419 impl WhisperTranscriber {
420 pub fn new(config: &WhisperConfig) -> Result<Self> {
422 let device = Self::select_device();
424 tracing::info!(device = ?device, "Using device for Whisper");
425 let model_id = match config.model_name.as_str() {
426 "whisper-small-en" => "openai/whisper-small.en",
427 "whisper-small" => "openai/whisper-small",
428 "whisper-tiny.en" => "openai/whisper-tiny.en",
429 "whisper-tiny" => "openai/whisper-tiny",
430 "whisper-base.en" => "openai/whisper-base.en",
431 "whisper-base" => "openai/whisper-base",
432 "whisper-medium.en" => "openai/whisper-medium.en",
433 "whisper-medium" => "openai/whisper-medium",
434 "whisper-large-v3" => "openai/whisper-large-v3",
435 other => other, };
437
438 tracing::info!(model_id = model_id, "Loading Whisper model");
439
440 let api = Api::new().map_err(|e| WhisperError::DownloadError {
441 cause: e.to_string(),
442 })?;
443 let repo = api.repo(Repo::with_revision(
444 model_id.to_string(),
445 RepoType::Model,
446 "main".to_string(),
447 ));
448
449 let config_path = repo
451 .get("config.json")
452 .map_err(|e| WhisperError::DownloadError {
453 cause: format!("Failed to download config.json: {}", e),
454 })?;
455 let tokenizer_path =
456 repo.get("tokenizer.json")
457 .map_err(|e| WhisperError::DownloadError {
458 cause: format!("Failed to download tokenizer.json: {}", e),
459 })?;
460 let model_path =
461 repo.get("model.safetensors")
462 .map_err(|e| WhisperError::DownloadError {
463 cause: format!("Failed to download model.safetensors: {}", e),
464 })?;
465
466 let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
468 WhisperError::InferenceError {
469 cause: format!("Failed to read config: {}", e),
470 }
471 })?;
472 let model_config: Config =
473 serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
474 cause: format!("Failed to parse config: {}", e),
475 })?;
476
477 let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
479 WhisperError::InferenceError {
480 cause: format!("Failed to load tokenizer: {}", e),
481 }
482 })?;
483
484 let mel_bytes = match model_config.num_mel_bins {
486 80 => include_bytes!("melfilters.bytes").as_slice(),
487 128 => include_bytes!("melfilters128.bytes").as_slice(),
488 n => {
489 return Err(WhisperError::InferenceError {
490 cause: format!("Unsupported number of mel bins: {}", n),
491 }
492 .into());
493 }
494 };
495 let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
496 <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(
497 mel_bytes,
498 &mut mel_filters,
499 );
500
501 let vb = unsafe {
503 VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device).map_err(
504 |e| WhisperError::InferenceError {
505 cause: format!("Failed to load model weights: {}", e),
506 },
507 )?
508 };
509 let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone()).map_err(
510 |e| WhisperError::InferenceError {
511 cause: format!("Failed to load Whisper model: {}", e),
512 },
513 )?);
514
515 tracing::info!("Whisper model loaded successfully");
516
517 Ok(Self {
518 model,
519 tokenizer,
520 config: model_config,
521 mel_filters,
522 device,
523 })
524 }
525
526 fn select_device() -> Device {
528 #[cfg(feature = "metal")]
530 {
531 if let Ok(device) = Device::new_metal(0) {
532 tracing::info!("Metal GPU available");
533 return device;
534 }
535 }
536
537 #[cfg(feature = "cuda")]
539 {
540 if let Ok(device) = Device::new_cuda(0) {
541 tracing::info!("CUDA GPU available");
542 return device;
543 }
544 }
545
546 tracing::info!("Using CPU (no GPU acceleration)");
548 Device::Cpu
549 }
550
551 pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
553 let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
555
556 let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
558 let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
559 let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
560 let audio_rms =
561 (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
562
563 tracing::info!(
564 duration = duration_secs,
565 samples = pcm_data.len(),
566 min = audio_min,
567 max = audio_max,
568 mean = audio_mean,
569 rms = audio_rms,
570 "Audio decoded"
571 );
572
573 self.transcribe_pcm(&pcm_data, duration_secs)
574 }
575
576 pub fn transcribe_pcm(
578 &mut self,
579 pcm_data: &[f32],
580 duration_secs: f32,
581 ) -> Result<TranscriptionResult> {
582 const CHUNK_LENGTH: usize = 30 * 16000; const N_FRAMES: usize = 3000; const SAMPLE_RATE: f32 = 16000.0;
586
587 let silence_threshold = 0.01; let window_size = 1600; let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
592 let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
593
594 let trimmed_start = start_sample as f32 / SAMPLE_RATE;
595 let trimmed_end = end_sample as f32 / SAMPLE_RATE;
596
597 tracing::info!(
598 start_sample = start_sample,
599 end_sample = end_sample,
600 trimmed_start_sec = trimmed_start,
601 trimmed_end_sec = trimmed_end,
602 original_duration = duration_secs,
603 "Trimmed silence"
604 );
605
606 let pcm_data = &pcm_data[start_sample..end_sample];
608 let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
609
610 let mut all_text = String::new();
611 let mut segments = Vec::new();
612
613 let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
615
616 for chunk_idx in 0..num_chunks {
617 let chunk_start = chunk_idx * CHUNK_LENGTH;
618 let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
619 let chunk = &pcm_data[chunk_start..chunk_end];
620
621 let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
623 let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
624
625 tracing::info!(
626 chunk = chunk_idx + 1,
627 total = num_chunks,
628 start = start_time,
629 end = end_time,
630 "Processing chunk"
631 );
632
633 match &mut self.model {
635 Model::Normal(m) => m.decoder.reset_kv_cache(),
636 Model::Quantized(m) => m.decoder.reset_kv_cache(),
637 }
638
639 let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
641 let n_mels = self.config.num_mel_bins;
642 let mel_len = mel.len();
643 let n_frames = mel_len / n_mels;
644
645 if chunk_idx == 0 {
646 tracing::info!(
648 num_mel_bins = self.config.num_mel_bins,
649 max_source_positions = self.config.max_source_positions,
650 max_target_positions = self.config.max_target_positions,
651 "Model config"
652 );
653
654 let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
656 let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
657 let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
658
659 tracing::info!(
660 mel_len = mel_len,
661 n_mels = n_mels,
662 n_frames = n_frames,
663 chunk_samples = chunk.len(),
664 expected_frames = 3000,
665 mel_min = mel_min,
666 mel_max = mel_max,
667 mel_mean = mel_mean,
668 "Mel spectrogram computed"
669 );
670 }
671
672 let mel = if n_frames < N_FRAMES {
676 let mut padded = vec![0.0f32; n_mels * N_FRAMES];
678 for bin in 0..n_mels {
679 let src_start = bin * n_frames;
680 let dst_start = bin * N_FRAMES;
681 padded[dst_start..dst_start + n_frames]
682 .copy_from_slice(&mel[src_start..src_start + n_frames]);
683 }
684 padded
685 } else if n_frames > N_FRAMES {
686 let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
688 for bin in 0..n_mels {
689 let src_start = bin * n_frames;
690 let dst_start = bin * N_FRAMES;
691 truncated[dst_start..dst_start + N_FRAMES]
692 .copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
693 }
694 truncated
695 } else {
696 mel
697 };
698
699 let mel =
700 Tensor::from_vec(mel, (1, n_mels, N_FRAMES), &self.device).map_err(|e| {
701 WhisperError::InferenceError {
702 cause: format!("Failed to create mel tensor: {}", e),
703 }
704 })?;
705
706 if chunk_idx == 0 {
707 let mel_shape = mel.shape();
708 tracing::info!(
709 mel_shape = ?mel_shape,
710 "Mel tensor shape"
711 );
712 }
713
714 let audio_features = match &mut self.model {
716 Model::Normal(m) => m.encoder.forward(&mel, true),
717 Model::Quantized(m) => m.encoder.forward(&mel, true),
718 }
719 .map_err(|e| WhisperError::InferenceError {
720 cause: format!("Encoder forward failed: {}", e),
721 })?;
722
723 if chunk_idx == 0 {
724 let af_shape = audio_features.shape();
725 tracing::info!(
726 audio_features_shape = ?af_shape,
727 "Audio features from encoder"
728 );
729 }
730
731 let sot_token = self.token_id(m::SOT_TOKEN)?;
733 let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
734 let eot_token = self.token_id(m::EOT_TOKEN)?;
735 let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
736
737 if chunk_idx == 0 {
738 let en_token = self.tokenizer.token_to_id("<|en|>");
739 tracing::info!(
740 sot = sot_token,
741 transcribe = transcribe_token,
742 eot = eot_token,
743 no_timestamps = no_timestamps_token,
744 en_token = ?en_token,
745 "Special tokens"
746 );
747 }
748
749 let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
753
754 let is_english_only = self.config.vocab_size == 51864;
756
757 let tokens = if is_english_only {
758 vec![sot_token, transcribe_token, no_timestamps_token]
760 } else if has_language_token {
761 let language_token = self.token_id("<|en|>")?;
763 vec![
764 sot_token,
765 language_token,
766 transcribe_token,
767 no_timestamps_token,
768 ]
769 } else {
770 vec![sot_token, transcribe_token, no_timestamps_token]
772 };
773
774 if chunk_idx == 0 {
775 tracing::info!(
776 is_english_only = is_english_only,
777 vocab_size = self.config.vocab_size,
778 prompt_tokens = ?tokens,
779 "Initial prompt"
780 );
781 }
782 let mut all_tokens = tokens.clone();
783
784 let sample_len = self.config.max_target_positions / 2;
786 let mut repeat_count = 0;
787 let mut last_token: Option<u32> = None;
788
789 let suppress_tokens = &self.config.suppress_tokens;
791
792 for i in 0..sample_len {
793 let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
797 .and_then(|t| t.unsqueeze(0))
798 .map_err(|e| WhisperError::InferenceError {
799 cause: format!("Failed to create tokens tensor: {}", e),
800 })?;
801
802 if chunk_idx == 0 && i < 3 {
803 tracing::info!(
804 step = i,
805 all_tokens_len = all_tokens.len(),
806 tokens_shape = ?tokens_tensor.shape(),
807 "Decoder input"
808 );
809 }
810
811 let logits = match &mut self.model {
814 Model::Normal(m) => {
815 let hidden = m
816 .decoder
817 .forward(&tokens_tensor, &audio_features, true)
818 .map_err(|e| WhisperError::InferenceError {
819 cause: format!("Decoder forward failed: {}", e),
820 })?;
821 m.decoder.final_linear(&hidden).map_err(|e| {
822 WhisperError::InferenceError {
823 cause: format!("Final linear failed: {}", e),
824 }
825 })?
826 }
827 Model::Quantized(m) => {
828 let hidden = m
829 .decoder
830 .forward(&tokens_tensor, &audio_features, true)
831 .map_err(|e| WhisperError::InferenceError {
832 cause: format!("Decoder forward failed: {}", e),
833 })?;
834 m.decoder.final_linear(&hidden).map_err(|e| {
835 WhisperError::InferenceError {
836 cause: format!("Final linear failed: {}", e),
837 }
838 })?
839 }
840 };
841
842 if chunk_idx == 0 && i == 0 {
843 tracing::info!(
844 logits_shape = ?logits.shape(),
845 "Decoder output logits"
846 );
847 }
848
849 let (_, seq_len, _) =
851 logits.dims3().map_err(|e| WhisperError::InferenceError {
852 cause: format!("Failed to get logits dims: {}", e),
853 })?;
854 let mut logits_vec = logits
855 .i((0, seq_len - 1, ..))
856 .and_then(|t| t.to_vec1::<f32>())
857 .map_err(|e| WhisperError::InferenceError {
858 cause: format!("Failed to extract logits: {}", e),
859 })?;
860
861 for &token_id in suppress_tokens.iter() {
863 if (token_id as usize) < logits_vec.len() {
864 logits_vec[token_id as usize] = f32::NEG_INFINITY;
865 }
866 }
867
868 if all_tokens.len() < 10 {
870 logits_vec[eot_token as usize] = f32::NEG_INFINITY;
871 }
872
873 logits_vec[sot_token as usize] = f32::NEG_INFINITY;
877 logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
878 logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
879 for token_id in 50257..logits_vec.len() {
881 logits_vec[token_id] = f32::NEG_INFINITY;
882 }
883
884 if chunk_idx == 0 && i == 0 {
885 tracing::info!(
886 suppress_count = suppress_tokens.len(),
887 eot_suppressed = all_tokens.len() < 10,
888 "Applied token suppression"
889 );
890 }
891
892 let next_token = logits_vec
894 .iter()
895 .enumerate()
896 .max_by(|(_, a), (_, b)| {
897 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
898 })
899 .map(|(idx, _)| idx as u32)
900 .unwrap_or(eot_token);
901
902 if chunk_idx == 0 && i < 5 {
903 let max_logit =
904 logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
905 let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
906 tracing::info!(
907 step = i,
908 next_token = next_token,
909 max_logit = max_logit,
910 min_logit = min_logit,
911 "Decoding step"
912 );
913 }
914
915 if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
916 if chunk_idx == 0 && i < 5 {
917 tracing::info!(
918 next_token = next_token,
919 eot = eot_token,
920 "Stopping: EOT or invalid token"
921 );
922 }
923 break;
924 }
925
926 if Some(next_token) == last_token {
928 repeat_count += 1;
929 if repeat_count > 3 {
930 tracing::debug!("Breaking due to token repetition");
931 break;
932 }
933 } else {
934 repeat_count = 0;
935 }
936 last_token = Some(next_token);
937
938 all_tokens.push(next_token);
939 }
940
941 let prompt_len = if is_english_only { 3 } else { 4 };
943
944 if chunk_idx == 0 {
945 tracing::info!(
946 prompt_tokens = ?&all_tokens[..prompt_len],
947 generated_tokens = ?&all_tokens[prompt_len..],
948 total = all_tokens.len(),
949 "Generated tokens for chunk"
950 );
951 }
952
953 let chunk_text = self
954 .tokenizer
955 .decode(&all_tokens[prompt_len..], true) .map_err(|e| WhisperError::InferenceError {
957 cause: format!("Failed to decode tokens: {}", e),
958 })?;
959
960 let trimmed_text = chunk_text.trim();
961 if !trimmed_text.is_empty() {
962 if !all_text.is_empty() {
963 all_text.push(' ');
964 }
965 all_text.push_str(trimmed_text);
966
967 segments.push(TranscriptionSegment {
968 start: start_time,
969 end: end_time,
970 text: trimmed_text.to_string(),
971 });
972 }
973 }
974
975 Ok(TranscriptionResult {
976 text: all_text.trim().to_string(),
977 language: "en".to_string(),
978 duration_secs,
979 segments,
980 })
981 }
982
983 fn token_id(&self, token: &str) -> Result<u32> {
984 self.tokenizer.token_to_id(token).ok_or_else(|| {
985 WhisperError::InferenceError {
986 cause: format!("Token '{}' not found in vocabulary", token),
987 }
988 .into()
989 })
990 }
991 }
992
993 fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
995 for i in (0..samples.len()).step_by(window_size) {
996 let end = (i + window_size).min(samples.len());
997 let window = &samples[i..end];
998 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
999 if rms > threshold {
1000 return i.saturating_sub(window_size);
1002 }
1003 }
1004 0 }
1006
1007 fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1009 for i in (0..samples.len()).rev().step_by(window_size) {
1010 let start = i.saturating_sub(window_size);
1011 let window = &samples[start..=i.min(samples.len() - 1)];
1012 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1013 if rms > threshold {
1014 return (i + window_size).min(samples.len());
1016 }
1017 }
1018 samples.len() }
1020}
1021
1022#[cfg(feature = "whisper")]
1023pub use inference::WhisperTranscriber;
1024
1025#[cfg(test)]
1030mod tests {
1031 use super::*;
1032
1033 #[test]
1034 fn whisper_model_registry() {
1035 let default = default_whisper_model_info();
1036 assert_eq!(default.name, "whisper-small-en");
1037 assert!(default.is_default);
1038 assert_eq!(default.language, "en");
1039
1040 let unknown = get_whisper_model_info("nonexistent");
1042 assert_eq!(unknown.name, "whisper-small-en");
1043 }
1044
1045 #[test]
1046 fn whisper_config_defaults() {
1047 let config = WhisperConfig::default();
1048 assert_eq!(config.model_name, "whisper-small-en");
1049 }
1050}