Skip to main content

memvid_core/
whisper.rs

1// Safe expect: Static Whisper model lookup with guaranteed default.
2#![allow(clippy::unwrap_used, clippy::expect_used)]
3//! Whisper audio transcription with Candle inference.
4//!
5//! This module provides complete Whisper transcription functionality including:
6//! - Audio decoding (MP3, WAV, FLAC, etc.) via symphonia
7//! - Resampling to 16kHz via rubato
8//! - Whisper model inference via candle-transformers
9//! - Automatic model download from `HuggingFace` Hub
10
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14use crate::MemvidError;
15
16// These are only used when whisper feature is enabled
17#[cfg(feature = "whisper")]
18use crate::Result;
19#[cfg(feature = "whisper")]
20use std::path::Path;
21
22// ============================================================================
23// Model Registry
24// ============================================================================
25
26/// Quantization type for Whisper models
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum QuantizationType {
29    /// Full precision FP32 (default, highest accuracy)
30    FP32,
31    /// 8-bit quantization (~75% smaller, ~15-20% faster)
32    Q8K,
33    /// 4-bit quantization (~87.5% smaller, ~25-30% faster)
34    Q4K,
35}
36
37impl Default for QuantizationType {
38    fn default() -> Self {
39        Self::FP32
40    }
41}
42
43/// Available Whisper models with verified `HuggingFace` model IDs
44#[derive(Debug, Clone)]
45pub struct WhisperModelInfo {
46    /// Model identifier for `HuggingFace`
47    pub model_id: &'static str,
48    /// Human-readable name
49    pub name: &'static str,
50    /// Approximate model size in MB
51    pub size_mb: f32,
52    /// Whether this is the default model
53    pub is_default: bool,
54    /// Language (e.g., "en" for English-only models, "multilingual" for others)
55    pub language: &'static str,
56    /// Quantization type (FP32, Q8K, Q4K)
57    pub quantization: QuantizationType,
58    /// Model file format ("safetensors" for FP32, "gguf" for quantized)
59    pub file_format: &'static str,
60}
61
62/// Available Whisper models registry
63pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
64    // FP32 models (default, highest accuracy)
65    WhisperModelInfo {
66        model_id: "openai/whisper-small.en",
67        name: "whisper-small-en",
68        size_mb: 244.0,
69        is_default: true,
70        language: "en",
71        quantization: QuantizationType::FP32,
72        file_format: "safetensors",
73    },
74    WhisperModelInfo {
75        model_id: "openai/whisper-small",
76        name: "whisper-small",
77        size_mb: 244.0,
78        is_default: false,
79        language: "multilingual",
80        quantization: QuantizationType::FP32,
81        file_format: "safetensors",
82    },
83    // Tiny FP32 models (faster, less accurate)
84    WhisperModelInfo {
85        model_id: "openai/whisper-tiny.en",
86        name: "whisper-tiny-en",
87        size_mb: 75.0,
88        is_default: false,
89        language: "en",
90        quantization: QuantizationType::FP32,
91        file_format: "safetensors",
92    },
93    // Q8K quantized tiny models (~75% smaller, faster)
94    // Uses lmz/candle-whisper quantized models from HuggingFace
95    WhisperModelInfo {
96        model_id: "lmz/candle-whisper",
97        name: "whisper-tiny-en-q8k",
98        size_mb: 19.0,
99        is_default: false,
100        language: "en",
101        quantization: QuantizationType::Q8K,
102        file_format: "gguf",
103    },
104    WhisperModelInfo {
105        model_id: "lmz/candle-whisper",
106        name: "whisper-tiny-q8k",
107        size_mb: 19.0,
108        is_default: false,
109        language: "multilingual",
110        quantization: QuantizationType::Q8K,
111        file_format: "gguf",
112    },
113];
114
115/// Get model info by name, defaults to whisper-small-en
116#[must_use]
117pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
118    WHISPER_MODELS
119        .iter()
120        .find(|m| m.name == name || m.model_id == name)
121        .unwrap_or_else(|| {
122            WHISPER_MODELS
123                .iter()
124                .find(|m| m.is_default)
125                .expect("default whisper model")
126        })
127}
128
129/// Get the default model info
130#[must_use]
131pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
132    WHISPER_MODELS
133        .iter()
134        .find(|m| m.is_default)
135        .expect("default whisper model exists")
136}
137
138// ============================================================================
139// Whisper Model Configuration
140// ============================================================================
141
142/// Configuration for Whisper model initialization
143#[derive(Debug, Clone)]
144pub struct WhisperConfig {
145    /// Model name (e.g., "whisper-small-en")
146    pub model_name: String,
147    /// Directory where models are cached
148    pub models_dir: PathBuf,
149    /// Whether to run in offline mode (no downloads)
150    pub offline: bool,
151}
152
153impl Default for WhisperConfig {
154    fn default() -> Self {
155        let models_dir = std::env::var("MEMVID_MODELS_DIR")
156            .ok()
157            .map(PathBuf::from)
158            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
159            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
160
161        let model_name = std::env::var("MEMVID_WHISPER_MODEL")
162            .unwrap_or_else(|_| "whisper-small-en".to_string());
163
164        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
165
166        Self {
167            model_name,
168            models_dir,
169            offline,
170        }
171    }
172}
173
174impl WhisperConfig {
175    /// Create config with Q8K quantized tiny model (~19 MB, very fast)
176    ///
177    /// Uses lmz/candle-whisper quantized models from HuggingFace.
178    /// Trade-off: Lower accuracy than whisper-small, but much faster.
179    #[must_use]
180    pub fn with_quantization() -> Self {
181        Self {
182            model_name: "whisper-tiny-en-q8k".to_string(),
183            ..Default::default()
184        }
185    }
186
187    /// Create config with specific model name
188    #[must_use]
189    pub fn with_model(model_name: impl Into<String>) -> Self {
190        Self {
191            model_name: model_name.into(),
192            ..Default::default()
193        }
194    }
195
196    /// Create config for multilingual Q8K quantized tiny model
197    #[must_use]
198    pub fn multilingual_quantized() -> Self {
199        Self {
200            model_name: "whisper-tiny-q8k".to_string(),
201            ..Default::default()
202        }
203    }
204
205    /// Create config for tiny FP32 model (75 MB, faster than small)
206    #[must_use]
207    pub fn tiny() -> Self {
208        Self {
209            model_name: "whisper-tiny-en".to_string(),
210            ..Default::default()
211        }
212    }
213}
214
215// ============================================================================
216// Whisper Error Types
217// ============================================================================
218
219/// Whisper-specific errors
220#[derive(Debug, thiserror::Error)]
221pub enum WhisperError {
222    /// Model not found
223    #[error("Whisper model '{model}' not found. {hint}")]
224    ModelNotFound { model: String, hint: String },
225
226    /// Audio decode failed
227    #[error("Failed to decode audio at {path:?}: {cause}")]
228    AudioDecodeError { path: PathBuf, cause: String },
229
230    /// Audio bytes decode failed
231    #[error("Failed to decode audio bytes: {cause}")]
232    AudioBytesDecodeError { cause: String },
233
234    /// Inference error
235    #[error("Whisper inference error: {cause}")]
236    InferenceError { cause: String },
237
238    /// Model download failed
239    #[error("Failed to download Whisper model: {cause}")]
240    DownloadError { cause: String },
241}
242
243impl From<WhisperError> for MemvidError {
244    fn from(err: WhisperError) -> Self {
245        MemvidError::ExtractionFailed {
246            reason: err.to_string().into_boxed_str(),
247        }
248    }
249}
250
251// ============================================================================
252// Transcription Result
253// ============================================================================
254
255/// Result of audio transcription
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct TranscriptionResult {
258    /// The transcribed text
259    pub text: String,
260    /// Language detected or specified
261    pub language: String,
262    /// Duration of audio in seconds
263    pub duration_secs: f32,
264    /// Optional timestamps for segments
265    #[serde(default)]
266    pub segments: Vec<TranscriptionSegment>,
267}
268
269/// A segment of transcription with timestamps
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct TranscriptionSegment {
272    /// Start time in seconds
273    pub start: f32,
274    /// End time in seconds
275    pub end: f32,
276    /// Transcribed text for this segment
277    pub text: String,
278}
279
280// ============================================================================
281// Audio Decoding (Feature-gated)
282// ============================================================================
283
284#[cfg(feature = "whisper")]
285mod audio {
286    use super::*;
287    use std::fs::File;
288    use symphonia::core::audio::SampleBuffer;
289    use symphonia::core::codecs::DecoderOptions;
290    use symphonia::core::formats::FormatOptions;
291    use symphonia::core::io::MediaSourceStream;
292    use symphonia::core::meta::MetadataOptions;
293    use symphonia::core::probe::Hint;
294
295    /// Whisper sample rate (always 16kHz)
296    pub const WHISPER_SAMPLE_RATE: u32 = 16000;
297
298    /// Decode audio file to f32 samples, resampling to 16kHz mono
299    pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
300        let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
301            path: path.to_path_buf(),
302            cause: e.to_string(),
303        })?;
304
305        let mss = MediaSourceStream::new(Box::new(file), Default::default());
306
307        // Create a hint based on file extension
308        let mut hint = Hint::new();
309        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
310            hint.with_extension(ext);
311        }
312
313        // Probe the media source
314        let format_opts = FormatOptions::default();
315        let metadata_opts = MetadataOptions::default();
316        let probed = symphonia::default::get_probe()
317            .format(&hint, mss, &format_opts, &metadata_opts)
318            .map_err(|e| WhisperError::AudioDecodeError {
319                path: path.to_path_buf(),
320                cause: format!("Failed to probe audio format: {}", e),
321            })?;
322
323        let mut format = probed.format;
324
325        // Find the first audio track
326        let track = format
327            .tracks()
328            .iter()
329            .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
330            .ok_or_else(|| WhisperError::AudioDecodeError {
331                path: path.to_path_buf(),
332                cause: "No audio track found".to_string(),
333            })?;
334
335        let track_id = track.id;
336        let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
337        let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
338
339        // Create decoder
340        let decoder_opts = DecoderOptions::default();
341        let mut decoder = symphonia::default::get_codecs()
342            .make(&track.codec_params, &decoder_opts)
343            .map_err(|e| WhisperError::AudioDecodeError {
344                path: path.to_path_buf(),
345                cause: format!("Failed to create decoder: {}", e),
346            })?;
347
348        let mut samples: Vec<f32> = Vec::new();
349
350        // Decode all packets
351        loop {
352            let packet = match format.next_packet() {
353                Ok(p) => p,
354                Err(symphonia::core::errors::Error::IoError(e))
355                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
356                {
357                    break;
358                }
359                Err(_) => break,
360            };
361
362            if packet.track_id() != track_id {
363                continue;
364            }
365
366            let decoded = match decoder.decode(&packet) {
367                Ok(d) => d,
368                Err(_) => continue,
369            };
370
371            let spec = *decoded.spec();
372            let num_frames = decoded.frames();
373
374            if num_frames == 0 {
375                continue;
376            }
377
378            let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
379            sample_buf.copy_interleaved_ref(decoded);
380
381            let interleaved = sample_buf.samples();
382
383            // Convert to mono by averaging channels
384            if channels > 1 {
385                for chunk in interleaved.chunks(channels) {
386                    let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
387                    samples.push(mono);
388                }
389            } else {
390                samples.extend_from_slice(interleaved);
391            }
392        }
393
394        let duration_secs = samples.len() as f32 / sample_rate as f32;
395
396        // Log pre-resampling stats
397        let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
398        let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
399        let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
400        tracing::info!(
401            sample_rate = sample_rate,
402            channels = channels,
403            samples_before = samples.len(),
404            pre_min = pre_min,
405            pre_max = pre_max,
406            pre_rms = pre_rms,
407            "Audio before resampling"
408        );
409
410        // High-quality sinc resampling to 16kHz
411        let samples = if sample_rate != WHISPER_SAMPLE_RATE {
412            let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
413
414            // Log post-resampling stats
415            let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
416            let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
417            let post_rms =
418                (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
419            tracing::info!(
420                samples_after = resampled.len(),
421                post_min = post_min,
422                post_max = post_max,
423                post_rms = post_rms,
424                "Audio after resampling"
425            );
426            resampled
427        } else {
428            tracing::info!("Audio already at 16kHz, no resampling needed");
429            samples
430        };
431
432        Ok((samples, duration_secs))
433    }
434
435    /// Simple linear interpolation resampling
436    /// Note: rubato 1.0 changed API to require audio_core buffer types.
437    /// Using simple linear interpolation which is sufficient for Whisper mel spectrogram input.
438    fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
439        if from_rate == to_rate {
440            return samples.to_vec();
441        }
442
443        let ratio = to_rate as f64 / from_rate as f64;
444        let output_len = (samples.len() as f64 * ratio).ceil() as usize;
445        let mut output = Vec::with_capacity(output_len);
446
447        for i in 0..output_len {
448            let src_pos = i as f64 / ratio;
449            let src_idx = src_pos.floor() as usize;
450            let frac = (src_pos - src_idx as f64) as f32;
451
452            if src_idx + 1 < samples.len() {
453                // Linear interpolation between samples
454                let sample = samples[src_idx] * (1.0 - frac) + samples[src_idx + 1] * frac;
455                output.push(sample);
456            } else if src_idx < samples.len() {
457                output.push(samples[src_idx]);
458            }
459        }
460
461        output
462    }
463}
464
465#[cfg(feature = "whisper")]
466pub use audio::*;
467
468// ============================================================================
469// Whisper Transcriber (Candle Inference)
470// ============================================================================
471
472#[cfg(feature = "whisper")]
473mod inference {
474    use super::*;
475    use candle_core::{DType, Device, IndexOp, Tensor};
476    use candle_nn::VarBuilder;
477    use candle_transformers::models::whisper::{self as m, Config, audio};
478    use hf_hub::{Repo, RepoType, api::sync::Api};
479    use tokenizers::Tokenizer;
480
481    /// Whisper model wrapper for transcription
482    pub struct WhisperTranscriber {
483        model: Model,
484        tokenizer: Tokenizer,
485        config: Config,
486        mel_filters: Vec<f32>,
487        device: Device,
488    }
489
490    #[allow(dead_code)]
491    enum Model {
492        Normal(m::model::Whisper),
493        Quantized(m::quantized_model::Whisper),
494    }
495
496    impl WhisperTranscriber {
497        /// Create a new WhisperTranscriber, downloading the model if needed
498        pub fn new(config: &WhisperConfig) -> Result<Self> {
499            // Use GPU if available: Metal (macOS) or CUDA (NVIDIA)
500            let device = Self::select_device();
501            tracing::info!(device = ?device, "Using device for Whisper");
502
503            // Get model info from registry
504            let model_info = get_whisper_model_info(&config.model_name);
505            let is_quantized = model_info.quantization != QuantizationType::FP32;
506
507            tracing::info!(
508                model_name = %config.model_name,
509                model_id = %model_info.model_id,
510                quantization = ?model_info.quantization,
511                file_format = %model_info.file_format,
512                "Loading Whisper model"
513            );
514
515            let api = Api::new().map_err(|e| WhisperError::DownloadError {
516                cause: e.to_string(),
517            })?;
518            let repo = api.repo(Repo::with_revision(
519                model_info.model_id.to_string(),
520                RepoType::Model,
521                "main".to_string(),
522            ));
523
524            // Download config and tokenizer files
525            // For quantized models, config/tokenizer come from the base OpenAI model
526            let (config_path, tokenizer_path) = if is_quantized {
527                // Quantized tiny models need config from openai/whisper-tiny
528                let base_model_id = match model_info.language {
529                    "en" => "openai/whisper-tiny.en",
530                    _ => "openai/whisper-tiny",
531                };
532                let base_repo = api.repo(Repo::with_revision(
533                    base_model_id.to_string(),
534                    RepoType::Model,
535                    "main".to_string(),
536                ));
537
538                let cfg =
539                    base_repo
540                        .get("config.json")
541                        .map_err(|e| WhisperError::DownloadError {
542                            cause: format!("Failed to download config.json: {}", e),
543                        })?;
544                let tok =
545                    base_repo
546                        .get("tokenizer.json")
547                        .map_err(|e| WhisperError::DownloadError {
548                            cause: format!("Failed to download tokenizer.json: {}", e),
549                        })?;
550                (cfg, tok)
551            } else {
552                let cfg = repo
553                    .get("config.json")
554                    .map_err(|e| WhisperError::DownloadError {
555                        cause: format!("Failed to download config.json: {}", e),
556                    })?;
557                let tok = repo
558                    .get("tokenizer.json")
559                    .map_err(|e| WhisperError::DownloadError {
560                        cause: format!("Failed to download tokenizer.json: {}", e),
561                    })?;
562                (cfg, tok)
563            };
564
565            // Load config
566            let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
567                WhisperError::InferenceError {
568                    cause: format!("Failed to read config: {}", e),
569                }
570            })?;
571            let model_config: Config =
572                serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
573                    cause: format!("Failed to parse config: {}", e),
574                })?;
575
576            // Load tokenizer
577            let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
578                WhisperError::InferenceError {
579                    cause: format!("Failed to load tokenizer: {}", e),
580                }
581            })?;
582
583            // Load mel filters
584            let mel_bytes = match model_config.num_mel_bins {
585                80 => include_bytes!("melfilters.bytes").as_slice(),
586                128 => include_bytes!("melfilters128.bytes").as_slice(),
587                n => {
588                    return Err(WhisperError::InferenceError {
589                        cause: format!("Unsupported number of mel bins: {}", n),
590                    }
591                    .into());
592                }
593            };
594            let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
595            <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(
596                mel_bytes,
597                &mut mel_filters,
598            );
599
600            // Load model based on quantization type
601            let model = match model_info.quantization {
602                QuantizationType::FP32 => {
603                    // Download and load FP32 safetensors model
604                    let model_path =
605                        repo.get("model.safetensors")
606                            .map_err(|e| WhisperError::DownloadError {
607                                cause: format!("Failed to download model.safetensors: {}", e),
608                            })?;
609
610                    let vb = unsafe {
611                        VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
612                            .map_err(|e| WhisperError::InferenceError {
613                                cause: format!("Failed to load model weights: {}", e),
614                            })?
615                    };
616                    Model::Normal(m::model::Whisper::load(&vb, model_config.clone()).map_err(
617                        |e| WhisperError::InferenceError {
618                            cause: format!("Failed to load Whisper model: {}", e),
619                        },
620                    )?)
621                }
622                QuantizationType::Q8K | QuantizationType::Q4K => {
623                    // Download and load quantized GGUF model from lmz/candle-whisper
624                    // Available files: model-tiny-en-q80.gguf, model-tiny-q80.gguf, etc.
625                    let gguf_filename = match (model_info.language, model_info.quantization) {
626                        ("en", QuantizationType::Q8K) => "model-tiny-en-q80.gguf",
627                        ("en", QuantizationType::Q4K) => "model-tiny-en-q40.gguf",
628                        (_, QuantizationType::Q8K) => "model-tiny-q80.gguf",
629                        (_, QuantizationType::Q4K) => "model-tiny-q40.gguf",
630                        _ => "model-tiny-q80.gguf",
631                    };
632
633                    let model_path =
634                        repo.get(gguf_filename)
635                            .map_err(|e| WhisperError::DownloadError {
636                                cause: format!("Failed to download {}: {}", gguf_filename, e),
637                            })?;
638
639                    tracing::info!(
640                        gguf_file = %gguf_filename,
641                        quantization = ?model_info.quantization,
642                        "Loading quantized GGUF model"
643                    );
644
645                    let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
646                        &model_path,
647                        &device,
648                    )
649                    .map_err(|e| WhisperError::InferenceError {
650                        cause: format!("Failed to load quantized model: {}", e),
651                    })?;
652
653                    Model::Quantized(
654                        m::quantized_model::Whisper::load(&vb, model_config.clone()).map_err(
655                            |e| WhisperError::InferenceError {
656                                cause: format!("Failed to load quantized Whisper model: {}", e),
657                            },
658                        )?,
659                    )
660                }
661            };
662
663            tracing::info!("Whisper model loaded successfully");
664
665            Ok(Self {
666                model,
667                tokenizer,
668                config: model_config,
669                mel_filters,
670                device,
671            })
672        }
673
674        /// Select the best available device (GPU if available, otherwise CPU)
675        fn select_device() -> Device {
676            // Try Metal (macOS Apple Silicon / AMD)
677            #[cfg(feature = "metal")]
678            {
679                if let Ok(device) = Device::new_metal(0) {
680                    tracing::info!("Metal GPU available");
681                    return device;
682                }
683            }
684
685            // Try CUDA (NVIDIA GPUs)
686            #[cfg(feature = "cuda")]
687            {
688                if let Ok(device) = Device::new_cuda(0) {
689                    tracing::info!("CUDA GPU available");
690                    return device;
691                }
692            }
693
694            // Fallback to CPU
695            tracing::info!("Using CPU (no GPU acceleration)");
696            Device::Cpu
697        }
698
699        /// Transcribe an audio file
700        pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
701            // Decode audio to PCM
702            let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
703
704            // Check audio statistics
705            let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
706            let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
707            let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
708            let audio_rms =
709                (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
710
711            tracing::info!(
712                duration = duration_secs,
713                samples = pcm_data.len(),
714                min = audio_min,
715                max = audio_max,
716                mean = audio_mean,
717                rms = audio_rms,
718                "Audio decoded"
719            );
720
721            self.transcribe_pcm(&pcm_data, duration_secs)
722        }
723
724        /// Transcribe PCM audio samples (16kHz mono f32)
725        pub fn transcribe_pcm(
726            &mut self,
727            pcm_data: &[f32],
728            duration_secs: f32,
729        ) -> Result<TranscriptionResult> {
730            // Whisper processes audio in 30-second chunks
731            const CHUNK_LENGTH: usize = 30 * 16000; // 30 seconds at 16kHz
732            const N_FRAMES: usize = 3000; // frames per chunk
733            const SAMPLE_RATE: f32 = 16000.0;
734
735            // Detect and trim leading silence
736            let silence_threshold = 0.01; // RMS threshold for silence
737            let window_size = 1600; // 100ms windows at 16kHz
738
739            let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
740            let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
741
742            let trimmed_start = start_sample as f32 / SAMPLE_RATE;
743            let trimmed_end = end_sample as f32 / SAMPLE_RATE;
744
745            tracing::info!(
746                start_sample = start_sample,
747                end_sample = end_sample,
748                trimmed_start_sec = trimmed_start,
749                trimmed_end_sec = trimmed_end,
750                original_duration = duration_secs,
751                "Trimmed silence"
752            );
753
754            // Use trimmed audio
755            let pcm_data = &pcm_data[start_sample..end_sample];
756            let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
757
758            let mut all_text = String::new();
759            let mut segments = Vec::new();
760
761            // Process audio in chunks
762            let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
763
764            for chunk_idx in 0..num_chunks {
765                let chunk_start = chunk_idx * CHUNK_LENGTH;
766                let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
767                let chunk = &pcm_data[chunk_start..chunk_end];
768
769                // Adjust timestamps to account for trimmed silence
770                let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
771                let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
772
773                tracing::info!(
774                    chunk = chunk_idx + 1,
775                    total = num_chunks,
776                    start = start_time,
777                    end = end_time,
778                    "Processing chunk"
779                );
780
781                // Reset decoder KV cache for each new chunk
782                match &mut self.model {
783                    Model::Normal(m) => m.decoder.reset_kv_cache(),
784                    Model::Quantized(m) => m.decoder.reset_kv_cache(),
785                }
786
787                // Convert chunk to mel spectrogram
788                let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
789                let n_mels = self.config.num_mel_bins;
790                let mel_len = mel.len();
791                let n_frames = mel_len / n_mels;
792
793                if chunk_idx == 0 {
794                    // Print config for debugging
795                    tracing::info!(
796                        num_mel_bins = self.config.num_mel_bins,
797                        max_source_positions = self.config.max_source_positions,
798                        max_target_positions = self.config.max_target_positions,
799                        "Model config"
800                    );
801
802                    // Mel statistics
803                    let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
804                    let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
805                    let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
806
807                    tracing::info!(
808                        mel_len = mel_len,
809                        n_mels = n_mels,
810                        n_frames = n_frames,
811                        chunk_samples = chunk.len(),
812                        expected_frames = 3000,
813                        mel_min = mel_min,
814                        mel_max = mel_max,
815                        mel_mean = mel_mean,
816                        "Mel spectrogram computed"
817                    );
818                }
819
820                // Ensure we have exactly 3000 frames (pad or truncate)
821                // NOTE: mel array from pcm_to_mel is stored as [mel_bin_0_all_frames, mel_bin_1_all_frames, ...]
822                // So each mel bin has n_frames contiguous values: mel[bin * n_frames + frame]
823                let mel = if n_frames < N_FRAMES {
824                    // Pad each mel bin's frames with zeros to reach N_FRAMES
825                    let mut padded = vec![0.0f32; n_mels * N_FRAMES];
826                    for bin in 0..n_mels {
827                        let src_start = bin * n_frames;
828                        let dst_start = bin * N_FRAMES;
829                        padded[dst_start..dst_start + n_frames]
830                            .copy_from_slice(&mel[src_start..src_start + n_frames]);
831                    }
832                    padded
833                } else if n_frames > N_FRAMES {
834                    // Truncate each mel bin's frames to N_FRAMES
835                    let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
836                    for bin in 0..n_mels {
837                        let src_start = bin * n_frames;
838                        let dst_start = bin * N_FRAMES;
839                        truncated[dst_start..dst_start + N_FRAMES]
840                            .copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
841                    }
842                    truncated
843                } else {
844                    mel
845                };
846
847                let mel =
848                    Tensor::from_vec(mel, (1, n_mels, N_FRAMES), &self.device).map_err(|e| {
849                        WhisperError::InferenceError {
850                            cause: format!("Failed to create mel tensor: {}", e),
851                        }
852                    })?;
853
854                if chunk_idx == 0 {
855                    let mel_shape = mel.shape();
856                    tracing::info!(
857                        mel_shape = ?mel_shape,
858                        "Mel tensor shape"
859                    );
860                }
861
862                // Run encoder
863                let audio_features = match &mut self.model {
864                    Model::Normal(m) => m.encoder.forward(&mel, true),
865                    Model::Quantized(m) => m.encoder.forward(&mel, true),
866                }
867                .map_err(|e| WhisperError::InferenceError {
868                    cause: format!("Encoder forward failed: {}", e),
869                })?;
870
871                if chunk_idx == 0 {
872                    let af_shape = audio_features.shape();
873                    tracing::info!(
874                        audio_features_shape = ?af_shape,
875                        "Audio features from encoder"
876                    );
877                }
878
879                // Get special token IDs
880                let sot_token = self.token_id(m::SOT_TOKEN)?;
881                let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
882                let eot_token = self.token_id(m::EOT_TOKEN)?;
883                let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
884
885                if chunk_idx == 0 {
886                    let en_token = self.tokenizer.token_to_id("<|en|>");
887                    tracing::info!(
888                        sot = sot_token,
889                        transcribe = transcribe_token,
890                        eot = eot_token,
891                        no_timestamps = no_timestamps_token,
892                        en_token = ?en_token,
893                        "Special tokens"
894                    );
895                }
896
897                // Build initial prompt
898                // For English-only models (*.en), we DON'T use language token
899                // For multilingual models, we add language token after sot_token
900                let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
901
902                // English-only models have vocab size 51864, multilingual have 51865
903                let is_english_only = self.config.vocab_size == 51864;
904
905                let tokens = if is_english_only {
906                    // English-only: SOT -> transcribe -> notimestamps
907                    vec![sot_token, transcribe_token, no_timestamps_token]
908                } else if has_language_token {
909                    // Multilingual: SOT -> language -> transcribe -> notimestamps
910                    let language_token = self.token_id("<|en|>")?;
911                    vec![
912                        sot_token,
913                        language_token,
914                        transcribe_token,
915                        no_timestamps_token,
916                    ]
917                } else {
918                    // Fallback
919                    vec![sot_token, transcribe_token, no_timestamps_token]
920                };
921
922                if chunk_idx == 0 {
923                    tracing::info!(
924                        is_english_only = is_english_only,
925                        vocab_size = self.config.vocab_size,
926                        prompt_tokens = ?tokens,
927                        "Initial prompt"
928                    );
929                }
930                let mut all_tokens = tokens.clone();
931
932                // Autoregressive decoding with token suppression
933                let sample_len = self.config.max_target_positions / 2;
934                let mut repeat_count = 0;
935                let mut last_token: Option<u32> = None;
936
937                // Build suppression mask
938                let suppress_tokens = &self.config.suppress_tokens;
939
940                for i in 0..sample_len {
941                    // For autoregressive decoding with KV cache:
942                    // - First iteration: pass all prompt tokens, flush_kv_cache=true
943                    // - Subsequent iterations: pass only the new token, flush_kv_cache=false
944                    let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
945                        .and_then(|t| t.unsqueeze(0))
946                        .map_err(|e| WhisperError::InferenceError {
947                            cause: format!("Failed to create tokens tensor: {}", e),
948                        })?;
949
950                    if chunk_idx == 0 && i < 3 {
951                        tracing::info!(
952                            step = i,
953                            all_tokens_len = all_tokens.len(),
954                            tokens_shape = ?tokens_tensor.shape(),
955                            "Decoder input"
956                        );
957                    }
958
959                    // Get hidden states from decoder, then project to vocabulary
960                    // Always pass all tokens (candle doesn't use KV cache the same way as PyTorch)
961                    let logits = match &mut self.model {
962                        Model::Normal(m) => {
963                            let hidden = m
964                                .decoder
965                                .forward(&tokens_tensor, &audio_features, true)
966                                .map_err(|e| WhisperError::InferenceError {
967                                cause: format!("Decoder forward failed: {}", e),
968                            })?;
969                            m.decoder.final_linear(&hidden).map_err(|e| {
970                                WhisperError::InferenceError {
971                                    cause: format!("Final linear failed: {}", e),
972                                }
973                            })?
974                        }
975                        Model::Quantized(m) => {
976                            let hidden = m
977                                .decoder
978                                .forward(&tokens_tensor, &audio_features, true)
979                                .map_err(|e| WhisperError::InferenceError {
980                                cause: format!("Decoder forward failed: {}", e),
981                            })?;
982                            m.decoder.final_linear(&hidden).map_err(|e| {
983                                WhisperError::InferenceError {
984                                    cause: format!("Final linear failed: {}", e),
985                                }
986                            })?
987                        }
988                    };
989
990                    if chunk_idx == 0 && i == 0 {
991                        tracing::info!(
992                            logits_shape = ?logits.shape(),
993                            "Decoder output logits"
994                        );
995                    }
996
997                    // Get logits for last position
998                    let (_, seq_len, _) =
999                        logits.dims3().map_err(|e| WhisperError::InferenceError {
1000                            cause: format!("Failed to get logits dims: {}", e),
1001                        })?;
1002                    let mut logits_vec = logits
1003                        .i((0, seq_len - 1, ..))
1004                        .and_then(|t| t.to_vec1::<f32>())
1005                        .map_err(|e| WhisperError::InferenceError {
1006                            cause: format!("Failed to extract logits: {}", e),
1007                        })?;
1008
1009                    // Apply token suppression from config
1010                    for &token_id in suppress_tokens.iter() {
1011                        if (token_id as usize) < logits_vec.len() {
1012                            logits_vec[token_id as usize] = f32::NEG_INFINITY;
1013                        }
1014                    }
1015
1016                    // Suppress EOT token for first few steps to allow generation
1017                    if all_tokens.len() < 10 {
1018                        logits_vec[eot_token as usize] = f32::NEG_INFINITY;
1019                    }
1020
1021                    // Suppress all special tokens during generation:
1022                    // - SOT (50257), language tokens (50258-50261), task tokens (50358-50359),
1023                    // - no_timestamps (50362), and timestamp tokens (50363+)
1024                    logits_vec[sot_token as usize] = f32::NEG_INFINITY;
1025                    logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
1026                    logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
1027                    // Suppress all tokens from 50257 onward (special tokens) except those in normal vocab
1028                    for token_id in 50257..logits_vec.len() {
1029                        logits_vec[token_id] = f32::NEG_INFINITY;
1030                    }
1031
1032                    if chunk_idx == 0 && i == 0 {
1033                        tracing::info!(
1034                            suppress_count = suppress_tokens.len(),
1035                            eot_suppressed = all_tokens.len() < 10,
1036                            "Applied token suppression"
1037                        );
1038                    }
1039
1040                    // Find argmax
1041                    let next_token = logits_vec
1042                        .iter()
1043                        .enumerate()
1044                        .max_by(|(_, a), (_, b)| {
1045                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
1046                        })
1047                        .map(|(idx, _)| idx as u32)
1048                        .unwrap_or(eot_token);
1049
1050                    if chunk_idx == 0 && i < 5 {
1051                        let max_logit =
1052                            logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1053                        let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
1054                        tracing::info!(
1055                            step = i,
1056                            next_token = next_token,
1057                            max_logit = max_logit,
1058                            min_logit = min_logit,
1059                            "Decoding step"
1060                        );
1061                    }
1062
1063                    if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
1064                        if chunk_idx == 0 && i < 5 {
1065                            tracing::info!(
1066                                next_token = next_token,
1067                                eot = eot_token,
1068                                "Stopping: EOT or invalid token"
1069                            );
1070                        }
1071                        break;
1072                    }
1073
1074                    // Check for excessive repetition (stop if same token repeats >3 times)
1075                    if Some(next_token) == last_token {
1076                        repeat_count += 1;
1077                        if repeat_count > 3 {
1078                            tracing::debug!("Breaking due to token repetition");
1079                            break;
1080                        }
1081                    } else {
1082                        repeat_count = 0;
1083                    }
1084                    last_token = Some(next_token);
1085
1086                    all_tokens.push(next_token);
1087                }
1088
1089                // Decode tokens to text for this chunk
1090                let prompt_len = if is_english_only { 3 } else { 4 };
1091
1092                if chunk_idx == 0 {
1093                    tracing::info!(
1094                        prompt_tokens = ?&all_tokens[..prompt_len],
1095                        generated_tokens = ?&all_tokens[prompt_len..],
1096                        total = all_tokens.len(),
1097                        "Generated tokens for chunk"
1098                    );
1099                }
1100
1101                let chunk_text = self
1102                    .tokenizer
1103                    .decode(&all_tokens[prompt_len..], true) // Skip prompt tokens
1104                    .map_err(|e| WhisperError::InferenceError {
1105                        cause: format!("Failed to decode tokens: {}", e),
1106                    })?;
1107
1108                let trimmed_text = chunk_text.trim();
1109                if !trimmed_text.is_empty() {
1110                    if !all_text.is_empty() {
1111                        all_text.push(' ');
1112                    }
1113                    all_text.push_str(trimmed_text);
1114
1115                    segments.push(TranscriptionSegment {
1116                        start: start_time,
1117                        end: end_time,
1118                        text: trimmed_text.to_string(),
1119                    });
1120                }
1121            }
1122
1123            Ok(TranscriptionResult {
1124                text: all_text.trim().to_string(),
1125                language: "en".to_string(),
1126                duration_secs,
1127                segments,
1128            })
1129        }
1130
1131        fn token_id(&self, token: &str) -> Result<u32> {
1132            self.tokenizer.token_to_id(token).ok_or_else(|| {
1133                WhisperError::InferenceError {
1134                    cause: format!("Token '{}' not found in vocabulary", token),
1135                }
1136                .into()
1137            })
1138        }
1139    }
1140
1141    /// Find the sample index where speech starts (after leading silence)
1142    fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1143        for i in (0..samples.len()).step_by(window_size) {
1144            let end = (i + window_size).min(samples.len());
1145            let window = &samples[i..end];
1146            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1147            if rms > threshold {
1148                // Found speech, go back a bit to not cut off the start
1149                return i.saturating_sub(window_size);
1150            }
1151        }
1152        0 // No silence found, return start
1153    }
1154
1155    /// Find the sample index where speech ends (before trailing silence)
1156    fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1157        for i in (0..samples.len()).rev().step_by(window_size) {
1158            let start = i.saturating_sub(window_size);
1159            let window = &samples[start..=i.min(samples.len() - 1)];
1160            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1161            if rms > threshold {
1162                // Found speech, go forward a bit to not cut off the end
1163                return (i + window_size).min(samples.len());
1164            }
1165        }
1166        samples.len() // No silence found, return end
1167    }
1168}
1169
1170#[cfg(feature = "whisper")]
1171pub use inference::WhisperTranscriber;
1172
1173// ============================================================================
1174// Tests
1175// ============================================================================
1176
1177#[cfg(test)]
1178mod tests {
1179    use super::*;
1180
1181    #[test]
1182    fn whisper_model_registry() {
1183        let default = default_whisper_model_info();
1184        assert_eq!(default.name, "whisper-small-en");
1185        assert!(default.is_default);
1186        assert_eq!(default.language, "en");
1187
1188        // Unknown model returns default
1189        let unknown = get_whisper_model_info("nonexistent");
1190        assert_eq!(unknown.name, "whisper-small-en");
1191    }
1192
1193    #[test]
1194    fn whisper_config_defaults() {
1195        let config = WhisperConfig::default();
1196        assert_eq!(config.model_name, "whisper-small-en");
1197    }
1198}