1use serde::{Deserialize, Serialize};
10use std::path::{Path, PathBuf};
11
12use crate::{MemvidError, Result};
13
14#[derive(Debug, Clone)]
20pub struct WhisperModelInfo {
21 pub model_id: &'static str,
23 pub name: &'static str,
25 pub size_mb: f32,
27 pub is_default: bool,
29 pub language: &'static str,
31}
32
33pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
35 WhisperModelInfo {
36 model_id: "openai/whisper-small.en",
37 name: "whisper-small-en",
38 size_mb: 244.0,
39 is_default: true,
40 language: "en",
41 },
42 WhisperModelInfo {
43 model_id: "openai/whisper-small",
44 name: "whisper-small",
45 size_mb: 244.0,
46 is_default: false,
47 language: "multilingual",
48 },
49];
50
51pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
53 WHISPER_MODELS
54 .iter()
55 .find(|m| m.name == name || m.model_id == name)
56 .unwrap_or_else(|| {
57 WHISPER_MODELS
58 .iter()
59 .find(|m| m.is_default)
60 .expect("default whisper model")
61 })
62}
63
64pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
66 WHISPER_MODELS
67 .iter()
68 .find(|m| m.is_default)
69 .expect("default whisper model exists")
70}
71
72#[derive(Debug, Clone)]
78pub struct WhisperConfig {
79 pub model_name: String,
81 pub models_dir: PathBuf,
83 pub offline: bool,
85}
86
87impl Default for WhisperConfig {
88 fn default() -> Self {
89 let models_dir = std::env::var("MEMVID_MODELS_DIR")
90 .ok()
91 .map(PathBuf::from)
92 .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
93 .unwrap_or_else(|| PathBuf::from(".memvid/models"));
94
95 let model_name = std::env::var("MEMVID_WHISPER_MODEL")
96 .unwrap_or_else(|_| "whisper-small-en".to_string());
97
98 let offline = std::env::var("MEMVID_OFFLINE").is_ok();
99
100 Self {
101 model_name,
102 models_dir,
103 offline,
104 }
105 }
106}
107
108#[derive(Debug, thiserror::Error)]
114pub enum WhisperError {
115 #[error("Whisper model '{model}' not found. {hint}")]
117 ModelNotFound { model: String, hint: String },
118
119 #[error("Failed to decode audio at {path:?}: {cause}")]
121 AudioDecodeError { path: PathBuf, cause: String },
122
123 #[error("Failed to decode audio bytes: {cause}")]
125 AudioBytesDecodeError { cause: String },
126
127 #[error("Whisper inference error: {cause}")]
129 InferenceError { cause: String },
130
131 #[error("Failed to download Whisper model: {cause}")]
133 DownloadError { cause: String },
134}
135
136impl From<WhisperError> for MemvidError {
137 fn from(err: WhisperError) -> Self {
138 MemvidError::ExtractionFailed {
139 reason: err.to_string().into_boxed_str(),
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TranscriptionResult {
151 pub text: String,
153 pub language: String,
155 pub duration_secs: f32,
157 #[serde(default)]
159 pub segments: Vec<TranscriptionSegment>,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct TranscriptionSegment {
165 pub start: f32,
167 pub end: f32,
169 pub text: String,
171}
172
173#[cfg(feature = "whisper")]
178mod audio {
179 use super::*;
180 use std::fs::File;
181 use symphonia::core::audio::SampleBuffer;
182 use symphonia::core::codecs::DecoderOptions;
183 use symphonia::core::formats::FormatOptions;
184 use symphonia::core::io::MediaSourceStream;
185 use symphonia::core::meta::MetadataOptions;
186 use symphonia::core::probe::Hint;
187
188 pub const WHISPER_SAMPLE_RATE: u32 = 16000;
190
191 pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
193 let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
194 path: path.to_path_buf(),
195 cause: e.to_string(),
196 })?;
197
198 let mss = MediaSourceStream::new(Box::new(file), Default::default());
199
200 let mut hint = Hint::new();
202 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
203 hint.with_extension(ext);
204 }
205
206 let format_opts = FormatOptions::default();
208 let metadata_opts = MetadataOptions::default();
209 let probed = symphonia::default::get_probe()
210 .format(&hint, mss, &format_opts, &metadata_opts)
211 .map_err(|e| WhisperError::AudioDecodeError {
212 path: path.to_path_buf(),
213 cause: format!("Failed to probe audio format: {}", e),
214 })?;
215
216 let mut format = probed.format;
217
218 let track = format
220 .tracks()
221 .iter()
222 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
223 .ok_or_else(|| WhisperError::AudioDecodeError {
224 path: path.to_path_buf(),
225 cause: "No audio track found".to_string(),
226 })?;
227
228 let track_id = track.id;
229 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
230 let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
231
232 let decoder_opts = DecoderOptions::default();
234 let mut decoder = symphonia::default::get_codecs()
235 .make(&track.codec_params, &decoder_opts)
236 .map_err(|e| WhisperError::AudioDecodeError {
237 path: path.to_path_buf(),
238 cause: format!("Failed to create decoder: {}", e),
239 })?;
240
241 let mut samples: Vec<f32> = Vec::new();
242
243 loop {
245 let packet = match format.next_packet() {
246 Ok(p) => p,
247 Err(symphonia::core::errors::Error::IoError(e))
248 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
249 {
250 break;
251 }
252 Err(_) => break,
253 };
254
255 if packet.track_id() != track_id {
256 continue;
257 }
258
259 let decoded = match decoder.decode(&packet) {
260 Ok(d) => d,
261 Err(_) => continue,
262 };
263
264 let spec = *decoded.spec();
265 let num_frames = decoded.frames();
266
267 if num_frames == 0 {
268 continue;
269 }
270
271 let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
272 sample_buf.copy_interleaved_ref(decoded);
273
274 let interleaved = sample_buf.samples();
275
276 if channels > 1 {
278 for chunk in interleaved.chunks(channels) {
279 let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
280 samples.push(mono);
281 }
282 } else {
283 samples.extend_from_slice(interleaved);
284 }
285 }
286
287 let duration_secs = samples.len() as f32 / sample_rate as f32;
288
289 let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
291 let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
292 let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
293 tracing::info!(
294 sample_rate = sample_rate,
295 channels = channels,
296 samples_before = samples.len(),
297 pre_min = pre_min,
298 pre_max = pre_max,
299 pre_rms = pre_rms,
300 "Audio before resampling"
301 );
302
303 let samples = if sample_rate != WHISPER_SAMPLE_RATE {
305 let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
306
307 let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
309 let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
310 let post_rms = (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
311 tracing::info!(
312 samples_after = resampled.len(),
313 post_min = post_min,
314 post_max = post_max,
315 post_rms = post_rms,
316 "Audio after resampling"
317 );
318 resampled
319 } else {
320 tracing::info!("Audio already at 16kHz, no resampling needed");
321 samples
322 };
323
324 Ok((samples, duration_secs))
325 }
326
327 fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
329 use rubato::{FftFixedIn, Resampler};
330
331 if from_rate == to_rate {
332 return samples.to_vec();
333 }
334
335 let chunk_size = 1024;
337 let mut resampler = FftFixedIn::<f32>::new(
338 from_rate as usize,
339 to_rate as usize,
340 chunk_size,
341 2, 1, )
344 .expect("Failed to create resampler");
345
346 let mut output = Vec::new();
347 let mut pos = 0;
348
349 while pos < samples.len() {
351 let end = (pos + chunk_size).min(samples.len());
352 let chunk = &samples[pos..end];
353
354 let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
356 let mut padded = chunk.to_vec();
357 padded.resize(chunk_size, 0.0);
358 padded
359 } else {
360 chunk.to_vec()
361 };
362
363 let input = vec![input_chunk];
364 let resampled = resampler.process(&input, None).expect("Resampling failed");
365
366 if !resampled.is_empty() && !resampled[0].is_empty() {
367 output.extend_from_slice(&resampled[0]);
368 }
369
370 pos += chunk_size;
371 }
372
373 let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
375 output.truncate(expected_len);
376
377 output
378 }
379}
380
381#[cfg(feature = "whisper")]
382pub use audio::*;
383
384#[cfg(feature = "whisper")]
389mod inference {
390 use super::*;
391 use candle_core::{DType, Device, IndexOp, Tensor};
392 use candle_nn::VarBuilder;
393 use candle_transformers::models::whisper::{self as m, audio, Config};
394 use hf_hub::{api::sync::Api, Repo, RepoType};
395 use tokenizers::Tokenizer;
396
397 pub struct WhisperTranscriber {
399 model: Model,
400 tokenizer: Tokenizer,
401 config: Config,
402 mel_filters: Vec<f32>,
403 device: Device,
404 }
405
406 #[allow(dead_code)]
407 enum Model {
408 Normal(m::model::Whisper),
409 Quantized(m::quantized_model::Whisper),
410 }
411
412 impl WhisperTranscriber {
413 pub fn new(config: &WhisperConfig) -> Result<Self> {
415 let device = Self::select_device();
417 tracing::info!(device = ?device, "Using device for Whisper");
418 let model_id = match config.model_name.as_str() {
419 "whisper-small-en" => "openai/whisper-small.en",
420 "whisper-small" => "openai/whisper-small",
421 "whisper-tiny.en" => "openai/whisper-tiny.en",
422 "whisper-tiny" => "openai/whisper-tiny",
423 "whisper-base.en" => "openai/whisper-base.en",
424 "whisper-base" => "openai/whisper-base",
425 "whisper-medium.en" => "openai/whisper-medium.en",
426 "whisper-medium" => "openai/whisper-medium",
427 "whisper-large-v3" => "openai/whisper-large-v3",
428 other => other, };
430
431 tracing::info!(model_id = model_id, "Loading Whisper model");
432
433 let api = Api::new().map_err(|e| WhisperError::DownloadError {
434 cause: e.to_string(),
435 })?;
436 let repo = api.repo(Repo::with_revision(
437 model_id.to_string(),
438 RepoType::Model,
439 "main".to_string(),
440 ));
441
442 let config_path = repo.get("config.json").map_err(|e| WhisperError::DownloadError {
444 cause: format!("Failed to download config.json: {}", e),
445 })?;
446 let tokenizer_path = repo.get("tokenizer.json").map_err(|e| WhisperError::DownloadError {
447 cause: format!("Failed to download tokenizer.json: {}", e),
448 })?;
449 let model_path = repo.get("model.safetensors").map_err(|e| WhisperError::DownloadError {
450 cause: format!("Failed to download model.safetensors: {}", e),
451 })?;
452
453 let config_str = std::fs::read_to_string(&config_path).map_err(|e| WhisperError::InferenceError {
455 cause: format!("Failed to read config: {}", e),
456 })?;
457 let model_config: Config = serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
458 cause: format!("Failed to parse config: {}", e),
459 })?;
460
461 let tokenizer = Tokenizer::from_file(&tokenizer_path)
463 .map_err(|e| WhisperError::InferenceError {
464 cause: format!("Failed to load tokenizer: {}", e),
465 })?;
466
467 let mel_bytes = match model_config.num_mel_bins {
469 80 => include_bytes!("melfilters.bytes").as_slice(),
470 128 => include_bytes!("melfilters128.bytes").as_slice(),
471 n => return Err(WhisperError::InferenceError {
472 cause: format!("Unsupported number of mel bins: {}", n),
473 }.into()),
474 };
475 let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
476 <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
477
478 let vb = unsafe {
480 VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
481 .map_err(|e| WhisperError::InferenceError {
482 cause: format!("Failed to load model weights: {}", e),
483 })?
484 };
485 let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone())
486 .map_err(|e| WhisperError::InferenceError {
487 cause: format!("Failed to load Whisper model: {}", e),
488 })?);
489
490 tracing::info!("Whisper model loaded successfully");
491
492 Ok(Self {
493 model,
494 tokenizer,
495 config: model_config,
496 mel_filters,
497 device,
498 })
499 }
500
501 fn select_device() -> Device {
503 #[cfg(feature = "metal")]
505 {
506 if let Ok(device) = Device::new_metal(0) {
507 tracing::info!("Metal GPU available");
508 return device;
509 }
510 }
511
512 #[cfg(feature = "cuda")]
514 {
515 if let Ok(device) = Device::new_cuda(0) {
516 tracing::info!("CUDA GPU available");
517 return device;
518 }
519 }
520
521 tracing::info!("Using CPU (no GPU acceleration)");
523 Device::Cpu
524 }
525
526 pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
528 let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
530
531 let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
533 let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
534 let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
535 let audio_rms = (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
536
537 tracing::info!(
538 duration = duration_secs,
539 samples = pcm_data.len(),
540 min = audio_min,
541 max = audio_max,
542 mean = audio_mean,
543 rms = audio_rms,
544 "Audio decoded"
545 );
546
547 self.transcribe_pcm(&pcm_data, duration_secs)
548 }
549
550 pub fn transcribe_pcm(&mut self, pcm_data: &[f32], duration_secs: f32) -> Result<TranscriptionResult> {
552 const CHUNK_LENGTH: usize = 30 * 16000; const N_FRAMES: usize = 3000; const SAMPLE_RATE: f32 = 16000.0;
556
557 let silence_threshold = 0.01; let window_size = 1600; let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
562 let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
563
564 let trimmed_start = start_sample as f32 / SAMPLE_RATE;
565 let trimmed_end = end_sample as f32 / SAMPLE_RATE;
566
567 tracing::info!(
568 start_sample = start_sample,
569 end_sample = end_sample,
570 trimmed_start_sec = trimmed_start,
571 trimmed_end_sec = trimmed_end,
572 original_duration = duration_secs,
573 "Trimmed silence"
574 );
575
576 let pcm_data = &pcm_data[start_sample..end_sample];
578 let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
579
580 let mut all_text = String::new();
581 let mut segments = Vec::new();
582
583 let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
585
586 for chunk_idx in 0..num_chunks {
587 let chunk_start = chunk_idx * CHUNK_LENGTH;
588 let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
589 let chunk = &pcm_data[chunk_start..chunk_end];
590
591 let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
593 let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
594
595 tracing::info!(
596 chunk = chunk_idx + 1,
597 total = num_chunks,
598 start = start_time,
599 end = end_time,
600 "Processing chunk"
601 );
602
603 match &mut self.model {
605 Model::Normal(m) => m.decoder.reset_kv_cache(),
606 Model::Quantized(m) => m.decoder.reset_kv_cache(),
607 }
608
609 let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
611 let n_mels = self.config.num_mel_bins;
612 let mel_len = mel.len();
613 let n_frames = mel_len / n_mels;
614
615 if chunk_idx == 0 {
616 tracing::info!(
618 num_mel_bins = self.config.num_mel_bins,
619 max_source_positions = self.config.max_source_positions,
620 max_target_positions = self.config.max_target_positions,
621 "Model config"
622 );
623
624 let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
626 let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
627 let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
628
629 tracing::info!(
630 mel_len = mel_len,
631 n_mels = n_mels,
632 n_frames = n_frames,
633 chunk_samples = chunk.len(),
634 expected_frames = 3000,
635 mel_min = mel_min,
636 mel_max = mel_max,
637 mel_mean = mel_mean,
638 "Mel spectrogram computed"
639 );
640 }
641
642 let mel = if n_frames < N_FRAMES {
646 let mut padded = vec![0.0f32; n_mels * N_FRAMES];
648 for bin in 0..n_mels {
649 let src_start = bin * n_frames;
650 let dst_start = bin * N_FRAMES;
651 padded[dst_start..dst_start + n_frames].copy_from_slice(&mel[src_start..src_start + n_frames]);
652 }
653 padded
654 } else if n_frames > N_FRAMES {
655 let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
657 for bin in 0..n_mels {
658 let src_start = bin * n_frames;
659 let dst_start = bin * N_FRAMES;
660 truncated[dst_start..dst_start + N_FRAMES].copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
661 }
662 truncated
663 } else {
664 mel
665 };
666
667 let mel = Tensor::from_vec(
668 mel,
669 (1, n_mels, N_FRAMES),
670 &self.device,
671 ).map_err(|e| WhisperError::InferenceError {
672 cause: format!("Failed to create mel tensor: {}", e),
673 })?;
674
675 if chunk_idx == 0 {
676 let mel_shape = mel.shape();
677 tracing::info!(
678 mel_shape = ?mel_shape,
679 "Mel tensor shape"
680 );
681 }
682
683 let audio_features = match &mut self.model {
685 Model::Normal(m) => m.encoder.forward(&mel, true),
686 Model::Quantized(m) => m.encoder.forward(&mel, true),
687 }.map_err(|e| WhisperError::InferenceError {
688 cause: format!("Encoder forward failed: {}", e),
689 })?;
690
691 if chunk_idx == 0 {
692 let af_shape = audio_features.shape();
693 tracing::info!(
694 audio_features_shape = ?af_shape,
695 "Audio features from encoder"
696 );
697 }
698
699 let sot_token = self.token_id(m::SOT_TOKEN)?;
701 let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
702 let eot_token = self.token_id(m::EOT_TOKEN)?;
703 let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
704
705 if chunk_idx == 0 {
706 let en_token = self.tokenizer.token_to_id("<|en|>");
707 tracing::info!(
708 sot = sot_token,
709 transcribe = transcribe_token,
710 eot = eot_token,
711 no_timestamps = no_timestamps_token,
712 en_token = ?en_token,
713 "Special tokens"
714 );
715 }
716
717 let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
721
722 let is_english_only = self.config.vocab_size == 51864;
724
725 let tokens = if is_english_only {
726 vec![sot_token, transcribe_token, no_timestamps_token]
728 } else if has_language_token {
729 let language_token = self.token_id("<|en|>")?;
731 vec![sot_token, language_token, transcribe_token, no_timestamps_token]
732 } else {
733 vec![sot_token, transcribe_token, no_timestamps_token]
735 };
736
737 if chunk_idx == 0 {
738 tracing::info!(
739 is_english_only = is_english_only,
740 vocab_size = self.config.vocab_size,
741 prompt_tokens = ?tokens,
742 "Initial prompt"
743 );
744 }
745 let mut all_tokens = tokens.clone();
746
747 let sample_len = self.config.max_target_positions / 2;
749 let mut repeat_count = 0;
750 let mut last_token: Option<u32> = None;
751
752 let suppress_tokens = &self.config.suppress_tokens;
754
755 for i in 0..sample_len {
756 let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
760 .and_then(|t| t.unsqueeze(0))
761 .map_err(|e| WhisperError::InferenceError {
762 cause: format!("Failed to create tokens tensor: {}", e),
763 })?;
764
765 if chunk_idx == 0 && i < 3 {
766 tracing::info!(
767 step = i,
768 all_tokens_len = all_tokens.len(),
769 tokens_shape = ?tokens_tensor.shape(),
770 "Decoder input"
771 );
772 }
773
774 let logits = match &mut self.model {
777 Model::Normal(m) => {
778 let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
779 .map_err(|e| WhisperError::InferenceError {
780 cause: format!("Decoder forward failed: {}", e),
781 })?;
782 m.decoder.final_linear(&hidden)
783 .map_err(|e| WhisperError::InferenceError {
784 cause: format!("Final linear failed: {}", e),
785 })?
786 }
787 Model::Quantized(m) => {
788 let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
789 .map_err(|e| WhisperError::InferenceError {
790 cause: format!("Decoder forward failed: {}", e),
791 })?;
792 m.decoder.final_linear(&hidden)
793 .map_err(|e| WhisperError::InferenceError {
794 cause: format!("Final linear failed: {}", e),
795 })?
796 }
797 };
798
799 if chunk_idx == 0 && i == 0 {
800 tracing::info!(
801 logits_shape = ?logits.shape(),
802 "Decoder output logits"
803 );
804 }
805
806 let (_, seq_len, _) = logits.dims3().map_err(|e| WhisperError::InferenceError {
808 cause: format!("Failed to get logits dims: {}", e),
809 })?;
810 let mut logits_vec = logits.i((0, seq_len - 1, ..))
811 .and_then(|t| t.to_vec1::<f32>())
812 .map_err(|e| WhisperError::InferenceError {
813 cause: format!("Failed to extract logits: {}", e),
814 })?;
815
816 for &token_id in suppress_tokens.iter() {
818 if (token_id as usize) < logits_vec.len() {
819 logits_vec[token_id as usize] = f32::NEG_INFINITY;
820 }
821 }
822
823 if all_tokens.len() < 10 {
825 logits_vec[eot_token as usize] = f32::NEG_INFINITY;
826 }
827
828 logits_vec[sot_token as usize] = f32::NEG_INFINITY;
832 logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
833 logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
834 for token_id in 50257..logits_vec.len() {
836 logits_vec[token_id] = f32::NEG_INFINITY;
837 }
838
839 if chunk_idx == 0 && i == 0 {
840 tracing::info!(
841 suppress_count = suppress_tokens.len(),
842 eot_suppressed = all_tokens.len() < 10,
843 "Applied token suppression"
844 );
845 }
846
847 let next_token = logits_vec
849 .iter()
850 .enumerate()
851 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
852 .map(|(idx, _)| idx as u32)
853 .unwrap_or(eot_token);
854
855 if chunk_idx == 0 && i < 5 {
856 let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
857 let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
858 tracing::info!(
859 step = i,
860 next_token = next_token,
861 max_logit = max_logit,
862 min_logit = min_logit,
863 "Decoding step"
864 );
865 }
866
867 if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
868 if chunk_idx == 0 && i < 5 {
869 tracing::info!(next_token = next_token, eot = eot_token, "Stopping: EOT or invalid token");
870 }
871 break;
872 }
873
874 if Some(next_token) == last_token {
876 repeat_count += 1;
877 if repeat_count > 3 {
878 tracing::debug!("Breaking due to token repetition");
879 break;
880 }
881 } else {
882 repeat_count = 0;
883 }
884 last_token = Some(next_token);
885
886 all_tokens.push(next_token);
887 }
888
889 let prompt_len = if is_english_only { 3 } else { 4 };
891
892 if chunk_idx == 0 {
893 tracing::info!(
894 prompt_tokens = ?&all_tokens[..prompt_len],
895 generated_tokens = ?&all_tokens[prompt_len..],
896 total = all_tokens.len(),
897 "Generated tokens for chunk"
898 );
899 }
900
901 let chunk_text = self.tokenizer
902 .decode(&all_tokens[prompt_len..], true) .map_err(|e| WhisperError::InferenceError {
904 cause: format!("Failed to decode tokens: {}", e),
905 })?;
906
907 let trimmed_text = chunk_text.trim();
908 if !trimmed_text.is_empty() {
909 if !all_text.is_empty() {
910 all_text.push(' ');
911 }
912 all_text.push_str(trimmed_text);
913
914 segments.push(TranscriptionSegment {
915 start: start_time,
916 end: end_time,
917 text: trimmed_text.to_string(),
918 });
919 }
920 }
921
922 Ok(TranscriptionResult {
923 text: all_text.trim().to_string(),
924 language: "en".to_string(),
925 duration_secs,
926 segments,
927 })
928 }
929
930 fn token_id(&self, token: &str) -> Result<u32> {
931 self.tokenizer
932 .token_to_id(token)
933 .ok_or_else(|| WhisperError::InferenceError {
934 cause: format!("Token '{}' not found in vocabulary", token),
935 }.into())
936 }
937 }
938
939 fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
941 for i in (0..samples.len()).step_by(window_size) {
942 let end = (i + window_size).min(samples.len());
943 let window = &samples[i..end];
944 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
945 if rms > threshold {
946 return i.saturating_sub(window_size);
948 }
949 }
950 0 }
952
953 fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
955 for i in (0..samples.len()).rev().step_by(window_size) {
956 let start = i.saturating_sub(window_size);
957 let window = &samples[start..=i.min(samples.len() - 1)];
958 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
959 if rms > threshold {
960 return (i + window_size).min(samples.len());
962 }
963 }
964 samples.len() }
966}
967
968#[cfg(feature = "whisper")]
969pub use inference::WhisperTranscriber;
970
971#[cfg(test)]
976mod tests {
977 use super::*;
978
979 #[test]
980 fn whisper_model_registry() {
981 let default = default_whisper_model_info();
982 assert_eq!(default.name, "whisper-small-en");
983 assert!(default.is_default);
984 assert_eq!(default.language, "en");
985
986 let unknown = get_whisper_model_info("nonexistent");
988 assert_eq!(unknown.name, "whisper-small-en");
989 }
990
991 #[test]
992 fn whisper_config_defaults() {
993 let config = WhisperConfig::default();
994 assert_eq!(config.model_name, "whisper-small-en");
995 }
996}