memvid_core/
whisper.rs

1//! Whisper audio transcription with Candle inference.
2//!
3//! This module provides complete Whisper transcription functionality including:
4//! - Audio decoding (MP3, WAV, FLAC, etc.) via symphonia
5//! - Resampling to 16kHz via rubato
6//! - Whisper model inference via candle-transformers
7//! - Automatic model download from HuggingFace Hub
8
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12use crate::MemvidError;
13
14// These are only used when whisper feature is enabled
15#[cfg(feature = "whisper")]
16use crate::Result;
17#[cfg(feature = "whisper")]
18use std::path::Path;
19
20// ============================================================================
21// Model Registry
22// ============================================================================
23
24/// Available Whisper models with verified HuggingFace model IDs
25#[derive(Debug, Clone)]
26pub struct WhisperModelInfo {
27    /// Model identifier for HuggingFace
28    pub model_id: &'static str,
29    /// Human-readable name
30    pub name: &'static str,
31    /// Approximate model size in MB
32    pub size_mb: f32,
33    /// Whether this is the default model
34    pub is_default: bool,
35    /// Language (e.g., "en" for English-only models, "multilingual" for others)
36    pub language: &'static str,
37}
38
39/// Available Whisper models registry
40pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
41    WhisperModelInfo {
42        model_id: "openai/whisper-small.en",
43        name: "whisper-small-en",
44        size_mb: 244.0,
45        is_default: true,
46        language: "en",
47    },
48    WhisperModelInfo {
49        model_id: "openai/whisper-small",
50        name: "whisper-small",
51        size_mb: 244.0,
52        is_default: false,
53        language: "multilingual",
54    },
55];
56
57/// Get model info by name, defaults to whisper-small-en
58pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
59    WHISPER_MODELS
60        .iter()
61        .find(|m| m.name == name || m.model_id == name)
62        .unwrap_or_else(|| {
63            WHISPER_MODELS
64                .iter()
65                .find(|m| m.is_default)
66                .expect("default whisper model")
67        })
68}
69
70/// Get the default model info
71pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
72    WHISPER_MODELS
73        .iter()
74        .find(|m| m.is_default)
75        .expect("default whisper model exists")
76}
77
78// ============================================================================
79// Whisper Model Configuration
80// ============================================================================
81
82/// Configuration for Whisper model initialization
83#[derive(Debug, Clone)]
84pub struct WhisperConfig {
85    /// Model name (e.g., "whisper-small-en")
86    pub model_name: String,
87    /// Directory where models are cached
88    pub models_dir: PathBuf,
89    /// Whether to run in offline mode (no downloads)
90    pub offline: bool,
91}
92
93impl Default for WhisperConfig {
94    fn default() -> Self {
95        let models_dir = std::env::var("MEMVID_MODELS_DIR")
96            .ok()
97            .map(PathBuf::from)
98            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
99            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
100
101        let model_name = std::env::var("MEMVID_WHISPER_MODEL")
102            .unwrap_or_else(|_| "whisper-small-en".to_string());
103
104        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
105
106        Self {
107            model_name,
108            models_dir,
109            offline,
110        }
111    }
112}
113
114// ============================================================================
115// Whisper Error Types
116// ============================================================================
117
118/// Whisper-specific errors
119#[derive(Debug, thiserror::Error)]
120pub enum WhisperError {
121    /// Model not found
122    #[error("Whisper model '{model}' not found. {hint}")]
123    ModelNotFound { model: String, hint: String },
124
125    /// Audio decode failed
126    #[error("Failed to decode audio at {path:?}: {cause}")]
127    AudioDecodeError { path: PathBuf, cause: String },
128
129    /// Audio bytes decode failed
130    #[error("Failed to decode audio bytes: {cause}")]
131    AudioBytesDecodeError { cause: String },
132
133    /// Inference error
134    #[error("Whisper inference error: {cause}")]
135    InferenceError { cause: String },
136
137    /// Model download failed
138    #[error("Failed to download Whisper model: {cause}")]
139    DownloadError { cause: String },
140}
141
142impl From<WhisperError> for MemvidError {
143    fn from(err: WhisperError) -> Self {
144        MemvidError::ExtractionFailed {
145            reason: err.to_string().into_boxed_str(),
146        }
147    }
148}
149
150// ============================================================================
151// Transcription Result
152// ============================================================================
153
154/// Result of audio transcription
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TranscriptionResult {
157    /// The transcribed text
158    pub text: String,
159    /// Language detected or specified
160    pub language: String,
161    /// Duration of audio in seconds
162    pub duration_secs: f32,
163    /// Optional timestamps for segments
164    #[serde(default)]
165    pub segments: Vec<TranscriptionSegment>,
166}
167
168/// A segment of transcription with timestamps
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct TranscriptionSegment {
171    /// Start time in seconds
172    pub start: f32,
173    /// End time in seconds
174    pub end: f32,
175    /// Transcribed text for this segment
176    pub text: String,
177}
178
179// ============================================================================
180// Audio Decoding (Feature-gated)
181// ============================================================================
182
183#[cfg(feature = "whisper")]
184mod audio {
185    use super::*;
186    use std::fs::File;
187    use symphonia::core::audio::SampleBuffer;
188    use symphonia::core::codecs::DecoderOptions;
189    use symphonia::core::formats::FormatOptions;
190    use symphonia::core::io::MediaSourceStream;
191    use symphonia::core::meta::MetadataOptions;
192    use symphonia::core::probe::Hint;
193
194    /// Whisper sample rate (always 16kHz)
195    pub const WHISPER_SAMPLE_RATE: u32 = 16000;
196
197    /// Decode audio file to f32 samples, resampling to 16kHz mono
198    pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
199        let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
200            path: path.to_path_buf(),
201            cause: e.to_string(),
202        })?;
203
204        let mss = MediaSourceStream::new(Box::new(file), Default::default());
205
206        // Create a hint based on file extension
207        let mut hint = Hint::new();
208        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
209            hint.with_extension(ext);
210        }
211
212        // Probe the media source
213        let format_opts = FormatOptions::default();
214        let metadata_opts = MetadataOptions::default();
215        let probed = symphonia::default::get_probe()
216            .format(&hint, mss, &format_opts, &metadata_opts)
217            .map_err(|e| WhisperError::AudioDecodeError {
218                path: path.to_path_buf(),
219                cause: format!("Failed to probe audio format: {}", e),
220            })?;
221
222        let mut format = probed.format;
223
224        // Find the first audio track
225        let track = format
226            .tracks()
227            .iter()
228            .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
229            .ok_or_else(|| WhisperError::AudioDecodeError {
230                path: path.to_path_buf(),
231                cause: "No audio track found".to_string(),
232            })?;
233
234        let track_id = track.id;
235        let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
236        let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
237
238        // Create decoder
239        let decoder_opts = DecoderOptions::default();
240        let mut decoder = symphonia::default::get_codecs()
241            .make(&track.codec_params, &decoder_opts)
242            .map_err(|e| WhisperError::AudioDecodeError {
243                path: path.to_path_buf(),
244                cause: format!("Failed to create decoder: {}", e),
245            })?;
246
247        let mut samples: Vec<f32> = Vec::new();
248
249        // Decode all packets
250        loop {
251            let packet = match format.next_packet() {
252                Ok(p) => p,
253                Err(symphonia::core::errors::Error::IoError(e))
254                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
255                {
256                    break;
257                }
258                Err(_) => break,
259            };
260
261            if packet.track_id() != track_id {
262                continue;
263            }
264
265            let decoded = match decoder.decode(&packet) {
266                Ok(d) => d,
267                Err(_) => continue,
268            };
269
270            let spec = *decoded.spec();
271            let num_frames = decoded.frames();
272
273            if num_frames == 0 {
274                continue;
275            }
276
277            let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
278            sample_buf.copy_interleaved_ref(decoded);
279
280            let interleaved = sample_buf.samples();
281
282            // Convert to mono by averaging channels
283            if channels > 1 {
284                for chunk in interleaved.chunks(channels) {
285                    let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
286                    samples.push(mono);
287                }
288            } else {
289                samples.extend_from_slice(interleaved);
290            }
291        }
292
293        let duration_secs = samples.len() as f32 / sample_rate as f32;
294
295        // Log pre-resampling stats
296        let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
297        let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
298        let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
299        tracing::info!(
300            sample_rate = sample_rate,
301            channels = channels,
302            samples_before = samples.len(),
303            pre_min = pre_min,
304            pre_max = pre_max,
305            pre_rms = pre_rms,
306            "Audio before resampling"
307        );
308
309        // High-quality sinc resampling to 16kHz
310        let samples = if sample_rate != WHISPER_SAMPLE_RATE {
311            let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
312
313            // Log post-resampling stats
314            let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
315            let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
316            let post_rms =
317                (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
318            tracing::info!(
319                samples_after = resampled.len(),
320                post_min = post_min,
321                post_max = post_max,
322                post_rms = post_rms,
323                "Audio after resampling"
324            );
325            resampled
326        } else {
327            tracing::info!("Audio already at 16kHz, no resampling needed");
328            samples
329        };
330
331        Ok((samples, duration_secs))
332    }
333
334    /// High-quality sinc resampling using rubato
335    fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
336        use rubato::{FftFixedIn, Resampler};
337
338        if from_rate == to_rate {
339            return samples.to_vec();
340        }
341
342        // Create resampler
343        let chunk_size = 1024;
344        let mut resampler = FftFixedIn::<f32>::new(
345            from_rate as usize,
346            to_rate as usize,
347            chunk_size,
348            2, // sub_chunks for quality
349            1, // mono
350        )
351        .expect("Failed to create resampler");
352
353        let mut output = Vec::new();
354        let mut pos = 0;
355
356        // Process in chunks
357        while pos < samples.len() {
358            let end = (pos + chunk_size).min(samples.len());
359            let chunk = &samples[pos..end];
360
361            // Pad if needed
362            let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
363                let mut padded = chunk.to_vec();
364                padded.resize(chunk_size, 0.0);
365                padded
366            } else {
367                chunk.to_vec()
368            };
369
370            let input = vec![input_chunk];
371            let resampled = resampler.process(&input, None).expect("Resampling failed");
372
373            if !resampled.is_empty() && !resampled[0].is_empty() {
374                output.extend_from_slice(&resampled[0]);
375            }
376
377            pos += chunk_size;
378        }
379
380        // Trim to expected length
381        let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
382        output.truncate(expected_len);
383
384        output
385    }
386}
387
388#[cfg(feature = "whisper")]
389pub use audio::*;
390
391// ============================================================================
392// Whisper Transcriber (Candle Inference)
393// ============================================================================
394
395#[cfg(feature = "whisper")]
396mod inference {
397    use super::*;
398    use candle_core::{DType, Device, IndexOp, Tensor};
399    use candle_nn::VarBuilder;
400    use candle_transformers::models::whisper::{self as m, Config, audio};
401    use hf_hub::{Repo, RepoType, api::sync::Api};
402    use tokenizers::Tokenizer;
403
404    /// Whisper model wrapper for transcription
405    pub struct WhisperTranscriber {
406        model: Model,
407        tokenizer: Tokenizer,
408        config: Config,
409        mel_filters: Vec<f32>,
410        device: Device,
411    }
412
413    #[allow(dead_code)]
414    enum Model {
415        Normal(m::model::Whisper),
416        Quantized(m::quantized_model::Whisper),
417    }
418
419    impl WhisperTranscriber {
420        /// Create a new WhisperTranscriber, downloading the model if needed
421        pub fn new(config: &WhisperConfig) -> Result<Self> {
422            // Use GPU if available: Metal (macOS) or CUDA (NVIDIA)
423            let device = Self::select_device();
424            tracing::info!(device = ?device, "Using device for Whisper");
425            let model_id = match config.model_name.as_str() {
426                "whisper-small-en" => "openai/whisper-small.en",
427                "whisper-small" => "openai/whisper-small",
428                "whisper-tiny.en" => "openai/whisper-tiny.en",
429                "whisper-tiny" => "openai/whisper-tiny",
430                "whisper-base.en" => "openai/whisper-base.en",
431                "whisper-base" => "openai/whisper-base",
432                "whisper-medium.en" => "openai/whisper-medium.en",
433                "whisper-medium" => "openai/whisper-medium",
434                "whisper-large-v3" => "openai/whisper-large-v3",
435                other => other, // Allow direct model IDs
436            };
437
438            tracing::info!(model_id = model_id, "Loading Whisper model");
439
440            let api = Api::new().map_err(|e| WhisperError::DownloadError {
441                cause: e.to_string(),
442            })?;
443            let repo = api.repo(Repo::with_revision(
444                model_id.to_string(),
445                RepoType::Model,
446                "main".to_string(),
447            ));
448
449            // Download model files
450            let config_path = repo
451                .get("config.json")
452                .map_err(|e| WhisperError::DownloadError {
453                    cause: format!("Failed to download config.json: {}", e),
454                })?;
455            let tokenizer_path =
456                repo.get("tokenizer.json")
457                    .map_err(|e| WhisperError::DownloadError {
458                        cause: format!("Failed to download tokenizer.json: {}", e),
459                    })?;
460            let model_path =
461                repo.get("model.safetensors")
462                    .map_err(|e| WhisperError::DownloadError {
463                        cause: format!("Failed to download model.safetensors: {}", e),
464                    })?;
465
466            // Load config
467            let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
468                WhisperError::InferenceError {
469                    cause: format!("Failed to read config: {}", e),
470                }
471            })?;
472            let model_config: Config =
473                serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
474                    cause: format!("Failed to parse config: {}", e),
475                })?;
476
477            // Load tokenizer
478            let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
479                WhisperError::InferenceError {
480                    cause: format!("Failed to load tokenizer: {}", e),
481                }
482            })?;
483
484            // Load mel filters
485            let mel_bytes = match model_config.num_mel_bins {
486                80 => include_bytes!("melfilters.bytes").as_slice(),
487                128 => include_bytes!("melfilters128.bytes").as_slice(),
488                n => {
489                    return Err(WhisperError::InferenceError {
490                        cause: format!("Unsupported number of mel bins: {}", n),
491                    }
492                    .into());
493                }
494            };
495            let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
496            <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(
497                mel_bytes,
498                &mut mel_filters,
499            );
500
501            // Load model weights
502            let vb = unsafe {
503                VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device).map_err(
504                    |e| WhisperError::InferenceError {
505                        cause: format!("Failed to load model weights: {}", e),
506                    },
507                )?
508            };
509            let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone()).map_err(
510                |e| WhisperError::InferenceError {
511                    cause: format!("Failed to load Whisper model: {}", e),
512                },
513            )?);
514
515            tracing::info!("Whisper model loaded successfully");
516
517            Ok(Self {
518                model,
519                tokenizer,
520                config: model_config,
521                mel_filters,
522                device,
523            })
524        }
525
526        /// Select the best available device (GPU if available, otherwise CPU)
527        fn select_device() -> Device {
528            // Try Metal (macOS Apple Silicon / AMD)
529            #[cfg(feature = "metal")]
530            {
531                if let Ok(device) = Device::new_metal(0) {
532                    tracing::info!("Metal GPU available");
533                    return device;
534                }
535            }
536
537            // Try CUDA (NVIDIA GPUs)
538            #[cfg(feature = "cuda")]
539            {
540                if let Ok(device) = Device::new_cuda(0) {
541                    tracing::info!("CUDA GPU available");
542                    return device;
543                }
544            }
545
546            // Fallback to CPU
547            tracing::info!("Using CPU (no GPU acceleration)");
548            Device::Cpu
549        }
550
551        /// Transcribe an audio file
552        pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
553            // Decode audio to PCM
554            let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
555
556            // Check audio statistics
557            let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
558            let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
559            let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
560            let audio_rms =
561                (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
562
563            tracing::info!(
564                duration = duration_secs,
565                samples = pcm_data.len(),
566                min = audio_min,
567                max = audio_max,
568                mean = audio_mean,
569                rms = audio_rms,
570                "Audio decoded"
571            );
572
573            self.transcribe_pcm(&pcm_data, duration_secs)
574        }
575
576        /// Transcribe PCM audio samples (16kHz mono f32)
577        pub fn transcribe_pcm(
578            &mut self,
579            pcm_data: &[f32],
580            duration_secs: f32,
581        ) -> Result<TranscriptionResult> {
582            // Whisper processes audio in 30-second chunks
583            const CHUNK_LENGTH: usize = 30 * 16000; // 30 seconds at 16kHz
584            const N_FRAMES: usize = 3000; // frames per chunk
585            const SAMPLE_RATE: f32 = 16000.0;
586
587            // Detect and trim leading silence
588            let silence_threshold = 0.01; // RMS threshold for silence
589            let window_size = 1600; // 100ms windows at 16kHz
590
591            let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
592            let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
593
594            let trimmed_start = start_sample as f32 / SAMPLE_RATE;
595            let trimmed_end = end_sample as f32 / SAMPLE_RATE;
596
597            tracing::info!(
598                start_sample = start_sample,
599                end_sample = end_sample,
600                trimmed_start_sec = trimmed_start,
601                trimmed_end_sec = trimmed_end,
602                original_duration = duration_secs,
603                "Trimmed silence"
604            );
605
606            // Use trimmed audio
607            let pcm_data = &pcm_data[start_sample..end_sample];
608            let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
609
610            let mut all_text = String::new();
611            let mut segments = Vec::new();
612
613            // Process audio in chunks
614            let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
615
616            for chunk_idx in 0..num_chunks {
617                let chunk_start = chunk_idx * CHUNK_LENGTH;
618                let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
619                let chunk = &pcm_data[chunk_start..chunk_end];
620
621                // Adjust timestamps to account for trimmed silence
622                let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
623                let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
624
625                tracing::info!(
626                    chunk = chunk_idx + 1,
627                    total = num_chunks,
628                    start = start_time,
629                    end = end_time,
630                    "Processing chunk"
631                );
632
633                // Reset decoder KV cache for each new chunk
634                match &mut self.model {
635                    Model::Normal(m) => m.decoder.reset_kv_cache(),
636                    Model::Quantized(m) => m.decoder.reset_kv_cache(),
637                }
638
639                // Convert chunk to mel spectrogram
640                let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
641                let n_mels = self.config.num_mel_bins;
642                let mel_len = mel.len();
643                let n_frames = mel_len / n_mels;
644
645                if chunk_idx == 0 {
646                    // Print config for debugging
647                    tracing::info!(
648                        num_mel_bins = self.config.num_mel_bins,
649                        max_source_positions = self.config.max_source_positions,
650                        max_target_positions = self.config.max_target_positions,
651                        "Model config"
652                    );
653
654                    // Mel statistics
655                    let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
656                    let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
657                    let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
658
659                    tracing::info!(
660                        mel_len = mel_len,
661                        n_mels = n_mels,
662                        n_frames = n_frames,
663                        chunk_samples = chunk.len(),
664                        expected_frames = 3000,
665                        mel_min = mel_min,
666                        mel_max = mel_max,
667                        mel_mean = mel_mean,
668                        "Mel spectrogram computed"
669                    );
670                }
671
672                // Ensure we have exactly 3000 frames (pad or truncate)
673                // NOTE: mel array from pcm_to_mel is stored as [mel_bin_0_all_frames, mel_bin_1_all_frames, ...]
674                // So each mel bin has n_frames contiguous values: mel[bin * n_frames + frame]
675                let mel = if n_frames < N_FRAMES {
676                    // Pad each mel bin's frames with zeros to reach N_FRAMES
677                    let mut padded = vec![0.0f32; n_mels * N_FRAMES];
678                    for bin in 0..n_mels {
679                        let src_start = bin * n_frames;
680                        let dst_start = bin * N_FRAMES;
681                        padded[dst_start..dst_start + n_frames]
682                            .copy_from_slice(&mel[src_start..src_start + n_frames]);
683                    }
684                    padded
685                } else if n_frames > N_FRAMES {
686                    // Truncate each mel bin's frames to N_FRAMES
687                    let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
688                    for bin in 0..n_mels {
689                        let src_start = bin * n_frames;
690                        let dst_start = bin * N_FRAMES;
691                        truncated[dst_start..dst_start + N_FRAMES]
692                            .copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
693                    }
694                    truncated
695                } else {
696                    mel
697                };
698
699                let mel =
700                    Tensor::from_vec(mel, (1, n_mels, N_FRAMES), &self.device).map_err(|e| {
701                        WhisperError::InferenceError {
702                            cause: format!("Failed to create mel tensor: {}", e),
703                        }
704                    })?;
705
706                if chunk_idx == 0 {
707                    let mel_shape = mel.shape();
708                    tracing::info!(
709                        mel_shape = ?mel_shape,
710                        "Mel tensor shape"
711                    );
712                }
713
714                // Run encoder
715                let audio_features = match &mut self.model {
716                    Model::Normal(m) => m.encoder.forward(&mel, true),
717                    Model::Quantized(m) => m.encoder.forward(&mel, true),
718                }
719                .map_err(|e| WhisperError::InferenceError {
720                    cause: format!("Encoder forward failed: {}", e),
721                })?;
722
723                if chunk_idx == 0 {
724                    let af_shape = audio_features.shape();
725                    tracing::info!(
726                        audio_features_shape = ?af_shape,
727                        "Audio features from encoder"
728                    );
729                }
730
731                // Get special token IDs
732                let sot_token = self.token_id(m::SOT_TOKEN)?;
733                let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
734                let eot_token = self.token_id(m::EOT_TOKEN)?;
735                let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
736
737                if chunk_idx == 0 {
738                    let en_token = self.tokenizer.token_to_id("<|en|>");
739                    tracing::info!(
740                        sot = sot_token,
741                        transcribe = transcribe_token,
742                        eot = eot_token,
743                        no_timestamps = no_timestamps_token,
744                        en_token = ?en_token,
745                        "Special tokens"
746                    );
747                }
748
749                // Build initial prompt
750                // For English-only models (*.en), we DON'T use language token
751                // For multilingual models, we add language token after sot_token
752                let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
753
754                // English-only models have vocab size 51864, multilingual have 51865
755                let is_english_only = self.config.vocab_size == 51864;
756
757                let tokens = if is_english_only {
758                    // English-only: SOT -> transcribe -> notimestamps
759                    vec![sot_token, transcribe_token, no_timestamps_token]
760                } else if has_language_token {
761                    // Multilingual: SOT -> language -> transcribe -> notimestamps
762                    let language_token = self.token_id("<|en|>")?;
763                    vec![
764                        sot_token,
765                        language_token,
766                        transcribe_token,
767                        no_timestamps_token,
768                    ]
769                } else {
770                    // Fallback
771                    vec![sot_token, transcribe_token, no_timestamps_token]
772                };
773
774                if chunk_idx == 0 {
775                    tracing::info!(
776                        is_english_only = is_english_only,
777                        vocab_size = self.config.vocab_size,
778                        prompt_tokens = ?tokens,
779                        "Initial prompt"
780                    );
781                }
782                let mut all_tokens = tokens.clone();
783
784                // Autoregressive decoding with token suppression
785                let sample_len = self.config.max_target_positions / 2;
786                let mut repeat_count = 0;
787                let mut last_token: Option<u32> = None;
788
789                // Build suppression mask
790                let suppress_tokens = &self.config.suppress_tokens;
791
792                for i in 0..sample_len {
793                    // For autoregressive decoding with KV cache:
794                    // - First iteration: pass all prompt tokens, flush_kv_cache=true
795                    // - Subsequent iterations: pass only the new token, flush_kv_cache=false
796                    let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
797                        .and_then(|t| t.unsqueeze(0))
798                        .map_err(|e| WhisperError::InferenceError {
799                            cause: format!("Failed to create tokens tensor: {}", e),
800                        })?;
801
802                    if chunk_idx == 0 && i < 3 {
803                        tracing::info!(
804                            step = i,
805                            all_tokens_len = all_tokens.len(),
806                            tokens_shape = ?tokens_tensor.shape(),
807                            "Decoder input"
808                        );
809                    }
810
811                    // Get hidden states from decoder, then project to vocabulary
812                    // Always pass all tokens (candle doesn't use KV cache the same way as PyTorch)
813                    let logits = match &mut self.model {
814                        Model::Normal(m) => {
815                            let hidden = m
816                                .decoder
817                                .forward(&tokens_tensor, &audio_features, true)
818                                .map_err(|e| WhisperError::InferenceError {
819                                cause: format!("Decoder forward failed: {}", e),
820                            })?;
821                            m.decoder.final_linear(&hidden).map_err(|e| {
822                                WhisperError::InferenceError {
823                                    cause: format!("Final linear failed: {}", e),
824                                }
825                            })?
826                        }
827                        Model::Quantized(m) => {
828                            let hidden = m
829                                .decoder
830                                .forward(&tokens_tensor, &audio_features, true)
831                                .map_err(|e| WhisperError::InferenceError {
832                                cause: format!("Decoder forward failed: {}", e),
833                            })?;
834                            m.decoder.final_linear(&hidden).map_err(|e| {
835                                WhisperError::InferenceError {
836                                    cause: format!("Final linear failed: {}", e),
837                                }
838                            })?
839                        }
840                    };
841
842                    if chunk_idx == 0 && i == 0 {
843                        tracing::info!(
844                            logits_shape = ?logits.shape(),
845                            "Decoder output logits"
846                        );
847                    }
848
849                    // Get logits for last position
850                    let (_, seq_len, _) =
851                        logits.dims3().map_err(|e| WhisperError::InferenceError {
852                            cause: format!("Failed to get logits dims: {}", e),
853                        })?;
854                    let mut logits_vec = logits
855                        .i((0, seq_len - 1, ..))
856                        .and_then(|t| t.to_vec1::<f32>())
857                        .map_err(|e| WhisperError::InferenceError {
858                            cause: format!("Failed to extract logits: {}", e),
859                        })?;
860
861                    // Apply token suppression from config
862                    for &token_id in suppress_tokens.iter() {
863                        if (token_id as usize) < logits_vec.len() {
864                            logits_vec[token_id as usize] = f32::NEG_INFINITY;
865                        }
866                    }
867
868                    // Suppress EOT token for first few steps to allow generation
869                    if all_tokens.len() < 10 {
870                        logits_vec[eot_token as usize] = f32::NEG_INFINITY;
871                    }
872
873                    // Suppress all special tokens during generation:
874                    // - SOT (50257), language tokens (50258-50261), task tokens (50358-50359),
875                    // - no_timestamps (50362), and timestamp tokens (50363+)
876                    logits_vec[sot_token as usize] = f32::NEG_INFINITY;
877                    logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
878                    logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
879                    // Suppress all tokens from 50257 onward (special tokens) except those in normal vocab
880                    for token_id in 50257..logits_vec.len() {
881                        logits_vec[token_id] = f32::NEG_INFINITY;
882                    }
883
884                    if chunk_idx == 0 && i == 0 {
885                        tracing::info!(
886                            suppress_count = suppress_tokens.len(),
887                            eot_suppressed = all_tokens.len() < 10,
888                            "Applied token suppression"
889                        );
890                    }
891
892                    // Find argmax
893                    let next_token = logits_vec
894                        .iter()
895                        .enumerate()
896                        .max_by(|(_, a), (_, b)| {
897                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
898                        })
899                        .map(|(idx, _)| idx as u32)
900                        .unwrap_or(eot_token);
901
902                    if chunk_idx == 0 && i < 5 {
903                        let max_logit =
904                            logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
905                        let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
906                        tracing::info!(
907                            step = i,
908                            next_token = next_token,
909                            max_logit = max_logit,
910                            min_logit = min_logit,
911                            "Decoding step"
912                        );
913                    }
914
915                    if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
916                        if chunk_idx == 0 && i < 5 {
917                            tracing::info!(
918                                next_token = next_token,
919                                eot = eot_token,
920                                "Stopping: EOT or invalid token"
921                            );
922                        }
923                        break;
924                    }
925
926                    // Check for excessive repetition (stop if same token repeats >3 times)
927                    if Some(next_token) == last_token {
928                        repeat_count += 1;
929                        if repeat_count > 3 {
930                            tracing::debug!("Breaking due to token repetition");
931                            break;
932                        }
933                    } else {
934                        repeat_count = 0;
935                    }
936                    last_token = Some(next_token);
937
938                    all_tokens.push(next_token);
939                }
940
941                // Decode tokens to text for this chunk
942                let prompt_len = if is_english_only { 3 } else { 4 };
943
944                if chunk_idx == 0 {
945                    tracing::info!(
946                        prompt_tokens = ?&all_tokens[..prompt_len],
947                        generated_tokens = ?&all_tokens[prompt_len..],
948                        total = all_tokens.len(),
949                        "Generated tokens for chunk"
950                    );
951                }
952
953                let chunk_text = self
954                    .tokenizer
955                    .decode(&all_tokens[prompt_len..], true) // Skip prompt tokens
956                    .map_err(|e| WhisperError::InferenceError {
957                        cause: format!("Failed to decode tokens: {}", e),
958                    })?;
959
960                let trimmed_text = chunk_text.trim();
961                if !trimmed_text.is_empty() {
962                    if !all_text.is_empty() {
963                        all_text.push(' ');
964                    }
965                    all_text.push_str(trimmed_text);
966
967                    segments.push(TranscriptionSegment {
968                        start: start_time,
969                        end: end_time,
970                        text: trimmed_text.to_string(),
971                    });
972                }
973            }
974
975            Ok(TranscriptionResult {
976                text: all_text.trim().to_string(),
977                language: "en".to_string(),
978                duration_secs,
979                segments,
980            })
981        }
982
983        fn token_id(&self, token: &str) -> Result<u32> {
984            self.tokenizer.token_to_id(token).ok_or_else(|| {
985                WhisperError::InferenceError {
986                    cause: format!("Token '{}' not found in vocabulary", token),
987                }
988                .into()
989            })
990        }
991    }
992
993    /// Find the sample index where speech starts (after leading silence)
994    fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
995        for i in (0..samples.len()).step_by(window_size) {
996            let end = (i + window_size).min(samples.len());
997            let window = &samples[i..end];
998            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
999            if rms > threshold {
1000                // Found speech, go back a bit to not cut off the start
1001                return i.saturating_sub(window_size);
1002            }
1003        }
1004        0 // No silence found, return start
1005    }
1006
1007    /// Find the sample index where speech ends (before trailing silence)
1008    fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
1009        for i in (0..samples.len()).rev().step_by(window_size) {
1010            let start = i.saturating_sub(window_size);
1011            let window = &samples[start..=i.min(samples.len() - 1)];
1012            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
1013            if rms > threshold {
1014                // Found speech, go forward a bit to not cut off the end
1015                return (i + window_size).min(samples.len());
1016            }
1017        }
1018        samples.len() // No silence found, return end
1019    }
1020}
1021
1022#[cfg(feature = "whisper")]
1023pub use inference::WhisperTranscriber;
1024
1025// ============================================================================
1026// Tests
1027// ============================================================================
1028
1029#[cfg(test)]
1030mod tests {
1031    use super::*;
1032
1033    #[test]
1034    fn whisper_model_registry() {
1035        let default = default_whisper_model_info();
1036        assert_eq!(default.name, "whisper-small-en");
1037        assert!(default.is_default);
1038        assert_eq!(default.language, "en");
1039
1040        // Unknown model returns default
1041        let unknown = get_whisper_model_info("nonexistent");
1042        assert_eq!(unknown.name, "whisper-small-en");
1043    }
1044
1045    #[test]
1046    fn whisper_config_defaults() {
1047        let config = WhisperConfig::default();
1048        assert_eq!(config.model_name, "whisper-small-en");
1049    }
1050}