1use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12use crate::MemvidError;
13
14#[cfg(feature = "whisper")]
16use std::path::Path;
17#[cfg(feature = "whisper")]
18use crate::Result;
19
20#[derive(Debug, Clone)]
26pub struct WhisperModelInfo {
27 pub model_id: &'static str,
29 pub name: &'static str,
31 pub size_mb: f32,
33 pub is_default: bool,
35 pub language: &'static str,
37}
38
39pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
41 WhisperModelInfo {
42 model_id: "openai/whisper-small.en",
43 name: "whisper-small-en",
44 size_mb: 244.0,
45 is_default: true,
46 language: "en",
47 },
48 WhisperModelInfo {
49 model_id: "openai/whisper-small",
50 name: "whisper-small",
51 size_mb: 244.0,
52 is_default: false,
53 language: "multilingual",
54 },
55];
56
57pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
59 WHISPER_MODELS
60 .iter()
61 .find(|m| m.name == name || m.model_id == name)
62 .unwrap_or_else(|| {
63 WHISPER_MODELS
64 .iter()
65 .find(|m| m.is_default)
66 .expect("default whisper model")
67 })
68}
69
70pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
72 WHISPER_MODELS
73 .iter()
74 .find(|m| m.is_default)
75 .expect("default whisper model exists")
76}
77
78#[derive(Debug, Clone)]
84pub struct WhisperConfig {
85 pub model_name: String,
87 pub models_dir: PathBuf,
89 pub offline: bool,
91}
92
93impl Default for WhisperConfig {
94 fn default() -> Self {
95 let models_dir = std::env::var("MEMVID_MODELS_DIR")
96 .ok()
97 .map(PathBuf::from)
98 .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
99 .unwrap_or_else(|| PathBuf::from(".memvid/models"));
100
101 let model_name = std::env::var("MEMVID_WHISPER_MODEL")
102 .unwrap_or_else(|_| "whisper-small-en".to_string());
103
104 let offline = std::env::var("MEMVID_OFFLINE").is_ok();
105
106 Self {
107 model_name,
108 models_dir,
109 offline,
110 }
111 }
112}
113
114#[derive(Debug, thiserror::Error)]
120pub enum WhisperError {
121 #[error("Whisper model '{model}' not found. {hint}")]
123 ModelNotFound { model: String, hint: String },
124
125 #[error("Failed to decode audio at {path:?}: {cause}")]
127 AudioDecodeError { path: PathBuf, cause: String },
128
129 #[error("Failed to decode audio bytes: {cause}")]
131 AudioBytesDecodeError { cause: String },
132
133 #[error("Whisper inference error: {cause}")]
135 InferenceError { cause: String },
136
137 #[error("Failed to download Whisper model: {cause}")]
139 DownloadError { cause: String },
140}
141
142impl From<WhisperError> for MemvidError {
143 fn from(err: WhisperError) -> Self {
144 MemvidError::ExtractionFailed {
145 reason: err.to_string().into_boxed_str(),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TranscriptionResult {
157 pub text: String,
159 pub language: String,
161 pub duration_secs: f32,
163 #[serde(default)]
165 pub segments: Vec<TranscriptionSegment>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct TranscriptionSegment {
171 pub start: f32,
173 pub end: f32,
175 pub text: String,
177}
178
179#[cfg(feature = "whisper")]
184mod audio {
185 use super::*;
186 use std::fs::File;
187 use symphonia::core::audio::SampleBuffer;
188 use symphonia::core::codecs::DecoderOptions;
189 use symphonia::core::formats::FormatOptions;
190 use symphonia::core::io::MediaSourceStream;
191 use symphonia::core::meta::MetadataOptions;
192 use symphonia::core::probe::Hint;
193
194 pub const WHISPER_SAMPLE_RATE: u32 = 16000;
196
197 pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
199 let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
200 path: path.to_path_buf(),
201 cause: e.to_string(),
202 })?;
203
204 let mss = MediaSourceStream::new(Box::new(file), Default::default());
205
206 let mut hint = Hint::new();
208 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
209 hint.with_extension(ext);
210 }
211
212 let format_opts = FormatOptions::default();
214 let metadata_opts = MetadataOptions::default();
215 let probed = symphonia::default::get_probe()
216 .format(&hint, mss, &format_opts, &metadata_opts)
217 .map_err(|e| WhisperError::AudioDecodeError {
218 path: path.to_path_buf(),
219 cause: format!("Failed to probe audio format: {}", e),
220 })?;
221
222 let mut format = probed.format;
223
224 let track = format
226 .tracks()
227 .iter()
228 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
229 .ok_or_else(|| WhisperError::AudioDecodeError {
230 path: path.to_path_buf(),
231 cause: "No audio track found".to_string(),
232 })?;
233
234 let track_id = track.id;
235 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
236 let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
237
238 let decoder_opts = DecoderOptions::default();
240 let mut decoder = symphonia::default::get_codecs()
241 .make(&track.codec_params, &decoder_opts)
242 .map_err(|e| WhisperError::AudioDecodeError {
243 path: path.to_path_buf(),
244 cause: format!("Failed to create decoder: {}", e),
245 })?;
246
247 let mut samples: Vec<f32> = Vec::new();
248
249 loop {
251 let packet = match format.next_packet() {
252 Ok(p) => p,
253 Err(symphonia::core::errors::Error::IoError(e))
254 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
255 {
256 break;
257 }
258 Err(_) => break,
259 };
260
261 if packet.track_id() != track_id {
262 continue;
263 }
264
265 let decoded = match decoder.decode(&packet) {
266 Ok(d) => d,
267 Err(_) => continue,
268 };
269
270 let spec = *decoded.spec();
271 let num_frames = decoded.frames();
272
273 if num_frames == 0 {
274 continue;
275 }
276
277 let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
278 sample_buf.copy_interleaved_ref(decoded);
279
280 let interleaved = sample_buf.samples();
281
282 if channels > 1 {
284 for chunk in interleaved.chunks(channels) {
285 let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
286 samples.push(mono);
287 }
288 } else {
289 samples.extend_from_slice(interleaved);
290 }
291 }
292
293 let duration_secs = samples.len() as f32 / sample_rate as f32;
294
295 let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
297 let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
298 let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
299 tracing::info!(
300 sample_rate = sample_rate,
301 channels = channels,
302 samples_before = samples.len(),
303 pre_min = pre_min,
304 pre_max = pre_max,
305 pre_rms = pre_rms,
306 "Audio before resampling"
307 );
308
309 let samples = if sample_rate != WHISPER_SAMPLE_RATE {
311 let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
312
313 let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
315 let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
316 let post_rms = (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
317 tracing::info!(
318 samples_after = resampled.len(),
319 post_min = post_min,
320 post_max = post_max,
321 post_rms = post_rms,
322 "Audio after resampling"
323 );
324 resampled
325 } else {
326 tracing::info!("Audio already at 16kHz, no resampling needed");
327 samples
328 };
329
330 Ok((samples, duration_secs))
331 }
332
333 fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
335 use rubato::{FftFixedIn, Resampler};
336
337 if from_rate == to_rate {
338 return samples.to_vec();
339 }
340
341 let chunk_size = 1024;
343 let mut resampler = FftFixedIn::<f32>::new(
344 from_rate as usize,
345 to_rate as usize,
346 chunk_size,
347 2, 1, )
350 .expect("Failed to create resampler");
351
352 let mut output = Vec::new();
353 let mut pos = 0;
354
355 while pos < samples.len() {
357 let end = (pos + chunk_size).min(samples.len());
358 let chunk = &samples[pos..end];
359
360 let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
362 let mut padded = chunk.to_vec();
363 padded.resize(chunk_size, 0.0);
364 padded
365 } else {
366 chunk.to_vec()
367 };
368
369 let input = vec![input_chunk];
370 let resampled = resampler.process(&input, None).expect("Resampling failed");
371
372 if !resampled.is_empty() && !resampled[0].is_empty() {
373 output.extend_from_slice(&resampled[0]);
374 }
375
376 pos += chunk_size;
377 }
378
379 let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
381 output.truncate(expected_len);
382
383 output
384 }
385}
386
387#[cfg(feature = "whisper")]
388pub use audio::*;
389
390#[cfg(feature = "whisper")]
395mod inference {
396 use super::*;
397 use candle_core::{DType, Device, IndexOp, Tensor};
398 use candle_nn::VarBuilder;
399 use candle_transformers::models::whisper::{self as m, audio, Config};
400 use hf_hub::{api::sync::Api, Repo, RepoType};
401 use tokenizers::Tokenizer;
402
403 pub struct WhisperTranscriber {
405 model: Model,
406 tokenizer: Tokenizer,
407 config: Config,
408 mel_filters: Vec<f32>,
409 device: Device,
410 }
411
412 #[allow(dead_code)]
413 enum Model {
414 Normal(m::model::Whisper),
415 Quantized(m::quantized_model::Whisper),
416 }
417
418 impl WhisperTranscriber {
419 pub fn new(config: &WhisperConfig) -> Result<Self> {
421 let device = Self::select_device();
423 tracing::info!(device = ?device, "Using device for Whisper");
424 let model_id = match config.model_name.as_str() {
425 "whisper-small-en" => "openai/whisper-small.en",
426 "whisper-small" => "openai/whisper-small",
427 "whisper-tiny.en" => "openai/whisper-tiny.en",
428 "whisper-tiny" => "openai/whisper-tiny",
429 "whisper-base.en" => "openai/whisper-base.en",
430 "whisper-base" => "openai/whisper-base",
431 "whisper-medium.en" => "openai/whisper-medium.en",
432 "whisper-medium" => "openai/whisper-medium",
433 "whisper-large-v3" => "openai/whisper-large-v3",
434 other => other, };
436
437 tracing::info!(model_id = model_id, "Loading Whisper model");
438
439 let api = Api::new().map_err(|e| WhisperError::DownloadError {
440 cause: e.to_string(),
441 })?;
442 let repo = api.repo(Repo::with_revision(
443 model_id.to_string(),
444 RepoType::Model,
445 "main".to_string(),
446 ));
447
448 let config_path = repo.get("config.json").map_err(|e| WhisperError::DownloadError {
450 cause: format!("Failed to download config.json: {}", e),
451 })?;
452 let tokenizer_path = repo.get("tokenizer.json").map_err(|e| WhisperError::DownloadError {
453 cause: format!("Failed to download tokenizer.json: {}", e),
454 })?;
455 let model_path = repo.get("model.safetensors").map_err(|e| WhisperError::DownloadError {
456 cause: format!("Failed to download model.safetensors: {}", e),
457 })?;
458
459 let config_str = std::fs::read_to_string(&config_path).map_err(|e| WhisperError::InferenceError {
461 cause: format!("Failed to read config: {}", e),
462 })?;
463 let model_config: Config = serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
464 cause: format!("Failed to parse config: {}", e),
465 })?;
466
467 let tokenizer = Tokenizer::from_file(&tokenizer_path)
469 .map_err(|e| WhisperError::InferenceError {
470 cause: format!("Failed to load tokenizer: {}", e),
471 })?;
472
473 let mel_bytes = match model_config.num_mel_bins {
475 80 => include_bytes!("melfilters.bytes").as_slice(),
476 128 => include_bytes!("melfilters128.bytes").as_slice(),
477 n => return Err(WhisperError::InferenceError {
478 cause: format!("Unsupported number of mel bins: {}", n),
479 }.into()),
480 };
481 let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
482 <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
483
484 let vb = unsafe {
486 VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
487 .map_err(|e| WhisperError::InferenceError {
488 cause: format!("Failed to load model weights: {}", e),
489 })?
490 };
491 let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone())
492 .map_err(|e| WhisperError::InferenceError {
493 cause: format!("Failed to load Whisper model: {}", e),
494 })?);
495
496 tracing::info!("Whisper model loaded successfully");
497
498 Ok(Self {
499 model,
500 tokenizer,
501 config: model_config,
502 mel_filters,
503 device,
504 })
505 }
506
507 fn select_device() -> Device {
509 #[cfg(feature = "metal")]
511 {
512 if let Ok(device) = Device::new_metal(0) {
513 tracing::info!("Metal GPU available");
514 return device;
515 }
516 }
517
518 #[cfg(feature = "cuda")]
520 {
521 if let Ok(device) = Device::new_cuda(0) {
522 tracing::info!("CUDA GPU available");
523 return device;
524 }
525 }
526
527 tracing::info!("Using CPU (no GPU acceleration)");
529 Device::Cpu
530 }
531
532 pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
534 let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
536
537 let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
539 let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
540 let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
541 let audio_rms = (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
542
543 tracing::info!(
544 duration = duration_secs,
545 samples = pcm_data.len(),
546 min = audio_min,
547 max = audio_max,
548 mean = audio_mean,
549 rms = audio_rms,
550 "Audio decoded"
551 );
552
553 self.transcribe_pcm(&pcm_data, duration_secs)
554 }
555
556 pub fn transcribe_pcm(&mut self, pcm_data: &[f32], duration_secs: f32) -> Result<TranscriptionResult> {
558 const CHUNK_LENGTH: usize = 30 * 16000; const N_FRAMES: usize = 3000; const SAMPLE_RATE: f32 = 16000.0;
562
563 let silence_threshold = 0.01; let window_size = 1600; let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
568 let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
569
570 let trimmed_start = start_sample as f32 / SAMPLE_RATE;
571 let trimmed_end = end_sample as f32 / SAMPLE_RATE;
572
573 tracing::info!(
574 start_sample = start_sample,
575 end_sample = end_sample,
576 trimmed_start_sec = trimmed_start,
577 trimmed_end_sec = trimmed_end,
578 original_duration = duration_secs,
579 "Trimmed silence"
580 );
581
582 let pcm_data = &pcm_data[start_sample..end_sample];
584 let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
585
586 let mut all_text = String::new();
587 let mut segments = Vec::new();
588
589 let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
591
592 for chunk_idx in 0..num_chunks {
593 let chunk_start = chunk_idx * CHUNK_LENGTH;
594 let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
595 let chunk = &pcm_data[chunk_start..chunk_end];
596
597 let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
599 let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
600
601 tracing::info!(
602 chunk = chunk_idx + 1,
603 total = num_chunks,
604 start = start_time,
605 end = end_time,
606 "Processing chunk"
607 );
608
609 match &mut self.model {
611 Model::Normal(m) => m.decoder.reset_kv_cache(),
612 Model::Quantized(m) => m.decoder.reset_kv_cache(),
613 }
614
615 let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
617 let n_mels = self.config.num_mel_bins;
618 let mel_len = mel.len();
619 let n_frames = mel_len / n_mels;
620
621 if chunk_idx == 0 {
622 tracing::info!(
624 num_mel_bins = self.config.num_mel_bins,
625 max_source_positions = self.config.max_source_positions,
626 max_target_positions = self.config.max_target_positions,
627 "Model config"
628 );
629
630 let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
632 let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
633 let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
634
635 tracing::info!(
636 mel_len = mel_len,
637 n_mels = n_mels,
638 n_frames = n_frames,
639 chunk_samples = chunk.len(),
640 expected_frames = 3000,
641 mel_min = mel_min,
642 mel_max = mel_max,
643 mel_mean = mel_mean,
644 "Mel spectrogram computed"
645 );
646 }
647
648 let mel = if n_frames < N_FRAMES {
652 let mut padded = vec![0.0f32; n_mels * N_FRAMES];
654 for bin in 0..n_mels {
655 let src_start = bin * n_frames;
656 let dst_start = bin * N_FRAMES;
657 padded[dst_start..dst_start + n_frames].copy_from_slice(&mel[src_start..src_start + n_frames]);
658 }
659 padded
660 } else if n_frames > N_FRAMES {
661 let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
663 for bin in 0..n_mels {
664 let src_start = bin * n_frames;
665 let dst_start = bin * N_FRAMES;
666 truncated[dst_start..dst_start + N_FRAMES].copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
667 }
668 truncated
669 } else {
670 mel
671 };
672
673 let mel = Tensor::from_vec(
674 mel,
675 (1, n_mels, N_FRAMES),
676 &self.device,
677 ).map_err(|e| WhisperError::InferenceError {
678 cause: format!("Failed to create mel tensor: {}", e),
679 })?;
680
681 if chunk_idx == 0 {
682 let mel_shape = mel.shape();
683 tracing::info!(
684 mel_shape = ?mel_shape,
685 "Mel tensor shape"
686 );
687 }
688
689 let audio_features = match &mut self.model {
691 Model::Normal(m) => m.encoder.forward(&mel, true),
692 Model::Quantized(m) => m.encoder.forward(&mel, true),
693 }.map_err(|e| WhisperError::InferenceError {
694 cause: format!("Encoder forward failed: {}", e),
695 })?;
696
697 if chunk_idx == 0 {
698 let af_shape = audio_features.shape();
699 tracing::info!(
700 audio_features_shape = ?af_shape,
701 "Audio features from encoder"
702 );
703 }
704
705 let sot_token = self.token_id(m::SOT_TOKEN)?;
707 let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
708 let eot_token = self.token_id(m::EOT_TOKEN)?;
709 let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
710
711 if chunk_idx == 0 {
712 let en_token = self.tokenizer.token_to_id("<|en|>");
713 tracing::info!(
714 sot = sot_token,
715 transcribe = transcribe_token,
716 eot = eot_token,
717 no_timestamps = no_timestamps_token,
718 en_token = ?en_token,
719 "Special tokens"
720 );
721 }
722
723 let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
727
728 let is_english_only = self.config.vocab_size == 51864;
730
731 let tokens = if is_english_only {
732 vec![sot_token, transcribe_token, no_timestamps_token]
734 } else if has_language_token {
735 let language_token = self.token_id("<|en|>")?;
737 vec![sot_token, language_token, transcribe_token, no_timestamps_token]
738 } else {
739 vec![sot_token, transcribe_token, no_timestamps_token]
741 };
742
743 if chunk_idx == 0 {
744 tracing::info!(
745 is_english_only = is_english_only,
746 vocab_size = self.config.vocab_size,
747 prompt_tokens = ?tokens,
748 "Initial prompt"
749 );
750 }
751 let mut all_tokens = tokens.clone();
752
753 let sample_len = self.config.max_target_positions / 2;
755 let mut repeat_count = 0;
756 let mut last_token: Option<u32> = None;
757
758 let suppress_tokens = &self.config.suppress_tokens;
760
761 for i in 0..sample_len {
762 let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
766 .and_then(|t| t.unsqueeze(0))
767 .map_err(|e| WhisperError::InferenceError {
768 cause: format!("Failed to create tokens tensor: {}", e),
769 })?;
770
771 if chunk_idx == 0 && i < 3 {
772 tracing::info!(
773 step = i,
774 all_tokens_len = all_tokens.len(),
775 tokens_shape = ?tokens_tensor.shape(),
776 "Decoder input"
777 );
778 }
779
780 let logits = match &mut self.model {
783 Model::Normal(m) => {
784 let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
785 .map_err(|e| WhisperError::InferenceError {
786 cause: format!("Decoder forward failed: {}", e),
787 })?;
788 m.decoder.final_linear(&hidden)
789 .map_err(|e| WhisperError::InferenceError {
790 cause: format!("Final linear failed: {}", e),
791 })?
792 }
793 Model::Quantized(m) => {
794 let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
795 .map_err(|e| WhisperError::InferenceError {
796 cause: format!("Decoder forward failed: {}", e),
797 })?;
798 m.decoder.final_linear(&hidden)
799 .map_err(|e| WhisperError::InferenceError {
800 cause: format!("Final linear failed: {}", e),
801 })?
802 }
803 };
804
805 if chunk_idx == 0 && i == 0 {
806 tracing::info!(
807 logits_shape = ?logits.shape(),
808 "Decoder output logits"
809 );
810 }
811
812 let (_, seq_len, _) = logits.dims3().map_err(|e| WhisperError::InferenceError {
814 cause: format!("Failed to get logits dims: {}", e),
815 })?;
816 let mut logits_vec = logits.i((0, seq_len - 1, ..))
817 .and_then(|t| t.to_vec1::<f32>())
818 .map_err(|e| WhisperError::InferenceError {
819 cause: format!("Failed to extract logits: {}", e),
820 })?;
821
822 for &token_id in suppress_tokens.iter() {
824 if (token_id as usize) < logits_vec.len() {
825 logits_vec[token_id as usize] = f32::NEG_INFINITY;
826 }
827 }
828
829 if all_tokens.len() < 10 {
831 logits_vec[eot_token as usize] = f32::NEG_INFINITY;
832 }
833
834 logits_vec[sot_token as usize] = f32::NEG_INFINITY;
838 logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
839 logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
840 for token_id in 50257..logits_vec.len() {
842 logits_vec[token_id] = f32::NEG_INFINITY;
843 }
844
845 if chunk_idx == 0 && i == 0 {
846 tracing::info!(
847 suppress_count = suppress_tokens.len(),
848 eot_suppressed = all_tokens.len() < 10,
849 "Applied token suppression"
850 );
851 }
852
853 let next_token = logits_vec
855 .iter()
856 .enumerate()
857 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
858 .map(|(idx, _)| idx as u32)
859 .unwrap_or(eot_token);
860
861 if chunk_idx == 0 && i < 5 {
862 let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
863 let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
864 tracing::info!(
865 step = i,
866 next_token = next_token,
867 max_logit = max_logit,
868 min_logit = min_logit,
869 "Decoding step"
870 );
871 }
872
873 if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
874 if chunk_idx == 0 && i < 5 {
875 tracing::info!(next_token = next_token, eot = eot_token, "Stopping: EOT or invalid token");
876 }
877 break;
878 }
879
880 if Some(next_token) == last_token {
882 repeat_count += 1;
883 if repeat_count > 3 {
884 tracing::debug!("Breaking due to token repetition");
885 break;
886 }
887 } else {
888 repeat_count = 0;
889 }
890 last_token = Some(next_token);
891
892 all_tokens.push(next_token);
893 }
894
895 let prompt_len = if is_english_only { 3 } else { 4 };
897
898 if chunk_idx == 0 {
899 tracing::info!(
900 prompt_tokens = ?&all_tokens[..prompt_len],
901 generated_tokens = ?&all_tokens[prompt_len..],
902 total = all_tokens.len(),
903 "Generated tokens for chunk"
904 );
905 }
906
907 let chunk_text = self.tokenizer
908 .decode(&all_tokens[prompt_len..], true) .map_err(|e| WhisperError::InferenceError {
910 cause: format!("Failed to decode tokens: {}", e),
911 })?;
912
913 let trimmed_text = chunk_text.trim();
914 if !trimmed_text.is_empty() {
915 if !all_text.is_empty() {
916 all_text.push(' ');
917 }
918 all_text.push_str(trimmed_text);
919
920 segments.push(TranscriptionSegment {
921 start: start_time,
922 end: end_time,
923 text: trimmed_text.to_string(),
924 });
925 }
926 }
927
928 Ok(TranscriptionResult {
929 text: all_text.trim().to_string(),
930 language: "en".to_string(),
931 duration_secs,
932 segments,
933 })
934 }
935
936 fn token_id(&self, token: &str) -> Result<u32> {
937 self.tokenizer
938 .token_to_id(token)
939 .ok_or_else(|| WhisperError::InferenceError {
940 cause: format!("Token '{}' not found in vocabulary", token),
941 }.into())
942 }
943 }
944
945 fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
947 for i in (0..samples.len()).step_by(window_size) {
948 let end = (i + window_size).min(samples.len());
949 let window = &samples[i..end];
950 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
951 if rms > threshold {
952 return i.saturating_sub(window_size);
954 }
955 }
956 0 }
958
959 fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
961 for i in (0..samples.len()).rev().step_by(window_size) {
962 let start = i.saturating_sub(window_size);
963 let window = &samples[start..=i.min(samples.len() - 1)];
964 let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
965 if rms > threshold {
966 return (i + window_size).min(samples.len());
968 }
969 }
970 samples.len() }
972}
973
974#[cfg(feature = "whisper")]
975pub use inference::WhisperTranscriber;
976
977#[cfg(test)]
982mod tests {
983 use super::*;
984
985 #[test]
986 fn whisper_model_registry() {
987 let default = default_whisper_model_info();
988 assert_eq!(default.name, "whisper-small-en");
989 assert!(default.is_default);
990 assert_eq!(default.language, "en");
991
992 let unknown = get_whisper_model_info("nonexistent");
994 assert_eq!(unknown.name, "whisper-small-en");
995 }
996
997 #[test]
998 fn whisper_config_defaults() {
999 let config = WhisperConfig::default();
1000 assert_eq!(config.model_name, "whisper-small-en");
1001 }
1002}