1#![allow(clippy::unwrap_used, clippy::expect_used)]
3use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14use crate::MemvidError;
15
16#[cfg(feature = "whisper")]
18use crate::Result;
19#[cfg(feature = "whisper")]
20use std::path::Path;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum QuantizationType {
29 FP32,
31 Q8K,
33 Q4K,
35}
36
37impl Default for QuantizationType {
38 fn default() -> Self {
39 Self::FP32
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct WhisperModelInfo {
46 pub model_id: &'static str,
48 pub name: &'static str,
50 pub size_mb: f32,
52 pub is_default: bool,
54 pub language: &'static str,
56 pub quantization: QuantizationType,
58 pub file_format: &'static str,
60}
61
62pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
64 WhisperModelInfo {
66 model_id: "openai/whisper-small.en",
67 name: "whisper-small-en",
68 size_mb: 244.0,
69 is_default: true,
70 language: "en",
71 quantization: QuantizationType::FP32,
72 file_format: "safetensors",
73 },
74 WhisperModelInfo {
75 model_id: "openai/whisper-small",
76 name: "whisper-small",
77 size_mb: 244.0,
78 is_default: false,
79 language: "multilingual",
80 quantization: QuantizationType::FP32,
81 file_format: "safetensors",
82 },
83 WhisperModelInfo {
85 model_id: "openai/whisper-tiny.en",
86 name: "whisper-tiny-en",
87 size_mb: 75.0,
88 is_default: false,
89 language: "en",
90 quantization: QuantizationType::FP32,
91 file_format: "safetensors",
92 },
93 WhisperModelInfo {
96 model_id: "lmz/candle-whisper",
97 name: "whisper-tiny-en-q8k",
98 size_mb: 19.0,
99 is_default: false,
100 language: "en",
101 quantization: QuantizationType::Q8K,
102 file_format: "gguf",
103 },
104 WhisperModelInfo {
105 model_id: "lmz/candle-whisper",
106 name: "whisper-tiny-q8k",
107 size_mb: 19.0,
108 is_default: false,
109 language: "multilingual",
110 quantization: QuantizationType::Q8K,
111 file_format: "gguf",
112 },
113];
114
115#[must_use]
117pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
118 WHISPER_MODELS
119 .iter()
120 .find(|m| m.name == name || m.model_id == name)
121 .unwrap_or_else(|| {
122 WHISPER_MODELS
123 .iter()
124 .find(|m| m.is_default)
125 .expect("default whisper model")
126 })
127}
128
129#[must_use]
131pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
132 WHISPER_MODELS
133 .iter()
134 .find(|m| m.is_default)
135 .expect("default whisper model exists")
136}
137
138#[derive(Debug, Clone)]
144pub struct WhisperConfig {
145 pub model_name: String,
147 pub models_dir: PathBuf,
149 pub offline: bool,
151}
152
153impl Default for WhisperConfig {
154 fn default() -> Self {
155 let models_dir = std::env::var("MEMVID_MODELS_DIR")
156 .ok()
157 .map(PathBuf::from)
158 .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
159 .unwrap_or_else(|| PathBuf::from(".memvid/models"));
160
161 let model_name = std::env::var("MEMVID_WHISPER_MODEL")
162 .unwrap_or_else(|_| "whisper-small-en".to_string());
163
164 let offline = std::env::var("MEMVID_OFFLINE").is_ok();
165
166 Self {
167 model_name,
168 models_dir,
169 offline,
170 }
171 }
172}
173
174impl WhisperConfig {
175 #[must_use]
180 pub fn with_quantization() -> Self {
181 Self {
182 model_name: "whisper-tiny-en-q8k".to_string(),
183 ..Default::default()
184 }
185 }
186
187 #[must_use]
189 pub fn with_model(model_name: impl Into<String>) -> Self {
190 Self {
191 model_name: model_name.into(),
192 ..Default::default()
193 }
194 }
195
196 #[must_use]
198 pub fn multilingual_quantized() -> Self {
199 Self {
200 model_name: "whisper-tiny-q8k".to_string(),
201 ..Default::default()
202 }
203 }
204
205 #[must_use]
207 pub fn tiny() -> Self {
208 Self {
209 model_name: "whisper-tiny-en".to_string(),
210 ..Default::default()
211 }
212 }
213}
214
215#[derive(Debug, thiserror::Error)]
221pub enum WhisperError {
222 #[error("Whisper model '{model}' not found. {hint}")]
224 ModelNotFound { model: String, hint: String },
225
226 #[error("Failed to decode audio at {path:?}: {cause}")]
228 AudioDecodeError { path: PathBuf, cause: String },
229
230 #[error("Failed to decode audio bytes: {cause}")]
232 AudioBytesDecodeError { cause: String },
233
234 #[error("Whisper inference error: {cause}")]
236 InferenceError { cause: String },
237
238 #[error("Failed to download Whisper model: {cause}")]
240 DownloadError { cause: String },
241}
242
243impl From<WhisperError> for MemvidError {
244 fn from(err: WhisperError) -> Self {
245 MemvidError::ExtractionFailed {
246 reason: err.to_string().into_boxed_str(),
247 }
248 }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct TranscriptionResult {
258 pub text: String,
260 pub language: String,
262 pub duration_secs: f32,
264 #[serde(default)]
266 pub segments: Vec<TranscriptionSegment>,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct TranscriptionSegment {
272 pub start: f32,
274 pub end: f32,
276 pub text: String,
278}
279
280#[cfg(feature = "whisper")]
285mod audio {
286 use super::*;
287 use std::fs::File;
288 use symphonia::core::audio::SampleBuffer;
289 use symphonia::core::codecs::DecoderOptions;
290 use symphonia::core::formats::FormatOptions;
291 use symphonia::core::io::MediaSourceStream;
292 use symphonia::core::meta::MetadataOptions;
293 use symphonia::core::probe::Hint;
294
295 pub const WHISPER_SAMPLE_RATE: u32 = 16000;
297
298 pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
300 let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
301 path: path.to_path_buf(),
302 cause: e.to_string(),
303 })?;
304
305 let mss = MediaSourceStream::new(Box::new(file), Default::default());
306
307 let mut hint = Hint::new();
309 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
310 hint.with_extension(ext);
311 }
312
313 let format_opts = FormatOptions::default();
315 let metadata_opts = MetadataOptions::default();
316 let probed = symphonia::default::get_probe()
317 .format(&hint, mss, &format_opts, &metadata_opts)
318 .map_err(|e| WhisperError::AudioDecodeError {
319 path: path.to_path_buf(),
320 cause: format!("Failed to probe audio format: {}", e),
321 })?;
322
323 let mut format = probed.format;
324
325 let track = format
327 .tracks()
328 .iter()
329 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
330 .ok_or_else(|| WhisperError::AudioDecodeError {
331 path: path.to_path_buf(),
332 cause: "No audio track found".to_string(),
333 })?;
334
335 let track_id = track.id;
336 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
337 let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
338
339 let decoder_opts = DecoderOptions::default();
341 let mut decoder = symphonia::default::get_codecs()
342 .make(&track.codec_params, &decoder_opts)
343 .map_err(|e| WhisperError::AudioDecodeError {
344 path: path.to_path_buf(),
345 cause: format!("Failed to create decoder: {}", e),
346 })?;
347
348 let mut samples: Vec<f32> = Vec::new();
349
350 loop {
352 let packet = match format.next_packet() {
353 Ok(p) => p,
354 Err(symphonia::core::errors::Error::IoError(e))
355 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
356 {
357 break;
358 }
359 Err(_) => break,
360 };
361
362 if packet.track_id() != track_id {
363 continue;
364 }
365
366 let decoded = match decoder.decode(&packet) {
367 Ok(d) => d,
368 Err(_) => continue,
369 };
370
371 let spec = *decoded.spec();
372 let num_frames = decoded.frames();
373
374 if num_frames == 0 {
375 continue;
376 }
377
378 let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
379 sample_buf.copy_interleaved_ref(decoded);
380
381 let interleaved = sample_buf.samples();
382
383 if channels > 1 {
385 for chunk in interleaved.chunks(channels) {
386 let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
387 samples.push(mono);
388 }
389 } else {
390 samples.extend_from_slice(interleaved);
391 }
392 }
393
394 let duration_secs = samples.len() as f32 / sample_rate as f32;
395
396 let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
398 let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
399 let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
400 tracing::info!(
401 sample_rate = sample_rate,
402 channels = channels,
403 samples_before = samples.len(),
404 pre_min = pre_min,
405 pre_max = pre_max,
406 pre_rms = pre_rms,
407 "Audio before resampling"
408 );
409
410 let samples = if sample_rate != WHISPER_SAMPLE_RATE {
412 let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
413
414 let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
416 let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
417 let post_rms =
418 (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
419 tracing::info!(
420 samples_after = resampled.len(),
421 post_min = post_min,
422 post_max = post_max,
423 post_rms = post_rms,
424 "Audio after resampling"
425 );
426 resampled
427 } else {
428 tracing::info!("Audio already at 16kHz, no resampling needed");
429 samples
430 };
431
432 Ok((samples, duration_secs))
433 }
434
435 fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
439 if from_rate == to_rate {
440 return samples.to_vec();
441 }
442
443 let ratio = to_rate as f64 / from_rate as f64;
444 let output_len = (samples.len() as f64 * ratio).ceil() as usize;
445 let mut output = Vec::with_capacity(output_len);
446
447 for i in 0..output_len {
448 let src_pos = i as f64 / ratio;
449 let src_idx = src_pos.floor() as usize;
450 let frac = (src_pos - src_idx as f64) as f32;
451
452 if src_idx + 1 < samples.len() {
453 let sample = samples[src_idx] * (1.0 - frac) + samples[src_idx + 1] * frac;
455 output.push(sample);
456 } else if src_idx < samples.len() {
457 output.push(samples[src_idx]);
458 }
459 }
460
461 output
462 }
463}
464
465#[cfg(feature = "whisper")]
466pub use audio::*;
467
468#[cfg(feature = "whisper")]
473mod inference {
474 use super::*;
475 use candle_core::{DType, Device, IndexOp, Tensor};
476 use candle_nn::VarBuilder;
477 use candle_transformers::models::whisper::{self as m, Config, audio};
478 use hf_hub::{Repo, RepoType, api::sync::Api};
479 use tokenizers::Tokenizer;
480
481 pub struct WhisperTranscriber {
483 model: Model,
484 tokenizer: Tokenizer,
485 config: Config,
486 mel_filters: Vec<f32>,
487 device: Device,
488 }
489
490 #[allow(dead_code)]
491 enum Model {
492 Normal(m::model::Whisper),
493 Quantized(m::quantized_model::Whisper),
494 }
495
496 impl WhisperTranscriber {
497 pub fn new(config: &WhisperConfig) -> Result<Self> {
499 let device = Self::select_device();
501 tracing::info!(device = ?device, "Using device for Whisper");
502
503 let model_info = get_whisper_model_info(&config.model_name);
505 let is_quantized = model_info.quantization != QuantizationType::FP32;
506
507 tracing::info!(
508 model_name = %config.model_name,
509 model_id = %model_info.model_id,
510 quantization = ?model_info.quantization,
511 file_format = %model_info.file_format,
512 "Loading Whisper model"
513 );
514
515 let api = Api::new().map_err(|e| WhisperError::DownloadError {
516 cause: e.to_string(),
517 })?;
518 let repo = api.repo(Repo::with_revision(
519 model_info.model_id.to_string(),
520 RepoType::Model,
521 "main".to_string(),
522 ));
523
524 let (config_path, tokenizer_path) = if is_quantized {
527 let base_model_id = match model_info.language {
529 "en" => "openai/whisper-tiny.en",
530 _ => "openai/whisper-tiny",
531 };
532 let base_repo = api.repo(Repo::with_revision(
533 base_model_id.to_string(),
534 RepoType::Model,
535 "main".to_string(),
536 ));
537
538 let cfg =
539 base_repo
540 .get("config.json")
541 .map_err(|e| WhisperError::DownloadError {
542 cause: format!("Failed to download config.json: {}", e),
543 })?;
544 let tok =
545 base_repo
546 .get("tokenizer.json")
547 .map_err(|e| WhisperError::DownloadError {
548 cause: format!("Failed to download tokenizer.json: {}", e),
549 })?;
550 (cfg, tok)
551 } else {
552 let cfg = repo
553 .get("config.json")
554 .map_err(|e| WhisperError::DownloadError {
555 cause: format!("Failed to download config.json: {}", e),
556 })?;
557 let tok = repo
558 .get("tokenizer.json")
559 .map_err(|e| WhisperError::DownloadError {
560 cause: format!("Failed to download tokenizer.json: {}", e),
561 })?;
562 (cfg, tok)
563 };
564
565 let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
567 WhisperError::InferenceError {
568 cause: format!("Failed to read config: {}", e),
569 }
570 })?;
571 let model_config: Config =
572 serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
573 cause: format!("Failed to parse config: {}", e),
574 })?;
575
576 let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
578 WhisperError::InferenceError {
579 cause: format!("Failed to load tokenizer: {}", e),
580 }
581 })?;
582
583 let mel_bytes = match model_config.num_mel_bins {
585 80 => include_bytes!("melfilters.bytes").as_slice(),
586 128 => include_bytes!("melfilters128.bytes").as_slice(),
587 n => {
588 return Err(WhisperError::InferenceError {
589 cause: format!("Unsupported number of mel bins: {}", n),
590 }
591 .into());
592 }
593 };
594 let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
595 <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(
596 mel_bytes,
597 &mut mel_filters,
598 );
599
600 let model = match model_info.quantization {
602 QuantizationType::FP32 => {
603 let model_path =
605 repo.get("model.safetensors")
606 .map_err(|e| WhisperError::DownloadError {
607 cause: format!("Failed to download model.safetensors: {}", e),
608 })?;
609
610 let vb = unsafe {
611 VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
612 .map_err(|e| WhisperError::InferenceError {
613 cause: format!("Failed to load model weights: {}", e),
614 })?
615 };
616 Model::Normal(m::model::Whisper::load(&vb, model_config.clone()).map_err(
617 |e| WhisperError::InferenceError {
618 cause: format!("Failed to load Whisper model: {}", e),
619 },
620 )?)
621 }
622 QuantizationType::Q8K | QuantizationType::Q4K => {
623 let gguf_filename = match (model_info.language, model_info.quantization) {
626 ("en", QuantizationType::Q8K) => "model-tiny-en-q80.gguf",
627 ("en", QuantizationType::Q4K) => "model-tiny-en-q40.gguf",
628 (_, QuantizationType::Q8K) => "model-tiny-q80.gguf",
629 (_, QuantizationType::Q4K) => "model-tiny-q40.gguf",
630 _ => "model-tiny-q80.gguf",
631 };
632
633 let model_path =
634 repo.get(gguf_filename)
635 .map_err(|e| WhisperError::DownloadError {
636 cause: format!("Failed to download {}: {}", gguf_filename, e),
637 })?;
638
639 tracing::info!(
640 gguf_file = %gguf_filename,
641 quantization = ?model_info.quantization,
642 "Loading quantized GGUF model"
643 );
644
645 let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
646 &model_path,
647 &device,
648 )
649 .map_err(|e| WhisperError::InferenceError {
650 cause: format!("Failed to load quantized model: {}", e),
651 })?;
652
653 Model::Quantized(
654 m::quantized_model::Whisper::load(&vb, model_config.clone()).map_err(
655 |e| WhisperError::InferenceError {
656 cause: format!("Failed to load quantized Whisper model: {}", e),
657 },
658 )?,
659 )
660 }
661 };
662
663 tracing::info!("Whisper model loaded successfully");
664
665 Ok(Self {
666 model,
667 tokenizer,
668 config: model_config,
669 mel_filters,
670 device,
671 })
672 }
673
674 fn select_device() -> Device {
676 #[cfg(feature = "metal")]
678 {
679 if let Ok(device) = Device::new_metal(0) {
680 tracing::info!("Metal GPU available");
681 return device;
682 }
683 }
684
685 #[cfg(feature = "cuda")]
687 {
688 if let Ok(device) = Device::new_cuda(0) {
689 tracing::info!("CUDA GPU available");
690 return device;
691 }
692 }
693
694 tracing::info!("Using CPU (no GPU acceleration)");
696 Device::Cpu
697 }
698
699 pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
701 let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
703
704 let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
706 let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
707 let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
708 let audio_rms =
709 (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
710
711 tracing::info!(
712 duration = duration_secs,
713 samples = pcm_data.len(),
714 min = audio_min,
715 max = audio_max,
716 mean = audio_mean,
717 rms = audio_rms,
718 "Audio decoded"
719 );
720
721 self.transcribe_pcm(&pcm_data, duration_secs)
722 }
723
724 pub fn transcribe_pcm(
726 &mut self,
727 pcm_data: &[f32],
728 duration_secs: f32,
729 ) -> Result<TranscriptionResult> {
730 const CHUNK_LENGTH: usize = 30 * 16000; const N_FRAMES: usize = 3000; const SAMPLE_RATE: f32 = 16000.0;
734
735 let silence_threshold = 0.01; let window_size = 1600; let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
740 let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
741
742 let trimmed_start = start_sample as f32 / SAMPLE_RATE;
743 let trimmed_end = end_sample as f32 / SAMPLE_RATE;
744
745 tracing::info!(
746 start_sample = start_sample,
747 end_sample = end_sample,
748 trimmed_start_sec = trimmed_start,
749 trimmed_end_sec = trimmed_end,
750 original_duration = duration_secs,
751 "Trimmed silence"
752 );
753
754 let pcm_data = &pcm_data[start_sample..end_sample];
756 let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
757
758 let mut all_text = String::new();
759 let mut segments = Vec::new();
760
761 let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
763
764 for chunk_idx in 0..num_chunks {
765 let chunk_start = chunk_idx * CHUNK_LENGTH;
766 let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
767 let chunk = &pcm_data[chunk_start..chunk_end];
768
769 let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
771 let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
772
773 tracing::info!(
774 chunk = chunk_idx + 1,
775 total = num_chunks,
776 start = start_time,
777 end = end_time,
778 "Processing chunk"
779 );
780
781 match &mut self.model {
783 Model::Normal(m) => m.decoder.reset_kv_cache(),
784 Model::Quantized(m) => m.decoder.reset_kv_cache(),
785 }
786
787 let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
789 let n_mels = self.config.num_mel_bins;
790 let mel_len = mel.len();
791 let n_frames = mel_len / n_mels;
792
793 if chunk_idx == 0 {
794 tracing::info!(
796 num_mel_bins = self.config.num_mel_bins,
797 max_source_positions = self.config.max_source_positions,
798 max_target_positions = self.config.max_target_positions,
799 "Model config"
800 );
801
802 let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
804 let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
805 let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
806
807 tracing::info!(
808 mel_len = mel_len,
809 n_mels = n_mels,
810 n_frames = n_frames,
811 chunk_samples = chunk.len(),
812 expected_frames = 3000,
813 mel_min = mel_min,
814 mel_max = mel_max,
815 mel_mean = mel_mean,
816 "Mel spectrogram computed"
817 );
818 }
819
820 let mel = if n_frames < N_FRAMES {
824 let mut padded = vec![0.0f32; n_mels * N_FRAMES];
826 for bin in 0..n_mels {
827 let src_start = bin * n_frames;
828 let dst_start = bin * N_FRAMES;
829 padded[dst_start..dst_start + n_frames]
830 .copy_from_slice(&mel[src_start..src_start + n_frames]);
831 }
832 padded
833 } else if n_frames > N_FRAMES {
834 let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
836 for bin in 0..n_mels {
837 let src_start = bin * n_frames;
838 let dst_start = bin * N_FRAMES;
839 truncated[dst_start..dst_start + N_FRAMES]
840 .copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
841 }
842 truncated
843 } else {
844 mel
845 };
846
847 let mel =
848 Tensor::from_vec(mel, (1, n_mels, N_FRAMES), &self.device).map_err(|e| {
849 WhisperError::InferenceError {
850 cause: format!("Failed to create mel tensor: {}", e),
851 }
852 })?;
853
854 if chunk_idx == 0 {
855 let mel_shape = mel.shape();
856 tracing::info!(
857 mel_shape = ?mel_shape,
858 "Mel tensor shape"
859 );
860 }
861
862 let audio_features = match &mut self.model {
864 Model::Normal(m) => m.encoder.forward(&mel, true),
865 Model::Quantized(m) => m.encoder.forward(&mel, true),
866 }
867 .map_err(|e| WhisperError::InferenceError {
868 cause: format!("Encoder forward failed: {}", e),
869 })?;
870
871 if chunk_idx == 0 {
872 let af_shape = audio_features.shape();
873 tracing::info!(
874 audio_features_shape = ?af_shape,
875 "Audio features from encoder"
876 );
877 }
878
879 let sot_token = self.token_id(m::SOT_TOKEN)?;
881 let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
882 let eot_token = self.token_id(m::EOT_TOKEN)?;
883 let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
884
885 if chunk_idx == 0 {
886 let en_token = self.tokenizer.token_to_id("<|en|>");
887 tracing::info!(
888 sot = sot_token,
889 transcribe = transcribe_token,
890 eot = eot_token,
891 no_timestamps = no_timestamps_token,
892 en_token = ?en_token,
893 "Special tokens"
894 );
895 }
896
897 let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
901
902 let is_english_only = self.config.vocab_size == 51864;
904
905 let tokens = if is_english_only {
906 vec![sot_token, transcribe_token, no_timestamps_token]
908 } else if has_language_token {
909 let language_token = self.token_id("<|en|>")?;
911 vec![
912 sot_token,
913 language_token,
914 transcribe_token,
915 no_timestamps_token,
916 ]
917 } else {
918 vec![sot_token, transcribe_token, no_timestamps_token]
920 };
921
922 if chunk_idx == 0 {
923 tracing::info!(
924 is_english_only = is_english_only,
925 vocab_size = self.config.vocab_size,
926 prompt_tokens = ?tokens,
927 "Initial prompt"
928 );
929 }
930 let mut all_tokens = tokens.clone();
931
932 let sample_len = self.config.max_target_positions / 2;
934 let mut repeat_count = 0;
935 let mut last_token: Option<u32> = None;
936
937 let suppress_tokens = &self.config.suppress_tokens;
939
940 for i in 0..sample_len {
941 let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
945 .and_then(|t| t.unsqueeze(0))
946 .map_err(|e| WhisperError::InferenceError {
947 cause: format!("Failed to create tokens tensor: {}", e),
948 })?;
949
950 if chunk_idx == 0 && i < 3 {
951 tracing::info!(
952 step = i,
953 all_tokens_len = all_tokens.len(),
954 tokens_shape = ?tokens_tensor.shape(),
955 "Decoder input"
956 );
957 }
958
959 let logits = match &mut self.model {
962 Model::Normal(m) => {
963 let hidden = m
964 .decoder
965 .forward(&tokens_tensor, &audio_features, true)
966 .map_err(|e| WhisperError::InferenceError {
967 cause: format!("Decoder forward failed: {}", e),
968 })?;
969 m.decoder.final_linear(&hidden).map_err(|e| {
970 WhisperError::InferenceError {
971 cause: format!("Final linear failed: {}", e),
972 }
973 })?
974 }
975 Model::Quantized(m) => {
976 let hidden = m
977 .decoder
978 .forward(&tokens_tensor, &audio_features, true)
979 .map_err(|e| WhisperError::InferenceError {
980 cause: format!("Decoder forward failed: {}", e),
981 })?;
982 m.decoder.final_linear(&hidden).map_err(|e| {
983 WhisperError::InferenceError {
984 cause: format!("Final linear failed: {}", e),
985 }
986 })?
987 }
988 };
989
990 if chunk_idx == 0 && i == 0 {
991 tracing::info!(
992 logits_shape = ?logits.shape(),
993 "Decoder output logits"
994 );
995 }
996
997 let (_, seq_len, _) =
999 logits.dims3().map_err(|e| WhisperError::InferenceError {
1000 cause: format!("Failed to get logits dims: {}", e),
1001 })?;
1002 let mut logits_vec = logits
1003 .i((0, seq_len - 1, ..))
1004 .and_then(|t| t.to_vec1::<f32>())
1005 .map_err(|e| WhisperError::InferenceError {
1006 cause: format!("Failed to extract logits: {}", e),
1007 })?;
1008
1009 for &token_id in suppress_tokens.iter() {
1011 if (token_id as usize) < logits_vec.len() {
1012 logits_vec[token_id as usize] = f32::NEG_INFINITY;
1013 }
1014 }
1015
1016 if all_tokens.len() < 10 {
1018 logits_vec[eot_token as usize] = f32::NEG_INFINITY;
1019 }
1020
1021 logits_vec[sot_token as usize] = f32::NEG_INFINITY;
1025 logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
1026 logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
1027 for token_id in 50257..logits_vec.len() {
1029 logits_vec[token_id] = f32::NEG_INFINITY;
1030 }
1031
1032 if chunk_idx == 0 && i == 0 {
1033 tracing::info!(
1034 suppress_count = suppress_tokens.len(),
1035 eot_suppressed = all_tokens.len() < 10,
1036 "Applied token suppression"
1037 );
1038 }
1039
1040 let next_token = logits_vec
1042 .iter()
1043 .enumerate()
1044 .max_by(|(_, a), (_, b)| {
1045 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
1046 })
1047 .map(|(idx, _)| idx as u32)
1048 .unwrap_or(eot_token);
1049
1050 if chunk_idx == 0 && i < 5 {
1051 let max_logit =
1052 logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1053 let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
1054 tracing::info!(
1055 step = i,
1056 next_token = next_token,
1057 max_logit = max_logit,
1058 min_logit = min_logit,
1059 "Decoding step"
1060 );
1061 }
1062
1063 if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
1064 if chunk_idx == 0 && i < 5 {
1065 tracing::info!(
1066 next_token = next_token,
1067 eot = eot_token,
1068 "Stopping: EOT or invalid token"
1069 );
1070 }
1071 break;
1072 }
1073
1074 if Some(next_token) == last_token {
1076 repeat_count += 1;
1077 if repeat_count > 3 {
1078 tracing::debug!("Breaking due to token repetition");
1079 break;
1080 }
1081 } else {
1082 repeat_count = 0;
1083 }
1084 last_token = Some(next_token);
1085
1086 all_tokens.push(next_token);
1087 }
1088
1089 let prompt_len = if is_english_only { 3 } else { 4 };
1091
1092 if chunk_idx == 0 {
1093 tracing::info!(
1094 prompt_tokens = ?&all_tokens[..prompt_len],
1095 generated_tokens = ?&all_tokens[prompt_len..],
1096 total = all_tokens.len(),
1097 "Generated tokens for chunk"
1098 );
1099 }
1100
1101 let chunk_text = self
1102 .tokenizer
1103 .decode(&all_tokens[prompt_len..], true) .map_err(|e| WhisperError::InferenceError {
1105 cause: format!("Failed to decode tokens: {}", e),
1106 })?;
1107
1108 let trimmed_text = chunk_text.trim();
1109 if !trimmed_text.is_empty() {
1110 if !all_text.is_empty() {
1111 all_text.push(' ');
1112 }
1113 all_text.push_str(trimmed_text);
1114
1115 segments.push(TranscriptionSegment {
1116 start: start_time,
1117 end: end_time,
1118 text: trimmed_text.to_string(),
1119 });
1120 }
1121 }
1122
1123 Ok(TranscriptionResult {
1124 text: all_text.trim().to_string(),
1125 language: "en".to_string(),
1126 duration_secs,
1127 segments,
1128 })
1129 }
1130
1131 fn token_id(&self, token: &str) -> Result<u32> {
1132 self.tokenizer.token_to_id(token).ok_or_else(|| {
1133 WhisperError::InferenceError {
1134 cause: format!("Token '{}' not found in vocabulary", token),
1135 }
1136 .into()
1137 })
1138 }
1139 }
1140
1141 fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1143 for i in (0..samples.len()).step_by(window_size) {
1144 let end = (i + window_size).min(samples.len());
1145 let window = &samples[i..end];
1146 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1147 if rms > threshold {
1148 return i.saturating_sub(window_size);
1150 }
1151 }
1152 0 }
1154
1155 fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1157 for i in (0..samples.len()).rev().step_by(window_size) {
1158 let start = i.saturating_sub(window_size);
1159 let window = &samples[start..=i.min(samples.len() - 1)];
1160 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1161 if rms > threshold {
1162 return (i + window_size).min(samples.len());
1164 }
1165 }
1166 samples.len() }
1168}
1169
1170#[cfg(feature = "whisper")]
1171pub use inference::WhisperTranscriber;
1172
1173#[cfg(test)]
1178mod tests {
1179 use super::*;
1180
1181 #[test]
1182 fn whisper_model_registry() {
1183 let default = default_whisper_model_info();
1184 assert_eq!(default.name, "whisper-small-en");
1185 assert!(default.is_default);
1186 assert_eq!(default.language, "en");
1187
1188 let unknown = get_whisper_model_info("nonexistent");
1190 assert_eq!(unknown.name, "whisper-small-en");
1191 }
1192
1193 #[test]
1194 fn whisper_config_defaults() {
1195 let config = WhisperConfig::default();
1196 assert_eq!(config.model_name, "whisper-small-en");
1197 }
1198}