memvid_core/
whisper.rs

1//! Whisper audio transcription with Candle inference.
2//!
3//! This module provides complete Whisper transcription functionality including:
4//! - Audio decoding (MP3, WAV, FLAC, etc.) via symphonia
5//! - Resampling to 16kHz via rubato
6//! - Whisper model inference via candle-transformers
7//! - Automatic model download from HuggingFace Hub
8
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12use crate::MemvidError;
13
14// These are only used when whisper feature is enabled
15#[cfg(feature = "whisper")]
16use std::path::Path;
17#[cfg(feature = "whisper")]
18use crate::Result;
19
20// ============================================================================
21// Model Registry
22// ============================================================================
23
24/// Available Whisper models with verified HuggingFace model IDs
25#[derive(Debug, Clone)]
26pub struct WhisperModelInfo {
27    /// Model identifier for HuggingFace
28    pub model_id: &'static str,
29    /// Human-readable name
30    pub name: &'static str,
31    /// Approximate model size in MB
32    pub size_mb: f32,
33    /// Whether this is the default model
34    pub is_default: bool,
35    /// Language (e.g., "en" for English-only models, "multilingual" for others)
36    pub language: &'static str,
37}
38
39/// Available Whisper models registry
40pub static WHISPER_MODELS: &[WhisperModelInfo] = &[
41    WhisperModelInfo {
42        model_id: "openai/whisper-small.en",
43        name: "whisper-small-en",
44        size_mb: 244.0,
45        is_default: true,
46        language: "en",
47    },
48    WhisperModelInfo {
49        model_id: "openai/whisper-small",
50        name: "whisper-small",
51        size_mb: 244.0,
52        is_default: false,
53        language: "multilingual",
54    },
55];
56
57/// Get model info by name, defaults to whisper-small-en
58pub fn get_whisper_model_info(name: &str) -> &'static WhisperModelInfo {
59    WHISPER_MODELS
60        .iter()
61        .find(|m| m.name == name || m.model_id == name)
62        .unwrap_or_else(|| {
63            WHISPER_MODELS
64                .iter()
65                .find(|m| m.is_default)
66                .expect("default whisper model")
67        })
68}
69
70/// Get the default model info
71pub fn default_whisper_model_info() -> &'static WhisperModelInfo {
72    WHISPER_MODELS
73        .iter()
74        .find(|m| m.is_default)
75        .expect("default whisper model exists")
76}
77
78// ============================================================================
79// Whisper Model Configuration
80// ============================================================================
81
82/// Configuration for Whisper model initialization
83#[derive(Debug, Clone)]
84pub struct WhisperConfig {
85    /// Model name (e.g., "whisper-small-en")
86    pub model_name: String,
87    /// Directory where models are cached
88    pub models_dir: PathBuf,
89    /// Whether to run in offline mode (no downloads)
90    pub offline: bool,
91}
92
93impl Default for WhisperConfig {
94    fn default() -> Self {
95        let models_dir = std::env::var("MEMVID_MODELS_DIR")
96            .ok()
97            .map(PathBuf::from)
98            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
99            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
100
101        let model_name = std::env::var("MEMVID_WHISPER_MODEL")
102            .unwrap_or_else(|_| "whisper-small-en".to_string());
103
104        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
105
106        Self {
107            model_name,
108            models_dir,
109            offline,
110        }
111    }
112}
113
114// ============================================================================
115// Whisper Error Types
116// ============================================================================
117
118/// Whisper-specific errors
119#[derive(Debug, thiserror::Error)]
120pub enum WhisperError {
121    /// Model not found
122    #[error("Whisper model '{model}' not found. {hint}")]
123    ModelNotFound { model: String, hint: String },
124
125    /// Audio decode failed
126    #[error("Failed to decode audio at {path:?}: {cause}")]
127    AudioDecodeError { path: PathBuf, cause: String },
128
129    /// Audio bytes decode failed
130    #[error("Failed to decode audio bytes: {cause}")]
131    AudioBytesDecodeError { cause: String },
132
133    /// Inference error
134    #[error("Whisper inference error: {cause}")]
135    InferenceError { cause: String },
136
137    /// Model download failed
138    #[error("Failed to download Whisper model: {cause}")]
139    DownloadError { cause: String },
140}
141
142impl From<WhisperError> for MemvidError {
143    fn from(err: WhisperError) -> Self {
144        MemvidError::ExtractionFailed {
145            reason: err.to_string().into_boxed_str(),
146        }
147    }
148}
149
150// ============================================================================
151// Transcription Result
152// ============================================================================
153
154/// Result of audio transcription
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TranscriptionResult {
157    /// The transcribed text
158    pub text: String,
159    /// Language detected or specified
160    pub language: String,
161    /// Duration of audio in seconds
162    pub duration_secs: f32,
163    /// Optional timestamps for segments
164    #[serde(default)]
165    pub segments: Vec<TranscriptionSegment>,
166}
167
168/// A segment of transcription with timestamps
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct TranscriptionSegment {
171    /// Start time in seconds
172    pub start: f32,
173    /// End time in seconds
174    pub end: f32,
175    /// Transcribed text for this segment
176    pub text: String,
177}
178
179// ============================================================================
180// Audio Decoding (Feature-gated)
181// ============================================================================
182
183#[cfg(feature = "whisper")]
184mod audio {
185    use super::*;
186    use std::fs::File;
187    use symphonia::core::audio::SampleBuffer;
188    use symphonia::core::codecs::DecoderOptions;
189    use symphonia::core::formats::FormatOptions;
190    use symphonia::core::io::MediaSourceStream;
191    use symphonia::core::meta::MetadataOptions;
192    use symphonia::core::probe::Hint;
193
194    /// Whisper sample rate (always 16kHz)
195    pub const WHISPER_SAMPLE_RATE: u32 = 16000;
196
197    /// Decode audio file to f32 samples, resampling to 16kHz mono
198    pub fn decode_audio_file(path: &Path) -> Result<(Vec<f32>, f32)> {
199        let file = File::open(path).map_err(|e| WhisperError::AudioDecodeError {
200            path: path.to_path_buf(),
201            cause: e.to_string(),
202        })?;
203
204        let mss = MediaSourceStream::new(Box::new(file), Default::default());
205
206        // Create a hint based on file extension
207        let mut hint = Hint::new();
208        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
209            hint.with_extension(ext);
210        }
211
212        // Probe the media source
213        let format_opts = FormatOptions::default();
214        let metadata_opts = MetadataOptions::default();
215        let probed = symphonia::default::get_probe()
216            .format(&hint, mss, &format_opts, &metadata_opts)
217            .map_err(|e| WhisperError::AudioDecodeError {
218                path: path.to_path_buf(),
219                cause: format!("Failed to probe audio format: {}", e),
220            })?;
221
222        let mut format = probed.format;
223
224        // Find the first audio track
225        let track = format
226            .tracks()
227            .iter()
228            .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
229            .ok_or_else(|| WhisperError::AudioDecodeError {
230                path: path.to_path_buf(),
231                cause: "No audio track found".to_string(),
232            })?;
233
234        let track_id = track.id;
235        let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
236        let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2);
237
238        // Create decoder
239        let decoder_opts = DecoderOptions::default();
240        let mut decoder = symphonia::default::get_codecs()
241            .make(&track.codec_params, &decoder_opts)
242            .map_err(|e| WhisperError::AudioDecodeError {
243                path: path.to_path_buf(),
244                cause: format!("Failed to create decoder: {}", e),
245            })?;
246
247        let mut samples: Vec<f32> = Vec::new();
248
249        // Decode all packets
250        loop {
251            let packet = match format.next_packet() {
252                Ok(p) => p,
253                Err(symphonia::core::errors::Error::IoError(e))
254                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
255                {
256                    break;
257                }
258                Err(_) => break,
259            };
260
261            if packet.track_id() != track_id {
262                continue;
263            }
264
265            let decoded = match decoder.decode(&packet) {
266                Ok(d) => d,
267                Err(_) => continue,
268            };
269
270            let spec = *decoded.spec();
271            let num_frames = decoded.frames();
272
273            if num_frames == 0 {
274                continue;
275            }
276
277            let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
278            sample_buf.copy_interleaved_ref(decoded);
279
280            let interleaved = sample_buf.samples();
281
282            // Convert to mono by averaging channels
283            if channels > 1 {
284                for chunk in interleaved.chunks(channels) {
285                    let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
286                    samples.push(mono);
287                }
288            } else {
289                samples.extend_from_slice(interleaved);
290            }
291        }
292
293        let duration_secs = samples.len() as f32 / sample_rate as f32;
294
295        // Log pre-resampling stats
296        let pre_min = samples.iter().cloned().fold(f32::INFINITY, f32::min);
297        let pre_max = samples.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
298        let pre_rms = (samples.iter().map(|x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
299        tracing::info!(
300            sample_rate = sample_rate,
301            channels = channels,
302            samples_before = samples.len(),
303            pre_min = pre_min,
304            pre_max = pre_max,
305            pre_rms = pre_rms,
306            "Audio before resampling"
307        );
308
309        // High-quality sinc resampling to 16kHz
310        let samples = if sample_rate != WHISPER_SAMPLE_RATE {
311            let resampled = resample_sinc(&samples, sample_rate, WHISPER_SAMPLE_RATE);
312
313            // Log post-resampling stats
314            let post_min = resampled.iter().cloned().fold(f32::INFINITY, f32::min);
315            let post_max = resampled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
316            let post_rms = (resampled.iter().map(|x| x * x).sum::<f32>() / resampled.len() as f32).sqrt();
317            tracing::info!(
318                samples_after = resampled.len(),
319                post_min = post_min,
320                post_max = post_max,
321                post_rms = post_rms,
322                "Audio after resampling"
323            );
324            resampled
325        } else {
326            tracing::info!("Audio already at 16kHz, no resampling needed");
327            samples
328        };
329
330        Ok((samples, duration_secs))
331    }
332
333    /// High-quality sinc resampling using rubato
334    fn resample_sinc(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
335        use rubato::{FftFixedIn, Resampler};
336
337        if from_rate == to_rate {
338            return samples.to_vec();
339        }
340
341        // Create resampler
342        let chunk_size = 1024;
343        let mut resampler = FftFixedIn::<f32>::new(
344            from_rate as usize,
345            to_rate as usize,
346            chunk_size,
347            2, // sub_chunks for quality
348            1, // mono
349        )
350        .expect("Failed to create resampler");
351
352        let mut output = Vec::new();
353        let mut pos = 0;
354
355        // Process in chunks
356        while pos < samples.len() {
357            let end = (pos + chunk_size).min(samples.len());
358            let chunk = &samples[pos..end];
359
360            // Pad if needed
361            let input_chunk: Vec<f32> = if chunk.len() < chunk_size {
362                let mut padded = chunk.to_vec();
363                padded.resize(chunk_size, 0.0);
364                padded
365            } else {
366                chunk.to_vec()
367            };
368
369            let input = vec![input_chunk];
370            let resampled = resampler.process(&input, None).expect("Resampling failed");
371
372            if !resampled.is_empty() && !resampled[0].is_empty() {
373                output.extend_from_slice(&resampled[0]);
374            }
375
376            pos += chunk_size;
377        }
378
379        // Trim to expected length
380        let expected_len = (samples.len() as f64 * to_rate as f64 / from_rate as f64) as usize;
381        output.truncate(expected_len);
382
383        output
384    }
385}
386
387#[cfg(feature = "whisper")]
388pub use audio::*;
389
390// ============================================================================
391// Whisper Transcriber (Candle Inference)
392// ============================================================================
393
394#[cfg(feature = "whisper")]
395mod inference {
396    use super::*;
397    use candle_core::{DType, Device, IndexOp, Tensor};
398    use candle_nn::VarBuilder;
399    use candle_transformers::models::whisper::{self as m, audio, Config};
400    use hf_hub::{api::sync::Api, Repo, RepoType};
401    use tokenizers::Tokenizer;
402
403    /// Whisper model wrapper for transcription
404    pub struct WhisperTranscriber {
405        model: Model,
406        tokenizer: Tokenizer,
407        config: Config,
408        mel_filters: Vec<f32>,
409        device: Device,
410    }
411
412    #[allow(dead_code)]
413    enum Model {
414        Normal(m::model::Whisper),
415        Quantized(m::quantized_model::Whisper),
416    }
417
418    impl WhisperTranscriber {
419        /// Create a new WhisperTranscriber, downloading the model if needed
420        pub fn new(config: &WhisperConfig) -> Result<Self> {
421            // Use GPU if available: Metal (macOS) or CUDA (NVIDIA)
422            let device = Self::select_device();
423            tracing::info!(device = ?device, "Using device for Whisper");
424            let model_id = match config.model_name.as_str() {
425                "whisper-small-en" => "openai/whisper-small.en",
426                "whisper-small" => "openai/whisper-small",
427                "whisper-tiny.en" => "openai/whisper-tiny.en",
428                "whisper-tiny" => "openai/whisper-tiny",
429                "whisper-base.en" => "openai/whisper-base.en",
430                "whisper-base" => "openai/whisper-base",
431                "whisper-medium.en" => "openai/whisper-medium.en",
432                "whisper-medium" => "openai/whisper-medium",
433                "whisper-large-v3" => "openai/whisper-large-v3",
434                other => other, // Allow direct model IDs
435            };
436
437            tracing::info!(model_id = model_id, "Loading Whisper model");
438
439            let api = Api::new().map_err(|e| WhisperError::DownloadError {
440                cause: e.to_string(),
441            })?;
442            let repo = api.repo(Repo::with_revision(
443                model_id.to_string(),
444                RepoType::Model,
445                "main".to_string(),
446            ));
447
448            // Download model files
449            let config_path = repo.get("config.json").map_err(|e| WhisperError::DownloadError {
450                cause: format!("Failed to download config.json: {}", e),
451            })?;
452            let tokenizer_path = repo.get("tokenizer.json").map_err(|e| WhisperError::DownloadError {
453                cause: format!("Failed to download tokenizer.json: {}", e),
454            })?;
455            let model_path = repo.get("model.safetensors").map_err(|e| WhisperError::DownloadError {
456                cause: format!("Failed to download model.safetensors: {}", e),
457            })?;
458
459            // Load config
460            let config_str = std::fs::read_to_string(&config_path).map_err(|e| WhisperError::InferenceError {
461                cause: format!("Failed to read config: {}", e),
462            })?;
463            let model_config: Config = serde_json::from_str(&config_str).map_err(|e| WhisperError::InferenceError {
464                cause: format!("Failed to parse config: {}", e),
465            })?;
466
467            // Load tokenizer
468            let tokenizer = Tokenizer::from_file(&tokenizer_path)
469                .map_err(|e| WhisperError::InferenceError {
470                    cause: format!("Failed to load tokenizer: {}", e),
471                })?;
472
473            // Load mel filters
474            let mel_bytes = match model_config.num_mel_bins {
475                80 => include_bytes!("melfilters.bytes").as_slice(),
476                128 => include_bytes!("melfilters128.bytes").as_slice(),
477                n => return Err(WhisperError::InferenceError {
478                    cause: format!("Unsupported number of mel bins: {}", n),
479                }.into()),
480            };
481            let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
482            <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
483
484            // Load model weights
485            let vb = unsafe {
486                VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
487                    .map_err(|e| WhisperError::InferenceError {
488                        cause: format!("Failed to load model weights: {}", e),
489                    })?
490            };
491            let model = Model::Normal(m::model::Whisper::load(&vb, model_config.clone())
492                .map_err(|e| WhisperError::InferenceError {
493                    cause: format!("Failed to load Whisper model: {}", e),
494                })?);
495
496            tracing::info!("Whisper model loaded successfully");
497
498            Ok(Self {
499                model,
500                tokenizer,
501                config: model_config,
502                mel_filters,
503                device,
504            })
505        }
506
507        /// Select the best available device (GPU if available, otherwise CPU)
508        fn select_device() -> Device {
509            // Try Metal (macOS Apple Silicon / AMD)
510            #[cfg(feature = "metal")]
511            {
512                if let Ok(device) = Device::new_metal(0) {
513                    tracing::info!("Metal GPU available");
514                    return device;
515                }
516            }
517
518            // Try CUDA (NVIDIA GPUs)
519            #[cfg(feature = "cuda")]
520            {
521                if let Ok(device) = Device::new_cuda(0) {
522                    tracing::info!("CUDA GPU available");
523                    return device;
524                }
525            }
526
527            // Fallback to CPU
528            tracing::info!("Using CPU (no GPU acceleration)");
529            Device::Cpu
530        }
531
532        /// Transcribe an audio file
533        pub fn transcribe_file(&mut self, path: &Path) -> Result<TranscriptionResult> {
534            // Decode audio to PCM
535            let (pcm_data, duration_secs) = super::decode_audio_file(path)?;
536
537            // Check audio statistics
538            let audio_min = pcm_data.iter().cloned().fold(f32::INFINITY, f32::min);
539            let audio_max = pcm_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
540            let audio_mean = pcm_data.iter().sum::<f32>() / pcm_data.len() as f32;
541            let audio_rms = (pcm_data.iter().map(|x| x * x).sum::<f32>() / pcm_data.len() as f32).sqrt();
542
543            tracing::info!(
544                duration = duration_secs,
545                samples = pcm_data.len(),
546                min = audio_min,
547                max = audio_max,
548                mean = audio_mean,
549                rms = audio_rms,
550                "Audio decoded"
551            );
552
553            self.transcribe_pcm(&pcm_data, duration_secs)
554        }
555
556        /// Transcribe PCM audio samples (16kHz mono f32)
557        pub fn transcribe_pcm(&mut self, pcm_data: &[f32], duration_secs: f32) -> Result<TranscriptionResult> {
558            // Whisper processes audio in 30-second chunks
559            const CHUNK_LENGTH: usize = 30 * 16000; // 30 seconds at 16kHz
560            const N_FRAMES: usize = 3000; // frames per chunk
561            const SAMPLE_RATE: f32 = 16000.0;
562
563            // Detect and trim leading silence
564            let silence_threshold = 0.01; // RMS threshold for silence
565            let window_size = 1600; // 100ms windows at 16kHz
566
567            let start_sample = find_speech_start(pcm_data, silence_threshold, window_size);
568            let end_sample = find_speech_end(pcm_data, silence_threshold, window_size);
569
570            let trimmed_start = start_sample as f32 / SAMPLE_RATE;
571            let trimmed_end = end_sample as f32 / SAMPLE_RATE;
572
573            tracing::info!(
574                start_sample = start_sample,
575                end_sample = end_sample,
576                trimmed_start_sec = trimmed_start,
577                trimmed_end_sec = trimmed_end,
578                original_duration = duration_secs,
579                "Trimmed silence"
580            );
581
582            // Use trimmed audio
583            let pcm_data = &pcm_data[start_sample..end_sample];
584            let _trimmed_duration = pcm_data.len() as f32 / SAMPLE_RATE;
585
586            let mut all_text = String::new();
587            let mut segments = Vec::new();
588
589            // Process audio in chunks
590            let num_chunks = (pcm_data.len() + CHUNK_LENGTH - 1) / CHUNK_LENGTH;
591
592            for chunk_idx in 0..num_chunks {
593                let chunk_start = chunk_idx * CHUNK_LENGTH;
594                let chunk_end = (chunk_start + CHUNK_LENGTH).min(pcm_data.len());
595                let chunk = &pcm_data[chunk_start..chunk_end];
596
597                // Adjust timestamps to account for trimmed silence
598                let start_time = trimmed_start + chunk_start as f32 / SAMPLE_RATE;
599                let end_time = trimmed_start + chunk_end as f32 / SAMPLE_RATE;
600
601                tracing::info!(
602                    chunk = chunk_idx + 1,
603                    total = num_chunks,
604                    start = start_time,
605                    end = end_time,
606                    "Processing chunk"
607                );
608
609                // Reset decoder KV cache for each new chunk
610                match &mut self.model {
611                    Model::Normal(m) => m.decoder.reset_kv_cache(),
612                    Model::Quantized(m) => m.decoder.reset_kv_cache(),
613                }
614
615                // Convert chunk to mel spectrogram
616                let mel = audio::pcm_to_mel(&self.config, chunk, &self.mel_filters);
617                let n_mels = self.config.num_mel_bins;
618                let mel_len = mel.len();
619                let n_frames = mel_len / n_mels;
620
621                if chunk_idx == 0 {
622                    // Print config for debugging
623                    tracing::info!(
624                        num_mel_bins = self.config.num_mel_bins,
625                        max_source_positions = self.config.max_source_positions,
626                        max_target_positions = self.config.max_target_positions,
627                        "Model config"
628                    );
629
630                    // Mel statistics
631                    let mel_min = mel.iter().cloned().fold(f32::INFINITY, f32::min);
632                    let mel_max = mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
633                    let mel_mean = mel.iter().sum::<f32>() / mel.len() as f32;
634
635                    tracing::info!(
636                        mel_len = mel_len,
637                        n_mels = n_mels,
638                        n_frames = n_frames,
639                        chunk_samples = chunk.len(),
640                        expected_frames = 3000,
641                        mel_min = mel_min,
642                        mel_max = mel_max,
643                        mel_mean = mel_mean,
644                        "Mel spectrogram computed"
645                    );
646                }
647
648                // Ensure we have exactly 3000 frames (pad or truncate)
649                // NOTE: mel array from pcm_to_mel is stored as [mel_bin_0_all_frames, mel_bin_1_all_frames, ...]
650                // So each mel bin has n_frames contiguous values: mel[bin * n_frames + frame]
651                let mel = if n_frames < N_FRAMES {
652                    // Pad each mel bin's frames with zeros to reach N_FRAMES
653                    let mut padded = vec![0.0f32; n_mels * N_FRAMES];
654                    for bin in 0..n_mels {
655                        let src_start = bin * n_frames;
656                        let dst_start = bin * N_FRAMES;
657                        padded[dst_start..dst_start + n_frames].copy_from_slice(&mel[src_start..src_start + n_frames]);
658                    }
659                    padded
660                } else if n_frames > N_FRAMES {
661                    // Truncate each mel bin's frames to N_FRAMES
662                    let mut truncated = vec![0.0f32; n_mels * N_FRAMES];
663                    for bin in 0..n_mels {
664                        let src_start = bin * n_frames;
665                        let dst_start = bin * N_FRAMES;
666                        truncated[dst_start..dst_start + N_FRAMES].copy_from_slice(&mel[src_start..src_start + N_FRAMES]);
667                    }
668                    truncated
669                } else {
670                    mel
671                };
672
673                let mel = Tensor::from_vec(
674                    mel,
675                    (1, n_mels, N_FRAMES),
676                    &self.device,
677                ).map_err(|e| WhisperError::InferenceError {
678                    cause: format!("Failed to create mel tensor: {}", e),
679                })?;
680
681                if chunk_idx == 0 {
682                    let mel_shape = mel.shape();
683                    tracing::info!(
684                        mel_shape = ?mel_shape,
685                        "Mel tensor shape"
686                    );
687                }
688
689                // Run encoder
690                let audio_features = match &mut self.model {
691                    Model::Normal(m) => m.encoder.forward(&mel, true),
692                    Model::Quantized(m) => m.encoder.forward(&mel, true),
693                }.map_err(|e| WhisperError::InferenceError {
694                    cause: format!("Encoder forward failed: {}", e),
695                })?;
696
697                if chunk_idx == 0 {
698                    let af_shape = audio_features.shape();
699                    tracing::info!(
700                        audio_features_shape = ?af_shape,
701                        "Audio features from encoder"
702                    );
703                }
704
705                // Get special token IDs
706                let sot_token = self.token_id(m::SOT_TOKEN)?;
707                let transcribe_token = self.token_id(m::TRANSCRIBE_TOKEN)?;
708                let eot_token = self.token_id(m::EOT_TOKEN)?;
709                let no_timestamps_token = self.token_id(m::NO_TIMESTAMPS_TOKEN)?;
710
711                if chunk_idx == 0 {
712                    let en_token = self.tokenizer.token_to_id("<|en|>");
713                    tracing::info!(
714                        sot = sot_token,
715                        transcribe = transcribe_token,
716                        eot = eot_token,
717                        no_timestamps = no_timestamps_token,
718                        en_token = ?en_token,
719                        "Special tokens"
720                    );
721                }
722
723                // Build initial prompt
724                // For English-only models (*.en), we DON'T use language token
725                // For multilingual models, we add language token after sot_token
726                let has_language_token = self.tokenizer.token_to_id("<|en|>").is_some();
727
728                // English-only models have vocab size 51864, multilingual have 51865
729                let is_english_only = self.config.vocab_size == 51864;
730
731                let tokens = if is_english_only {
732                    // English-only: SOT -> transcribe -> notimestamps
733                    vec![sot_token, transcribe_token, no_timestamps_token]
734                } else if has_language_token {
735                    // Multilingual: SOT -> language -> transcribe -> notimestamps
736                    let language_token = self.token_id("<|en|>")?;
737                    vec![sot_token, language_token, transcribe_token, no_timestamps_token]
738                } else {
739                    // Fallback
740                    vec![sot_token, transcribe_token, no_timestamps_token]
741                };
742
743                if chunk_idx == 0 {
744                    tracing::info!(
745                        is_english_only = is_english_only,
746                        vocab_size = self.config.vocab_size,
747                        prompt_tokens = ?tokens,
748                        "Initial prompt"
749                    );
750                }
751                let mut all_tokens = tokens.clone();
752
753                // Autoregressive decoding with token suppression
754                let sample_len = self.config.max_target_positions / 2;
755                let mut repeat_count = 0;
756                let mut last_token: Option<u32> = None;
757
758                // Build suppression mask
759                let suppress_tokens = &self.config.suppress_tokens;
760
761                for i in 0..sample_len {
762                    // For autoregressive decoding with KV cache:
763                    // - First iteration: pass all prompt tokens, flush_kv_cache=true
764                    // - Subsequent iterations: pass only the new token, flush_kv_cache=false
765                    let tokens_tensor = Tensor::new(all_tokens.as_slice(), &self.device)
766                        .and_then(|t| t.unsqueeze(0))
767                        .map_err(|e| WhisperError::InferenceError {
768                            cause: format!("Failed to create tokens tensor: {}", e),
769                        })?;
770
771                    if chunk_idx == 0 && i < 3 {
772                        tracing::info!(
773                            step = i,
774                            all_tokens_len = all_tokens.len(),
775                            tokens_shape = ?tokens_tensor.shape(),
776                            "Decoder input"
777                        );
778                    }
779
780                    // Get hidden states from decoder, then project to vocabulary
781                    // Always pass all tokens (candle doesn't use KV cache the same way as PyTorch)
782                    let logits = match &mut self.model {
783                        Model::Normal(m) => {
784                            let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
785                                .map_err(|e| WhisperError::InferenceError {
786                                    cause: format!("Decoder forward failed: {}", e),
787                                })?;
788                            m.decoder.final_linear(&hidden)
789                                .map_err(|e| WhisperError::InferenceError {
790                                    cause: format!("Final linear failed: {}", e),
791                                })?
792                        }
793                        Model::Quantized(m) => {
794                            let hidden = m.decoder.forward(&tokens_tensor, &audio_features, true)
795                                .map_err(|e| WhisperError::InferenceError {
796                                    cause: format!("Decoder forward failed: {}", e),
797                                })?;
798                            m.decoder.final_linear(&hidden)
799                                .map_err(|e| WhisperError::InferenceError {
800                                    cause: format!("Final linear failed: {}", e),
801                                })?
802                        }
803                    };
804
805                    if chunk_idx == 0 && i == 0 {
806                        tracing::info!(
807                            logits_shape = ?logits.shape(),
808                            "Decoder output logits"
809                        );
810                    }
811
812                    // Get logits for last position
813                    let (_, seq_len, _) = logits.dims3().map_err(|e| WhisperError::InferenceError {
814                        cause: format!("Failed to get logits dims: {}", e),
815                    })?;
816                    let mut logits_vec = logits.i((0, seq_len - 1, ..))
817                        .and_then(|t| t.to_vec1::<f32>())
818                        .map_err(|e| WhisperError::InferenceError {
819                            cause: format!("Failed to extract logits: {}", e),
820                        })?;
821
822                    // Apply token suppression from config
823                    for &token_id in suppress_tokens.iter() {
824                        if (token_id as usize) < logits_vec.len() {
825                            logits_vec[token_id as usize] = f32::NEG_INFINITY;
826                        }
827                    }
828
829                    // Suppress EOT token for first few steps to allow generation
830                    if all_tokens.len() < 10 {
831                        logits_vec[eot_token as usize] = f32::NEG_INFINITY;
832                    }
833
834                    // Suppress all special tokens during generation:
835                    // - SOT (50257), language tokens (50258-50261), task tokens (50358-50359),
836                    // - no_timestamps (50362), and timestamp tokens (50363+)
837                    logits_vec[sot_token as usize] = f32::NEG_INFINITY;
838                    logits_vec[transcribe_token as usize] = f32::NEG_INFINITY;
839                    logits_vec[no_timestamps_token as usize] = f32::NEG_INFINITY;
840                    // Suppress all tokens from 50257 onward (special tokens) except those in normal vocab
841                    for token_id in 50257..logits_vec.len() {
842                        logits_vec[token_id] = f32::NEG_INFINITY;
843                    }
844
845                    if chunk_idx == 0 && i == 0 {
846                        tracing::info!(
847                            suppress_count = suppress_tokens.len(),
848                            eot_suppressed = all_tokens.len() < 10,
849                            "Applied token suppression"
850                        );
851                    }
852
853                    // Find argmax
854                    let next_token = logits_vec
855                        .iter()
856                        .enumerate()
857                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
858                        .map(|(idx, _)| idx as u32)
859                        .unwrap_or(eot_token);
860
861                    if chunk_idx == 0 && i < 5 {
862                        let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
863                        let min_logit = logits_vec.iter().cloned().fold(f32::INFINITY, f32::min);
864                        tracing::info!(
865                            step = i,
866                            next_token = next_token,
867                            max_logit = max_logit,
868                            min_logit = min_logit,
869                            "Decoding step"
870                        );
871                    }
872
873                    if next_token == eot_token || next_token >= self.config.vocab_size as u32 {
874                        if chunk_idx == 0 && i < 5 {
875                            tracing::info!(next_token = next_token, eot = eot_token, "Stopping: EOT or invalid token");
876                        }
877                        break;
878                    }
879
880                    // Check for excessive repetition (stop if same token repeats >3 times)
881                    if Some(next_token) == last_token {
882                        repeat_count += 1;
883                        if repeat_count > 3 {
884                            tracing::debug!("Breaking due to token repetition");
885                            break;
886                        }
887                    } else {
888                        repeat_count = 0;
889                    }
890                    last_token = Some(next_token);
891
892                    all_tokens.push(next_token);
893                }
894
895                // Decode tokens to text for this chunk
896                let prompt_len = if is_english_only { 3 } else { 4 };
897
898                if chunk_idx == 0 {
899                    tracing::info!(
900                        prompt_tokens = ?&all_tokens[..prompt_len],
901                        generated_tokens = ?&all_tokens[prompt_len..],
902                        total = all_tokens.len(),
903                        "Generated tokens for chunk"
904                    );
905                }
906
907                let chunk_text = self.tokenizer
908                    .decode(&all_tokens[prompt_len..], true) // Skip prompt tokens
909                    .map_err(|e| WhisperError::InferenceError {
910                        cause: format!("Failed to decode tokens: {}", e),
911                    })?;
912
913                let trimmed_text = chunk_text.trim();
914                if !trimmed_text.is_empty() {
915                    if !all_text.is_empty() {
916                        all_text.push(' ');
917                    }
918                    all_text.push_str(trimmed_text);
919
920                    segments.push(TranscriptionSegment {
921                        start: start_time,
922                        end: end_time,
923                        text: trimmed_text.to_string(),
924                    });
925                }
926            }
927
928            Ok(TranscriptionResult {
929                text: all_text.trim().to_string(),
930                language: "en".to_string(),
931                duration_secs,
932                segments,
933            })
934        }
935
936        fn token_id(&self, token: &str) -> Result<u32> {
937            self.tokenizer
938                .token_to_id(token)
939                .ok_or_else(|| WhisperError::InferenceError {
940                    cause: format!("Token '{}' not found in vocabulary", token),
941                }.into())
942        }
943    }
944
945    /// Find the sample index where speech starts (after leading silence)
946    fn find_speech_start(samples: &[f32], threshold: f32, window_size: usize) -> usize {
947        for i in (0..samples.len()).step_by(window_size) {
948            let end = (i + window_size).min(samples.len());
949            let window = &samples[i..end];
950            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
951            if rms > threshold {
952                // Found speech, go back a bit to not cut off the start
953                return i.saturating_sub(window_size);
954            }
955        }
956        0 // No silence found, return start
957    }
958
959    /// Find the sample index where speech ends (before trailing silence)
960    fn find_speech_end(samples: &[f32], threshold: f32, window_size: usize) -> usize {
961        for i in (0..samples.len()).rev().step_by(window_size) {
962            let start = i.saturating_sub(window_size);
963            let window = &samples[start..=i.min(samples.len() - 1)];
964            let rms = (window.iter().map(|x| x * x).sum::<f32>() / window.len() as f32).sqrt();
965            if rms > threshold {
966                // Found speech, go forward a bit to not cut off the end
967                return (i + window_size).min(samples.len());
968            }
969        }
970        samples.len() // No silence found, return end
971    }
972}
973
974#[cfg(feature = "whisper")]
975pub use inference::WhisperTranscriber;
976
977// ============================================================================
978// Tests
979// ============================================================================
980
981#[cfg(test)]
982mod tests {
983    use super::*;
984
985    #[test]
986    fn whisper_model_registry() {
987        let default = default_whisper_model_info();
988        assert_eq!(default.name, "whisper-small-en");
989        assert!(default.is_default);
990        assert_eq!(default.language, "en");
991
992        // Unknown model returns default
993        let unknown = get_whisper_model_info("nonexistent");
994        assert_eq!(unknown.name, "whisper-small-en");
995    }
996
997    #[test]
998    fn whisper_config_defaults() {
999        let config = WhisperConfig::default();
1000        assert_eq!(config.model_name, "whisper-small-en");
1001    }
1002}