memvid_core/
whisper.rs

1//! Whisper audio transcription with Candle inference.
2//!
3//! This module provides complete Whisper transcription functionality including:
4//! - Audio decoding (MP3, WAV, FLAC, etc.) via symphonia
5//! - Resampling to 16kHz via rubato
6//! - Whisper model inference via candle-transformers
7//! - Automatic model download from HuggingFace Hub
8
9use serde::{Deserialize, Serialize};
10use std::path::{Path, PathBuf};
11
12use crate::{MemvidError, Result};
13
14// ============================================================================
15// Model Registry
16// ============================================================================
17
18/// Available Whisper models with verified HuggingFace model IDs
19#[derive(Debug, Clone)]
20pub struct WhisperModelInfo {
21    /// Model identifier for HuggingFace
22    pub model_id: &'static str,
23    /// Human-readable name
24    pub name: &'static str,
25    /// Approximate model size in MB
26    pub size_mb: f32,
27    /// Whether this is the default model
28    pub is_default: bool,
29    /// Language (e.g., "en" for English-only models, "multilingual" for others)
30    pub language: &'static str,
31}
32
33/// Available Whisper models registry
34pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
35    WhisperModelInfo {
36        model_id: "openai/whisper-small.en",
37        name: "whisper-small-en",
38        size_mb: 244.0,
39        is_default: true,
40        language: "en",
41    },
42    WhisperModelInfo {
43        model_id: "openai/whisper-small",
44        name: "whisper-small",
45        size_mb: 244.0,
46        is_default: false,
47        language: "multilingual",
48    },
49];
50
51/// Get model info by name, defaults to whisper-small-en
52pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
53    WHISPER_MODELS
54        .iter()
55        .find(|m| m.name == name || m.model_id == name)
56        .unwrap_or_else(|| {
57            WHISPER_MODELS
58                .iter()
59                .find(|m| m.is_default)
60                .expect("default whisper model")
61        })
62}
63
64/// Get the default model info
65pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
66    WHISPER_MODELS
67        .iter()
68        .find(|m| m.is_default)
69        .expect("default whisper model exists")
70}
71
72// ============================================================================
73// Whisper Model Configuration
74// ============================================================================
75
76/// Configuration for Whisper model initialization
77#[derive(Debug, Clone)]
78pub struct WhisperConfig {
79    /// Model name (e.g., "whisper-small-en")
80    pub model_name: String,
81    /// Directory where models are cached
82    pub models_dir: PathBuf,
83    /// Whether to run in offline mode (no downloads)
84    pub offline: bool,
85}
86
87impl Default for WhisperConfig {
88    fn default() -> Self {
89        let models_dir = std::env::var("MEMVID_MODELS_DIR")
90            .ok()
91            .map(PathBuf::from)
92            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
93            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
94
95        let model_name = std::env::var("MEMVID_WHISPER_MODEL")
96            .unwrap_or_else(|_| "whisper-small-en".to_string());
97
98        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
99
100        Self {
101            model_name,
102            models_dir,
103            offline,
104        }
105    }
106}
107
108// ============================================================================
109// Whisper Error Types
110// ============================================================================
111
112/// Whisper-specific errors
113#[derive(Debug, thiserror::Error)]
114pub enum WhisperError {
115    /// Model not found
116    #[error("Whisper model '{model}' not found. {hint}")]
117    ModelNotFound { model: String, hint: String },
118
119    /// Audio decode failed
120    #[error("Failed to decode audio at {path:?}: {cause}")]
121    AudioDecodeError { path: PathBuf, cause: String },
122
123    /// Audio bytes decode failed
124    #[error("Failed to decode audio bytes: {cause}")]
125    AudioBytesDecodeError { cause: String },
126
127    /// Inference error
128    #[error("Whisper inference error: {cause}")]
129    InferenceError { cause: String },
130
131    /// Model download failed
132    #[error("Failed to download Whisper model: {cause}")]
133    DownloadError { cause: String },
134}
135
136impl From<WhisperError> for MemvidError {
137    fn from(err: WhisperError) -> Self {
138        MemvidError::ExtractionFailed {
139            reason: err.to_string().into_boxed_str(),
140        }
141    }
142}
143
144// ============================================================================
145// Transcription Result
146// ============================================================================
147
148/// Result of audio transcription
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TranscriptionResult {
151    /// The transcribed text
152    pub text: String,
153    /// Language detected or specified
154    pub language: String,
155    /// Duration of audio in seconds
156    pub duration_secs: f32,
157    /// Optional timestamps for segments
158    #[serde(default)]
159    pub segments: Vec<TranscriptionSegment>,
160}
161
162/// A segment of transcription with timestamps
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct TranscriptionSegment {
165    /// Start time in seconds
166    pub start: f32,
167    /// End time in seconds
168    pub end: f32,
169    /// Transcribed text for this segment
170    pub text: String,
171}
172
173// ============================================================================
174// Audio Decoding (Feature-gated)
175// ============================================================================
176
177#[cfg(feature = "whisper")]
178mod audio {
179    use super::*;
180    use std::fs::File;
181    use symphonia::core::audio::SampleBuffer;
182    use symphonia::core::codecs::DecoderOptions;
183    use symphonia::core::formats::FormatOptions;
184    use symphonia::core::io::MediaSourceStream;
185    use symphonia::core::meta::MetadataOptions;
186    use symphonia::core::probe::Hint;
187
188    /// Whisper sample rate (always 16kHz)
189    pub const WHISPER_SAMPLE_RATE: u32 = 16000;
190
191    /// Decode audio file to f32 samples, resampling to 16kHz mono
192    pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
193        let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
194            path: path.to_path_buf(),
195            cause: e.to_string(),
196        })?;
197
198        let mss = MediaSourceStream::new(Box::new(file), Default::default());
199
200        // Create a hint based on file extension
201        let mut hint = Hint::new();
202        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
203            hint.with_extension(ext);
204        }
205
206        // Probe the media source
207        let format_opts = FormatOptions::default();
208        let metadata_opts = MetadataOptions::default();
209        let probed = symphonia::default::get_probe()
210            .format(&hint, mss, &format_opts, &metadata_opts)
211            .map_err(|e| WhisperError::AudioDecodeError {
212                path: path.to_path_buf(),
213                cause: format!("Failed to probe audio format: {}", e),
214            })?;
215
216        let mut format = probed.format;
217
218        // Find the first audio track
219        let track = format
220            .tracks()
221            .iter()
222            .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
223            .ok_or_else(|| WhisperError::AudioDecodeError {
224                path: path.to_path_buf(),
225                cause: "No audio track found".to_string(),
226            })?;
227
228        let track_id = track.id;
229        let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
230        let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
231
232        // Create decoder
233        let decoder_opts = DecoderOptions::default();
234        let mut decoder = symphonia::default::get_codecs()
235            .make(&track.codec_params, &decoder_opts)
236            .map_err(|e| WhisperError::AudioDecodeError {
237                path: path.to_path_buf(),
238                cause: format!("Failed to create decoder: {}", e),
239            })?;
240
241        let mut samples: Vec<f32> = Vec::new();
242
243        // Decode all packets
244        loop {
245            let packet = match format.next_packet() {
246                Ok(p) => p,
247                Err(symphonia::core::errors::Error::IoError(e))
248                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
249                {
250                    break;
251                }
252                Err(_) => break,
253            };
254
255            if packet.track_id() != track_id {
256                continue;
257            }
258
259            let decoded = match decoder.decode(&packet) {
260                Ok(d) => d,
261                Err(_) => continue,
262            };
263
264            let spec = *decoded.spec();
265            let num_frames = decoded.frames();
266
267            if num_frames == 0 {
268                continue;
269            }
270
271            let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
272            sample_buf.copy_interleaved_ref(decoded);
273
274            let interleaved = sample_buf.samples();
275
276            // Convert to mono by averaging channels
277            if channels > 1 {
278                for chunk in interleaved.chunks(channels) {
279                    let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
280                    samples.push(mono);
281                }
282            } else {
283                samples.extend_from_slice(interleaved);
284            }
285        }
286
287        let duration_secs = samples.len() as f32 / sample_rate as f32;
288
289        // Log pre-resampling stats
290        let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
291        let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
292        let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
293        tracing::info!(
294            sample_rate = sample_rate,
295            channels = channels,
296            samples_before = samples.len(),
297            pre_min = pre_min,
298            pre_max = pre_max,
299            pre_rms = pre_rms,
300            "Audio before resampling"
301        );
302
303        // High-quality sinc resampling to 16kHz
304        let samples = if sample_rate != WHISPER_SAMPLE_RATE {
305            let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
306
307            // Log post-resampling stats
308            let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
309            let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
310            let post_rms = (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
311            tracing::info!(
312                samples_after = resampled.len(),
313                post_min = post_min,
314                post_max = post_max,
315                post_rms = post_rms,
316                "Audio after resampling"
317            );
318            resampled
319        } else {
320            tracing::info!("Audio already at 16kHz, no resampling needed");
321            samples
322        };
323
324        Ok((samples, duration_secs))
325    }
326
327    /// High-quality sinc resampling using rubato
328    fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
329        use rubato::{FftFixedIn, Resampler};
330
331        if from_rate == to_rate {
332            return samples.to_vec();
333        }
334
335        // Create resampler
336        let chunk_size = 1024;
337        let mut resampler = FftFixedIn::<f32>::new(
338            from_rate as usize,
339            to_rate as usize,
340            chunk_size,
341            2, // sub_chunks for quality
342            1, // mono
343        )
344        .expect("Failed to create resampler");
345
346        let mut output = Vec::new();
347        let mut pos = 0;
348
349        // Process in chunks
350        while pos < samples.len() {
351            let end = (pos + chunk_size).min(samples.len());
352            let chunk = &samples[pos..end];
353
354            // Pad if needed
355            let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
356                let mut padded = chunk.to_vec();
357                padded.resize(chunk_size, 0.0);
358                padded
359            } else {
360                chunk.to_vec()
361            };
362
363            let input = vec![input_chunk];
364            let resampled = resampler.process(&input, None).expect("Resampling failed");
365
366            if !resampled.is_empty() && !resampled[0].is_empty() {
367                output.extend_from_slice(&resampled[0]);
368            }
369
370            pos += chunk_size;
371        }
372
373        // Trim to expected length
374        let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
375        output.truncate(expected_len);
376
377        output
378    }
379}
380
381#[cfg(feature = "whisper")]
382pub use audio::*;
383
384// ============================================================================
385// Whisper Transcriber (Candle Inference)
386// ============================================================================
387
388#[cfg(feature = "whisper")]
389mod inference {
390    use super::*;
391    use candle_core::{DType, Device, IndexOp, Tensor};
392    use candle_nn::VarBuilder;
393    use candle_transformers::models::whisper::{self as m, audio, Config};
394    use hf_hub::{api::sync::Api, Repo, RepoType};
395    use tokenizers::Tokenizer;
396
397    /// Whisper model wrapper for transcription
398    pub struct WhisperTranscriber {
399        model: Model,
400        tokenizer: Tokenizer,
401        config: Config,
402        mel_filters: Vec<f32>,
403        device: Device,
404    }
405
406    #[allow(dead_code)]
407    enum Model {
408        Normal(m::model::Whisper),
409        Quantized(m::quantized_model::Whisper),
410    }
411
412    impl WhisperTranscriber {
413        /// Create a new WhisperTranscriber, downloading the model if needed
414        pub fn new(config: &WhisperConfig) -> Result<Self> {
415            // Use GPU if available: Metal (macOS) or CUDA (NVIDIA)
416            let device = Self::select_device();
417            tracing::info!(device = ?device, "Using device for Whisper");
418            let model_id = match config.model_name.as_str() {
419                "whisper-small-en" => "openai/whisper-small.en",
420                "whisper-small" => "openai/whisper-small",
421                "whisper-tiny.en" => "openai/whisper-tiny.en",
422                "whisper-tiny" => "openai/whisper-tiny",
423                "whisper-base.en" => "openai/whisper-base.en",
424                "whisper-base" => "openai/whisper-base",
425                "whisper-medium.en" => "openai/whisper-medium.en",
426                "whisper-medium" => "openai/whisper-medium",
427                "whisper-large-v3" => "openai/whisper-large-v3",
428                other => other, // Allow direct model IDs
429            };
430
431            tracing::info!(model_id = model_id, "Loading Whisper model");
432
433            let api = Api::new().map_err(|e| WhisperError::DownloadError {
434                cause: e.to_string(),
435            })?;
436            let repo = api.repo(Repo::with_revision(
437                model_id.to_string(),
438                RepoType::Model,
439                "main".to_string(),
440            ));
441
442            // Download model files
443            let config_path = repo.get("config.json").map_err(|e| WhisperError::DownloadError {
444                cause: format!("Failed to download config.json: {}", e),
445            })?;
446            let tokenizer_path = repo.get("tokenizer.json").map_err(|e| WhisperError::DownloadError {
447                cause: format!("Failed to download tokenizer.json: {}", e),
448            })?;
449            let model_path = repo.get("model.safetensors").map_err(|e| WhisperError::DownloadError {
450                cause: format!("Failed to download model.safetensors: {}", e),
451            })?;
452
453            // Load config
454            let config_str = std::fs::read_to_string(&config_path).map_err(|e| WhisperError::InferenceError {
455                cause: format!("Failed to read config: {}", e),
456            })?;
457            let model_config: Config = serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
458                cause: format!("Failed to parse config: {}", e),
459            })?;
460
461            // Load tokenizer
462            let tokenizer = Tokenizer::from_file(&tokenizer_path)
463                .map_err(|e| WhisperError::InferenceError {
464                    cause: format!("Failed to load tokenizer: {}", e),
465                })?;
466
467            // Load mel filters
468            let mel_bytes = match model_config.num_mel_bins {
469                80 => include_bytes!("melfilters.bytes").as_slice(),
470                128 => include_bytes!("melfilters128.bytes").as_slice(),
471                n => return Err(WhisperError::InferenceError {
472                    cause: format!("Unsupported number of mel bins: {}", n),
473                }.into()),
474            };
475            let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
476            <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
477
478            // Load model weights
479            let vb = unsafe {
480                VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
481                    .map_err(|e| WhisperError::InferenceError {
482                        cause: format!("Failed to load model weights: {}", e),
483                    })?
484            };
485            let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone())
486                .map_err(|e| WhisperError::InferenceError {
487                    cause: format!("Failed to load Whisper model: {}", e),
488                })?);
489
490            tracing::info!("Whisper model loaded successfully");
491
492            Ok(Self {
493                model,
494                tokenizer,
495                config: model_config,
496                mel_filters,
497                device,
498            })
499        }
500
501        /// Select the best available device (GPU if available, otherwise CPU)
502        fn select_device() -> Device {
503            // Try Metal (macOS Apple Silicon / AMD)
504            #[cfg(feature = "metal")]
505            {
506                if let Ok(device) = Device::new_metal(0) {
507                    tracing::info!("Metal GPU available");
508                    return device;
509                }
510            }
511
512            // Try CUDA (NVIDIA GPUs)
513            #[cfg(feature = "cuda")]
514            {
515                if let Ok(device) = Device::new_cuda(0) {
516                    tracing::info!("CUDA GPU available");
517                    return device;
518                }
519            }
520
521            // Fallback to CPU
522            tracing::info!("Using CPU (no GPU acceleration)");
523            Device::Cpu
524        }
525
526        /// Transcribe an audio file
527        pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
528            // Decode audio to PCM
529            let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
530
531            // Check audio statistics
532            let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
533            let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
534            let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
535            let audio_rms = (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
536
537            tracing::info!(
538                duration = duration_secs,
539                samples = pcm_data.len(),
540                min = audio_min,
541                max = audio_max,
542                mean = audio_mean,
543                rms = audio_rms,
544                "Audio decoded"
545            );
546
547            self.transcribe_pcm(&pcm_data, duration_secs)
548        }
549
550        /// Transcribe PCM audio samples (16kHz mono f32)
551        pub fn transcribe_pcm(&mut self, pcm_data: &[f32], duration_secs: f32) -> Result<TranscriptionResult> {
552            // Whisper processes audio in 30-second chunks
553            const CHUNK_LENGTH: usize = 30 * 16000; // 30 seconds at 16kHz
554            const N_FRAMES: usize = 3000; // frames per chunk
555            const SAMPLE_RATE: f32 = 16000.0;
556
557            // Detect and trim leading silence
558            let silence_threshold = 0.01; // RMS threshold for silence
559            let window_size = 1600; // 100ms windows at 16kHz
560
561            let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
562            let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
563
564            let trimmed_start = start_sample as f32 / SAMPLE_RATE;
565            let trimmed_end = end_sample as f32 / SAMPLE_RATE;
566
567            tracing::info!(
568                start_sample = start_sample,
569                end_sample = end_sample,
570                trimmed_start_sec = trimmed_start,
571                trimmed_end_sec = trimmed_end,
572                original_duration = duration_secs,
573                "Trimmed silence"
574            );
575
576            // Use trimmed audio
577            let pcm_data = &pcm_data[start_sample..end_sample];
578            let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
579
580            let mut all_text = String::new();
581            let mut segments = Vec::new();
582
583            // Process audio in chunks
584            let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
585
586            for chunk_idx in 0..num_chunks {
587                let chunk_start = chunk_idx * CHUNK_LENGTH;
588                let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
589                let chunk = &pcm_data[chunk_start..chunk_end];
590
591                // Adjust timestamps to account for trimmed silence
592                let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
593                let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
594
595                tracing::info!(
596                    chunk = chunk_idx + 1,
597                    total = num_chunks,
598                    start = start_time,
599                    end = end_time,
600                    "Processing chunk"
601                );
602
603                // Reset decoder KV cache for each new chunk
604                match &mut self.model {
605                    Model::Normal(m) => m.decoder.reset_kv_cache(),
606                    Model::Quantized(m) => m.decoder.reset_kv_cache(),
607                }
608
609                // Convert chunk to mel spectrogram
610                let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
611                let n_mels = self.config.num_mel_bins;
612                let mel_len = mel.len();
613                let n_frames = mel_len / n_mels;
614
615                if chunk_idx == 0 {
616                    // Print config for debugging
617                    tracing::info!(
618                        num_mel_bins = self.config.num_mel_bins,
619                        max_source_positions = self.config.max_source_positions,
620                        max_target_positions = self.config.max_target_positions,
621                        "Model config"
622                    );
623
624                    // Mel statistics
625                    let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
626                    let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
627                    let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
628
629                    tracing::info!(
630                        mel_len = mel_len,
631                        n_mels = n_mels,
632                        n_frames = n_frames,
633                        chunk_samples = chunk.len(),
634                        expected_frames = 3000,
635                        mel_min = mel_min,
636                        mel_max = mel_max,
637                        mel_mean = mel_mean,
638                        "Mel spectrogram computed"
639                    );
640                }
641
642                // Ensure we have exactly 3000 frames (pad or truncate)
643                // NOTE: mel array from pcm_to_mel is stored as [mel_bin_0_all_frames, mel_bin_1_all_frames, ...]
644                // So each mel bin has n_frames contiguous values: mel[bin * n_frames + frame]
645                let mel = if n_frames < N_FRAMES {
646                    // Pad each mel bin's frames with zeros to reach N_FRAMES
647                    let mut padded = vec![0.0f32; n_mels * N_FRAMES];
648                    for bin in 0..n_mels {
649                        let src_start = bin * n_frames;
650                        let dst_start = bin * N_FRAMES;
651                        padded[dst_start..dst_start + n_frames].copy_from_slice(&mel[src_start..src_start + n_frames]);
652                    }
653                    padded
654                } else if n_frames > N_FRAMES {
655                    // Truncate each mel bin's frames to N_FRAMES
656                    let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
657                    for bin in 0..n_mels {
658                        let src_start = bin * n_frames;
659                        let dst_start = bin * N_FRAMES;
660                        truncated[dst_start..dst_start + N_FRAMES].copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
661                    }
662                    truncated
663                } else {
664                    mel
665                };
666
667                let mel = Tensor::from_vec(
668                    mel,
669                    (1, n_mels, N_FRAMES),
670                    &self.device,
671                ).map_err(|e| WhisperError::InferenceError {
672                    cause: format!("Failed to create mel tensor: {}", e),
673                })?;
674
675                if chunk_idx == 0 {
676                    let mel_shape = mel.shape();
677                    tracing::info!(
678                        mel_shape = ?mel_shape,
679                        "Mel tensor shape"
680                    );
681                }
682
683                // Run encoder
684                let audio_features = match &mut self.model {
685                    Model::Normal(m) => m.encoder.forward(&mel, true),
686                    Model::Quantized(m) => m.encoder.forward(&mel, true),
687                }.map_err(|e| WhisperError::InferenceError {
688                    cause: format!("Encoder forward failed: {}", e),
689                })?;
690
691                if chunk_idx == 0 {
692                    let af_shape = audio_features.shape();
693                    tracing::info!(
694                        audio_features_shape = ?af_shape,
695                        "Audio features from encoder"
696                    );
697                }
698
699                // Get special token IDs
700                let sot_token = self.token_id(m::SOT_TOKEN)?;
701                let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
702                let eot_token = self.token_id(m::EOT_TOKEN)?;
703                let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
704
705                if chunk_idx == 0 {
706                    let en_token = self.tokenizer.token_to_id("<|en|>");
707                    tracing::info!(
708                        sot = sot_token,
709                        transcribe = transcribe_token,
710                        eot = eot_token,
711                        no_timestamps = no_timestamps_token,
712                        en_token = ?en_token,
713                        "Special tokens"
714                    );
715                }
716
717                // Build initial prompt
718                // For English-only models (*.en), we DON'T use language token
719                // For multilingual models, we add language token after sot_token
720                let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
721
722                // English-only models have vocab size 51864, multilingual have 51865
723                let is_english_only = self.config.vocab_size == 51864;
724
725                let tokens = if is_english_only {
726                    // English-only: SOT -> transcribe -> notimestamps
727                    vec![sot_token, transcribe_token, no_timestamps_token]
728                } else if has_language_token {
729                    // Multilingual: SOT -> language -> transcribe -> notimestamps
730                    let language_token = self.token_id("<|en|>")?;
731                    vec![sot_token, language_token, transcribe_token, no_timestamps_token]
732                } else {
733                    // Fallback
734                    vec![sot_token, transcribe_token, no_timestamps_token]
735                };
736
737                if chunk_idx == 0 {
738                    tracing::info!(
739                        is_english_only = is_english_only,
740                        vocab_size = self.config.vocab_size,
741                        prompt_tokens = ?tokens,
742                        "Initial prompt"
743                    );
744                }
745                let mut all_tokens = tokens.clone();
746
747                // Autoregressive decoding with token suppression
748                let sample_len = self.config.max_target_positions / 2;
749                let mut repeat_count = 0;
750                let mut last_token: Option<u32> = None;
751
752                // Build suppression mask
753                let suppress_tokens = &self.config.suppress_tokens;
754
755                for i in 0..sample_len {
756                    // For autoregressive decoding with KV cache:
757                    // - First iteration: pass all prompt tokens, flush_kv_cache=true
758                    // - Subsequent iterations: pass only the new token, flush_kv_cache=false
759                    let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
760                        .and_then(|t| t.unsqueeze(0))
761                        .map_err(|e| WhisperError::InferenceError {
762                            cause: format!("Failed to create tokens tensor: {}", e),
763                        })?;
764
765                    if chunk_idx == 0 && i < 3 {
766                        tracing::info!(
767                            step = i,
768                            all_tokens_len = all_tokens.len(),
769                            tokens_shape = ?tokens_tensor.shape(),
770                            "Decoder input"
771                        );
772                    }
773
774                    // Get hidden states from decoder, then project to vocabulary
775                    // Always pass all tokens (candle doesn't use KV cache the same way as PyTorch)
776                    let logits = match &mut self.model {
777                        Model::Normal(m) => {
778                            let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
779                                .map_err(|e| WhisperError::InferenceError {
780                                    cause: format!("Decoder forward failed: {}", e),
781                                })?;
782                            m.decoder.final_linear(&hidden)
783                                .map_err(|e| WhisperError::InferenceError {
784                                    cause: format!("Final linear failed: {}", e),
785                                })?
786                        }
787                        Model::Quantized(m) => {
788                            let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
789                                .map_err(|e| WhisperError::InferenceError {
790                                    cause: format!("Decoder forward failed: {}", e),
791                                })?;
792                            m.decoder.final_linear(&hidden)
793                                .map_err(|e| WhisperError::InferenceError {
794                                    cause: format!("Final linear failed: {}", e),
795                                })?
796                        }
797                    };
798
799                    if chunk_idx == 0 && i == 0 {
800                        tracing::info!(
801                            logits_shape = ?logits.shape(),
802                            "Decoder output logits"
803                        );
804                    }
805
806                    // Get logits for last position
807                    let (_, seq_len, _) = logits.dims3().map_err(|e| WhisperError::InferenceError {
808                        cause: format!("Failed to get logits dims: {}", e),
809                    })?;
810                    let mut logits_vec = logits.i((0, seq_len - 1, ..))
811                        .and_then(|t| t.to_vec1::<f32>())
812                        .map_err(|e| WhisperError::InferenceError {
813                            cause: format!("Failed to extract logits: {}", e),
814                        })?;
815
816                    // Apply token suppression from config
817                    for &token_id in suppress_tokens.iter() {
818                        if (token_id as usize) < logits_vec.len() {
819                            logits_vec[token_id as usize] = f32::NEG_INFINITY;
820                        }
821                    }
822
823                    // Suppress EOT token for first few steps to allow generation
824                    if all_tokens.len() < 10 {
825                        logits_vec[eot_token as usize] = f32::NEG_INFINITY;
826                    }
827
828                    // Suppress all special tokens during generation:
829                    // - SOT (50257), language tokens (50258-50261), task tokens (50358-50359),
830                    // - no_timestamps (50362), and timestamp tokens (50363+)
831                    logits_vec[sot_token as usize] = f32::NEG_INFINITY;
832                    logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
833                    logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
834                    // Suppress all tokens from 50257 onward (special tokens) except those in normal vocab
835                    for token_id in 50257..logits_vec.len() {
836                        logits_vec[token_id] = f32::NEG_INFINITY;
837                    }
838
839                    if chunk_idx == 0 && i == 0 {
840                        tracing::info!(
841                            suppress_count = suppress_tokens.len(),
842                            eot_suppressed = all_tokens.len() < 10,
843                            "Applied token suppression"
844                        );
845                    }
846
847                    // Find argmax
848                    let next_token = logits_vec
849                        .iter()
850                        .enumerate()
851                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
852                        .map(|(idx, _)| idx as u32)
853                        .unwrap_or(eot_token);
854
855                    if chunk_idx == 0 && i < 5 {
856                        let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
857                        let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
858                        tracing::info!(
859                            step = i,
860                            next_token = next_token,
861                            max_logit = max_logit,
862                            min_logit = min_logit,
863                            "Decoding step"
864                        );
865                    }
866
867                    if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
868                        if chunk_idx == 0 && i < 5 {
869                            tracing::info!(next_token = next_token, eot = eot_token, "Stopping: EOT or invalid token");
870                        }
871                        break;
872                    }
873
874                    // Check for excessive repetition (stop if same token repeats >3 times)
875                    if Some(next_token) == last_token {
876                        repeat_count += 1;
877                        if repeat_count > 3 {
878                            tracing::debug!("Breaking due to token repetition");
879                            break;
880                        }
881                    } else {
882                        repeat_count = 0;
883                    }
884                    last_token = Some(next_token);
885
886                    all_tokens.push(next_token);
887                }
888
889                // Decode tokens to text for this chunk
890                let prompt_len = if is_english_only { 3 } else { 4 };
891
892                if chunk_idx == 0 {
893                    tracing::info!(
894                        prompt_tokens = ?&all_tokens[..prompt_len],
895                        generated_tokens = ?&all_tokens[prompt_len..],
896                        total = all_tokens.len(),
897                        "Generated tokens for chunk"
898                    );
899                }
900
901                let chunk_text = self.tokenizer
902                    .decode(&all_tokens[prompt_len..], true) // Skip prompt tokens
903                    .map_err(|e| WhisperError::InferenceError {
904                        cause: format!("Failed to decode tokens: {}", e),
905                    })?;
906
907                let trimmed_text = chunk_text.trim();
908                if !trimmed_text.is_empty() {
909                    if !all_text.is_empty() {
910                        all_text.push(' ');
911                    }
912                    all_text.push_str(trimmed_text);
913
914                    segments.push(TranscriptionSegment {
915                        start: start_time,
916                        end: end_time,
917                        text: trimmed_text.to_string(),
918                    });
919                }
920            }
921
922            Ok(TranscriptionResult {
923                text: all_text.trim().to_string(),
924                language: "en".to_string(),
925                duration_secs,
926                segments,
927            })
928        }
929
930        fn token_id(&self, token: &str) -> Result<u32> {
931            self.tokenizer
932                .token_to_id(token)
933                .ok_or_else(|| WhisperError::InferenceError {
934                    cause: format!("Token '{}' not found in vocabulary", token),
935                }.into())
936        }
937    }
938
939    /// Find the sample index where speech starts (after leading silence)
940    fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
941        for i in (0..samples.len()).step_by(window_size) {
942            let end = (i + window_size).min(samples.len());
943            let window = &samples[i..end];
944            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
945            if rms > threshold {
946                // Found speech, go back a bit to not cut off the start
947                return i.saturating_sub(window_size);
948            }
949        }
950        0 // No silence found, return start
951    }
952
953    /// Find the sample index where speech ends (before trailing silence)
954    fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
955        for i in (0..samples.len()).rev().step_by(window_size) {
956            let start = i.saturating_sub(window_size);
957            let window = &samples[start..=i.min(samples.len() - 1)];
958            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
959            if rms > threshold {
960                // Found speech, go forward a bit to not cut off the end
961                return (i + window_size).min(samples.len());
962            }
963        }
964        samples.len() // No silence found, return end
965    }
966}
967
968#[cfg(feature = "whisper")]
969pub use inference::WhisperTranscriber;
970
971// ============================================================================
972// Tests
973// ============================================================================
974
975#[cfg(test)]
976mod tests {
977    use super::*;
978
979    #[test]
980    fn whisper_model_registry() {
981        let default = default_whisper_model_info();
982        assert_eq!(default.name, "whisper-small-en");
983        assert!(default.is_default);
984        assert_eq!(default.language, "en");
985
986        // Unknown model returns default
987        let unknown = get_whisper_model_info("nonexistent");
988        assert_eq!(unknown.name, "whisper-small-en");
989    }
990
991    #[test]
992    fn whisper_config_defaults() {
993        let config = WhisperConfig::default();
994        assert_eq!(config.model_name, "whisper-small-en");
995    }
996}