Skip to main content

ferrum_models/executor/
whisper_executor.rs

1//! Whisper ASR Executor — full decode pipeline matching Python whisper.
2//!
3//! Implements: timestamp-based sequential decode, logit suppression (SuppressBlank,
4//! SuppressTokens, ApplyTimestampRules), temperature fallback, compression ratio
5//! check, seek-based segmentation.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use candle_core::{DType, Device as CandleDevice, Tensor};
12use ferrum_interfaces::{
13    model_executor::{
14        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
15        PrefillInput, PrefillOutput,
16    },
17    ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::whisper::WhisperModelWrapper;
24use crate::audio_processor;
25
26// ── Token constants ─────────────────────────────────────────────────────
27// These match Python whisper's tokenizer for multilingual models.
28
29const TIMESTAMP_BEGIN: u32 = 50364;
30const INPUT_STRIDE: usize = 2; // mel frames per output token (N_FRAMES / n_audio_ctx = 3000/1500)
31const TIME_PRECISION: f64 = 0.02; // seconds per timestamp token (INPUT_STRIDE * HOP_LENGTH / SAMPLE_RATE)
32
33/// Non-speech token IDs to suppress (from Python whisper tokenizer.non_speech_tokens).
34/// These are symbols that shouldn't appear in transcription (from Python whisper).
35const NON_SPEECH_TOKENS: &[u32] = &[
36    1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
37    366, 438, 532, 685, 691, 1060, 1258, 1261, 1435, 1436, 1652, 2028, 2029, 2150, 2404, 2932,
38    3292, 3455, 3723, 4100, 5751, 6283, 6347, 6436, 6615, 7579, 8765, 9929, 10563, 10813, 11318,
39    12380, 14117, 14397, 14734, 15003, 15068, 15206, 16450, 16805, 17193, 17832, 19063, 19438,
40    19635, 20203, 21111, 24220, 24408, 25212, 25830, 26622, 28156, 28279, 29464, 31650, 32302,
41    32470, 36865, 42863, 47425, 49870, 50254,
42];
43
44/// Whisper executor for speech-to-text.
45pub struct WhisperModelExecutor {
46    model: WhisperModelWrapper,
47    tokenizer: tokenizers::Tokenizer,
48    info: ModelInfo,
49    // Special token IDs
50    sot_token: u32,
51    eot_token: u32,
52    transcribe_token: u32,
53    translate_token: u32,
54    no_timestamps_token: u32,
55    no_speech_token: u32, // <|nocaptions|> = 50362
56    sot_prev: u32,
57    sot_lm: u32,
58    language_tokens: HashMap<String, u32>,
59    /// Precomputed suppress mask: token IDs that are always suppressed.
60    suppress_token_ids: Vec<u32>,
61    /// Sample length (max decode tokens per segment).
62    sample_len: usize,
63}
64
65impl WhisperModelExecutor {
66    /// Load from model directory.
67    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
68        let dir = std::path::Path::new(model_path);
69
70        let model = WhisperModelWrapper::from_model_dir(dir, device, dtype)?;
71
72        let tokenizer = tokenizers::Tokenizer::from_file(dir.join("tokenizer.json"))
73            .map_err(|e| FerrumError::model(format!("load tokenizer: {e}")))?;
74
75        // Resolve special token IDs
76        let sot_token = token_id(&tokenizer, "<|startoftranscript|>");
77        let eot_token = token_id(&tokenizer, "<|endoftext|>");
78        let transcribe_token = token_id(&tokenizer, "<|transcribe|>");
79        let translate_token = token_id(&tokenizer, "<|translate|>");
80        let no_timestamps_token = token_id(&tokenizer, "<|notimestamps|>");
81        let no_speech_token = token_id(&tokenizer, "<|nocaptions|>");
82        let sot_prev = token_id(&tokenizer, "<|startofprev|>");
83        let sot_lm = token_id(&tokenizer, "<|startoflm|>");
84
85        // Build language token map
86        let mut language_tokens = HashMap::new();
87        for lang in &[
88            "en", "zh", "ja", "ko", "fr", "de", "es", "ru", "ar", "pt", "it", "nl", "tr", "pl",
89            "sv", "da", "fi", "hu", "cs", "ro", "bg", "uk", "el", "hr", "sk", "th", "vi", "id",
90            "ms", "hi", "ta", "te", "ur", "fa", "he", "ca", "gl", "eu", "la",
91        ] {
92            let token_str = format!("<|{lang}|>");
93            if let Some(id) = tokenizer.token_to_id(&token_str) {
94                language_tokens.insert(lang.to_string(), id);
95            }
96        }
97
98        // Build suppress token list (matches Python _get_suppress_tokens)
99        let mut suppress_ids: Vec<u32> = NON_SPEECH_TOKENS.to_vec();
100        suppress_ids.extend_from_slice(&[
101            transcribe_token,
102            translate_token,
103            sot_token,
104            sot_prev,
105            sot_lm,
106            no_speech_token,
107        ]);
108        suppress_ids.sort();
109        suppress_ids.dedup();
110
111        let sample_len = model.config().max_target_positions / 2;
112
113        let info = ModelInfo {
114            model_id: ferrum_types::ModelId(model_path.to_string()),
115            model_type: ModelType::Custom("whisper".to_string()),
116            hidden_size: model.config().d_model,
117            vocab_size: model.config().vocab_size,
118            num_layers: model.config().encoder_layers + model.config().decoder_layers,
119            num_heads: model.config().encoder_attention_heads,
120            num_kv_heads: model.config().decoder_attention_heads,
121            num_parameters: 0,
122            max_sequence_length: model.config().max_target_positions,
123            device: Device::CPU,
124            dtype: DataType::FP32,
125            version: None,
126            license: None,
127            metadata: HashMap::new(),
128        };
129
130        info!(
131            "WhisperModelExecutor: {} (d_model={}, languages={}, suppress_tokens={})",
132            model_path,
133            model.config().d_model,
134            language_tokens.len(),
135            suppress_ids.len(),
136        );
137
138        Ok(Self {
139            model,
140            tokenizer,
141            info,
142            sot_token,
143            eot_token,
144            transcribe_token,
145            translate_token,
146            no_timestamps_token,
147            no_speech_token,
148            sot_prev,
149            sot_lm,
150            language_tokens,
151            suppress_token_ids: suppress_ids,
152            sample_len,
153        })
154    }
155
156    /// Transcribe audio file → text.
157    pub fn transcribe_file(&self, audio_path: &str, language: Option<&str>) -> Result<String> {
158        let pcm = audio_processor::load_audio(audio_path)?;
159        self.transcribe_pcm(&pcm, language)
160    }
161
162    /// Transcribe raw audio bytes (WAV/any) → text.
163    pub fn transcribe_bytes(&self, audio_data: &[u8], language: Option<&str>) -> Result<String> {
164        let pcm = audio_processor::load_audio_bytes(audio_data)?;
165        self.transcribe_pcm(&pcm, language)
166    }
167
168    /// Full transcription pipeline matching Python whisper.transcribe().
169    ///
170    /// - Computes mel for entire audio (padded by 30s of silence)
171    /// - Seek-based loop over 30s segments
172    /// - For each segment: encode → decode with timestamp rules → extract segments
173    /// - Temperature fallback on high compression ratio or low avg logprob
174    fn transcribe_pcm(&self, pcm: &[f32], language: Option<&str>) -> Result<String> {
175        let lang_token = language
176            .and_then(|l| self.language_tokens.get(l).copied())
177            .unwrap_or_else(|| {
178                self.language_tokens
179                    .get("en")
180                    .copied()
181                    .unwrap_or(self.sot_token + 1)
182            });
183
184        // Compute mel for full audio + 30s padding (matching Python: padding=N_SAMPLES)
185        let n_samples = candle_transformers::models::whisper::N_SAMPLES;
186        let n_frames = candle_transformers::models::whisper::N_FRAMES;
187        let mut padded_pcm = pcm.to_vec();
188        padded_pcm.resize(padded_pcm.len() + n_samples, 0.0); // 30s silence padding
189        let content_frames = pcm.len() / candle_transformers::models::whisper::HOP_LENGTH;
190
191        // Initial tokens: SOT + language + transcribe (WITHOUT no_timestamps — we use timestamps)
192        let sot_sequence = vec![self.sot_token, lang_token, self.transcribe_token];
193        let sample_begin = sot_sequence.len();
194
195        // Blank token for SuppressBlank
196        let blank_token = 220u32; // space token, matches Python tokenizer.encode(" ")
197
198        // Temperatures for fallback
199        let temperatures: &[f32] = &[0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
200
201        // Max initial timestamp: 1.0 second → index 50 (1.0 / 0.02)
202        let max_initial_timestamp_index: u32 = 50;
203
204        let mut seek: usize = 0;
205        let mut all_tokens: Vec<u32> = Vec::new();
206
207        while seek < content_frames {
208            let segment_size = n_frames.min(content_frames - seek);
209
210            // Extract mel segment
211            let mel = self.mel_segment_at(&padded_pcm, seek, n_frames)?;
212
213            // Encode
214            let encoder_out = self.model.encode(&mel)?;
215
216            // Decode with temperature fallback
217            let (tokens, avg_logprob, no_speech_prob, _temperature) = self.decode_with_fallback(
218                &encoder_out,
219                &sot_sequence,
220                sample_begin,
221                blank_token,
222                max_initial_timestamp_index,
223                temperatures,
224            )?;
225
226            // No speech check (matching Python transcribe.py):
227            // Skip segment if no_speech_prob is high, unless logprob is also high
228            let should_skip = no_speech_prob > 0.6 && avg_logprob < -1.0;
229            if should_skip {
230                seek += segment_size;
231                continue;
232            }
233
234            // Parse timestamp tokens to determine seek advancement
235            let sampled = &tokens[sample_begin..];
236            let timestamp_mask: Vec<bool> = sampled.iter().map(|&t| t >= TIMESTAMP_BEGIN).collect();
237
238            // Find consecutive timestamp pairs
239            let mut consecutive_indices = Vec::new();
240            for i in 0..timestamp_mask.len().saturating_sub(1) {
241                if timestamp_mask[i] && timestamp_mask[i + 1] {
242                    consecutive_indices.push(i + 1);
243                }
244            }
245
246            // Collect text tokens (strip timestamps and special tokens)
247            let text_tokens: Vec<u32> = sampled
248                .iter()
249                .copied()
250                .filter(|&t| t < self.eot_token)
251                .collect();
252            all_tokens.extend_from_slice(&text_tokens);
253
254            if !consecutive_indices.is_empty() {
255                // Has consecutive timestamps → use last timestamp to advance seek
256                let single_timestamp_ending = timestamp_mask.len() >= 2
257                    && !timestamp_mask[timestamp_mask.len() - 2]
258                    && timestamp_mask[timestamp_mask.len() - 1];
259
260                if single_timestamp_ending {
261                    seek += segment_size;
262                } else {
263                    let last_idx = *consecutive_indices.last().unwrap();
264                    let last_ts_pos = (sampled[last_idx] - TIMESTAMP_BEGIN) as usize;
265                    seek += last_ts_pos * INPUT_STRIDE;
266                }
267            } else {
268                // No consecutive timestamps — check for any single timestamp
269                let timestamps: Vec<u32> = sampled
270                    .iter()
271                    .copied()
272                    .filter(|&t| t >= TIMESTAMP_BEGIN)
273                    .collect();
274                if !timestamps.is_empty() && *timestamps.last().unwrap() != TIMESTAMP_BEGIN {
275                    let last_ts_pos = (*timestamps.last().unwrap() - TIMESTAMP_BEGIN) as usize;
276                    seek += last_ts_pos * INPUT_STRIDE;
277                } else {
278                    seek += segment_size;
279                }
280            }
281        }
282
283        // Decode all collected text tokens
284        let text = self
285            .tokenizer
286            .decode(&all_tokens, true)
287            .map_err(|e| FerrumError::model(format!("decode tokens: {e}")))?;
288
289        Ok(text.trim().to_string())
290    }
291
292    /// Extract mel segment at given seek position, pad to n_frames.
293    fn mel_segment_at(&self, pcm: &[f32], seek_frames: usize, n_frames: usize) -> Result<Tensor> {
294        let hop = candle_transformers::models::whisper::HOP_LENGTH;
295        let start_sample = seek_frames * hop;
296        let n_samples = candle_transformers::models::whisper::N_SAMPLES;
297        let end_sample = (start_sample + n_samples).min(pcm.len());
298        let segment = &pcm[start_sample..end_sample];
299        self.model.pcm_to_mel_tensor(segment)
300    }
301
302    /// Decode one segment with temperature fallback.
303    /// Returns (all_tokens, avg_logprob, no_speech_prob, temperature_used).
304    fn decode_with_fallback(
305        &self,
306        encoder_out: &Tensor,
307        sot_sequence: &[u32],
308        sample_begin: usize,
309        blank_token: u32,
310        max_initial_timestamp_index: u32,
311        temperatures: &[f32],
312    ) -> Result<(Vec<u32>, f32, f32, f32)> {
313        let mut last_result = None;
314
315        for &temp in temperatures {
316            let (tokens, avg_logprob, no_speech_prob) = self.decode_segment(
317                encoder_out,
318                sot_sequence,
319                sample_begin,
320                blank_token,
321                max_initial_timestamp_index,
322                temp,
323            )?;
324
325            let text_tokens: Vec<u32> = tokens[sample_begin..]
326                .iter()
327                .copied()
328                .filter(|&t| t < self.eot_token)
329                .collect();
330            let text = self
331                .tokenizer
332                .decode(&text_tokens, true)
333                .unwrap_or_default();
334
335            let cr = compression_ratio(&text);
336
337            // Matching Python: fallback if too repetitive or logprob too low,
338            // but NOT if it's silence (high no_speech_prob overrides).
339            let mut needs_fallback = false;
340            if cr > 2.4 {
341                needs_fallback = true;
342            }
343            if avg_logprob < -1.0 {
344                needs_fallback = true;
345            }
346            if no_speech_prob > 0.6 {
347                needs_fallback = false; // silence — accept as-is
348            }
349
350            last_result = Some((tokens, avg_logprob, no_speech_prob, temp));
351
352            if !needs_fallback {
353                break;
354            }
355        }
356
357        last_result.ok_or_else(|| FerrumError::model("decode_with_fallback: no result"))
358    }
359
360    /// Decode one segment at a given temperature.
361    /// Returns (full_token_sequence, avg_logprob, no_speech_prob).
362    fn decode_segment(
363        &self,
364        encoder_out: &Tensor,
365        sot_sequence: &[u32],
366        sample_begin: usize,
367        blank_token: u32,
368        max_initial_timestamp_index: u32,
369        temperature: f32,
370    ) -> Result<(Vec<u32>, f32, f32)> {
371        self.model.reset_decoder();
372
373        let mut tokens: Vec<u32> = sot_sequence.to_vec();
374        let mut sum_logprobs: f32 = 0.0;
375        let mut no_speech_prob: f32 = 0.0;
376        let mut n_text_tokens: usize = 0;
377
378        // First forward: feed all initial tokens
379        let mut logits = self.model.decode_step(&tokens, encoder_out)?;
380
381        for step in 0..self.sample_len {
382            // On first step, capture no_speech_prob
383            if step == 0 {
384                let sot_logits = &logits; // logits from last position of initial tokens
385                let probs = softmax_vec(sot_logits);
386                no_speech_prob = probs[self.no_speech_token as usize];
387            }
388
389            // ─── Apply logit filters (matching Python) ───
390
391            let sampled_tokens = &tokens[sample_begin..];
392
393            // 1. SuppressBlank: on first text token, suppress blank and EOT
394            if sampled_tokens.is_empty() {
395                logits[blank_token as usize] = f32::NEG_INFINITY;
396                logits[self.eot_token as usize] = f32::NEG_INFINITY;
397            }
398
399            // 2. SuppressTokens: always suppress non-speech + special tokens
400            for &t in &self.suppress_token_ids {
401                if (t as usize) < logits.len() {
402                    logits[t as usize] = f32::NEG_INFINITY;
403                }
404            }
405
406            // 3. ApplyTimestampRules
407            self.apply_timestamp_rules(
408                &mut logits,
409                sampled_tokens,
410                sample_begin,
411                max_initial_timestamp_index,
412                step,
413            );
414
415            // ─── Select next token ───
416            let next_token = if temperature == 0.0 {
417                argmax(&logits)
418            } else {
419                sample_with_temperature(&logits, temperature)
420            };
421
422            // Track logprobs
423            let log_probs = log_softmax_vec(&logits);
424            if next_token != self.eot_token {
425                sum_logprobs += log_probs[next_token as usize];
426                n_text_tokens += 1;
427            }
428
429            tokens.push(next_token);
430
431            if next_token == self.eot_token
432                || tokens.len() > self.model.config().max_target_positions
433            {
434                break;
435            }
436
437            // Repetition detection on text tokens only (ignoring timestamps).
438            // If a text token repeats > 5 times in the last 10 text tokens, stop.
439            let text_tail: Vec<u32> = tokens[sample_begin..]
440                .iter()
441                .copied()
442                .filter(|&t| t < TIMESTAMP_BEGIN && t != self.eot_token)
443                .collect();
444            if text_tail.len() >= 6 {
445                let last = *text_tail.last().unwrap();
446                let consecutive = text_tail.iter().rev().take_while(|&&t| t == last).count();
447                if consecutive >= 5 {
448                    // Trim: remove the repeated tokens, keep 1
449                    let mut keep = tokens.len();
450                    let mut removed = 0;
451                    while keep > sample_begin && removed < consecutive - 1 {
452                        keep -= 1;
453                        if tokens[keep] == last {
454                            removed += 1;
455                        }
456                    }
457                    tokens.truncate(keep + 1);
458                    break;
459                }
460            }
461
462            // Next step: feed only the new token
463            logits = self.model.decode_step(&[next_token], encoder_out)?;
464        }
465
466        let avg_logprob = if n_text_tokens > 0 {
467            sum_logprobs / n_text_tokens as f32
468        } else {
469            f32::NEG_INFINITY
470        };
471
472        Ok((tokens, avg_logprob, no_speech_prob))
473    }
474
475    /// Apply timestamp rules (matching Python ApplyTimestampRules).
476    fn apply_timestamp_rules(
477        &self,
478        logits: &mut [f32],
479        sampled_tokens: &[u32],
480        _sample_begin: usize,
481        max_initial_timestamp_index: u32,
482        _step: usize,
483    ) {
484        let ts_begin = TIMESTAMP_BEGIN as usize;
485
486        // Suppress <|notimestamps|>
487        logits[self.no_timestamps_token as usize] = f32::NEG_INFINITY;
488
489        // Timestamp pairing rules
490        let last_was_timestamp =
491            !sampled_tokens.is_empty() && *sampled_tokens.last().unwrap() >= TIMESTAMP_BEGIN;
492
493        let penultimate_was_timestamp =
494            sampled_tokens.len() < 2 || sampled_tokens[sampled_tokens.len() - 2] >= TIMESTAMP_BEGIN;
495
496        if last_was_timestamp {
497            if penultimate_was_timestamp {
498                // Two timestamps in a row → must produce non-timestamp
499                for i in ts_begin..logits.len() {
500                    logits[i] = f32::NEG_INFINITY;
501                }
502            } else {
503                // Timestamp after text → must produce timestamp or EOT (suppress all text)
504                for i in 0..self.eot_token as usize {
505                    logits[i] = f32::NEG_INFINITY;
506                }
507            }
508        }
509
510        // Monotonically increasing timestamps
511        let timestamps: Vec<u32> = sampled_tokens
512            .iter()
513            .copied()
514            .filter(|&t| t >= TIMESTAMP_BEGIN)
515            .collect();
516        if !timestamps.is_empty() {
517            let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
518                *timestamps.last().unwrap()
519            } else {
520                *timestamps.last().unwrap() + 1
521            };
522            for i in ts_begin..timestamp_last as usize {
523                if i < logits.len() {
524                    logits[i] = f32::NEG_INFINITY;
525                }
526            }
527        }
528
529        // First token: must be a timestamp, constrained by max_initial_timestamp
530        if sampled_tokens.is_empty() {
531            for i in 0..ts_begin {
532                logits[i] = f32::NEG_INFINITY;
533            }
534            let last_allowed = TIMESTAMP_BEGIN + max_initial_timestamp_index;
535            for i in (last_allowed as usize + 1)..logits.len() {
536                logits[i] = f32::NEG_INFINITY;
537            }
538        }
539
540        // If sum of timestamp probabilities > max text token probability, force timestamp
541        let log_probs = log_softmax_vec(logits);
542        let ts_logsumexp = {
543            let max_ts = log_probs[ts_begin..]
544                .iter()
545                .copied()
546                .fold(f32::NEG_INFINITY, f32::max);
547            if max_ts.is_finite() {
548                max_ts
549                    + log_probs[ts_begin..]
550                        .iter()
551                        .map(|&lp| (lp - max_ts).exp())
552                        .sum::<f32>()
553                        .ln()
554            } else {
555                f32::NEG_INFINITY
556            }
557        };
558        let max_text_logprob = log_probs[..ts_begin]
559            .iter()
560            .copied()
561            .fold(f32::NEG_INFINITY, f32::max);
562
563        if ts_logsumexp > max_text_logprob {
564            for i in 0..ts_begin {
565                logits[i] = f32::NEG_INFINITY;
566            }
567        }
568    }
569}
570
571// ── Utility functions ───────────────────────────────────────────────────
572
573fn token_id(tokenizer: &tokenizers::Tokenizer, token: &str) -> u32 {
574    tokenizer.token_to_id(token).unwrap_or(0)
575}
576
577fn argmax(v: &[f32]) -> u32 {
578    v.iter()
579        .enumerate()
580        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
581        .map(|(i, _)| i as u32)
582        .unwrap_or(0)
583}
584
585fn softmax_vec(logits: &[f32]) -> Vec<f32> {
586    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
587    let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
588    let sum: f32 = exps.iter().sum();
589    exps.iter().map(|&e| e / sum).collect()
590}
591
592fn log_softmax_vec(logits: &[f32]) -> Vec<f32> {
593    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
594    let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
595    let log_sum = max + sum_exp.ln();
596    logits.iter().map(|&x| x - log_sum).collect()
597}
598
599fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
600    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
601    let probs = softmax_vec(&scaled);
602    // Weighted random sampling
603    let r: f32 = rand_f32();
604    let mut cumulative = 0.0;
605    for (i, &p) in probs.iter().enumerate() {
606        cumulative += p;
607        if cumulative >= r {
608            return i as u32;
609        }
610    }
611    (probs.len() - 1) as u32
612}
613
614fn rand_f32() -> f32 {
615    // Simple xorshift-based PRNG (good enough for temperature sampling)
616    use std::sync::atomic::{AtomicU64, Ordering};
617    static STATE: AtomicU64 = AtomicU64::new(0x12345678_9abcdef0);
618    let mut s = STATE.load(Ordering::Relaxed);
619    s ^= s << 13;
620    s ^= s >> 7;
621    s ^= s << 17;
622    STATE.store(s, Ordering::Relaxed);
623    (s as f32) / (u64::MAX as f32)
624}
625
626/// Compression ratio using zlib deflate (matches Python whisper.utils.compression_ratio).
627fn compression_ratio(text: &str) -> f32 {
628    if text.is_empty() {
629        return 0.0;
630    }
631    use flate2::write::DeflateEncoder;
632    use flate2::Compression;
633    use std::io::Write;
634    let text_bytes = text.as_bytes();
635    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
636    encoder.write_all(text_bytes).unwrap();
637    let compressed = encoder.finish().unwrap();
638    text_bytes.len() as f32 / compressed.len().max(1) as f32
639}
640
641// ── Dummy KV cache + ModelExecutor trait impl ───────────────────────────
642
643#[derive(Clone, Debug)]
644#[allow(dead_code)]
645struct DummyWhisperCache;
646
647impl ferrum_interfaces::KvCacheHandle for DummyWhisperCache {
648    fn block_table(&self) -> &ferrum_interfaces::BlockTable {
649        static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
650            std::sync::OnceLock::new();
651        EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
652    }
653    fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
654        unimplemented!()
655    }
656    fn as_any(&self) -> &dyn std::any::Any {
657        self
658    }
659    fn device(&self) -> Device {
660        Device::CPU
661    }
662    fn num_layers(&self) -> usize {
663        0
664    }
665    fn num_heads(&self) -> usize {
666        0
667    }
668    fn head_dim(&self) -> usize {
669        0
670    }
671    fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
672        Ok(None)
673    }
674    fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
675        Ok(None)
676    }
677    fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
678        Ok(Arc::new(self.clone()))
679    }
680    fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
681        ferrum_interfaces::CacheHandleStats {
682            memory_bytes: 0,
683            blocks_allocated: 0,
684            tokens_stored: 0,
685            utilization: 0.0,
686            last_access: std::time::Instant::now(),
687        }
688    }
689    fn is_valid(&self) -> bool {
690        true
691    }
692    fn cache_id(&self) -> String {
693        "whisper_dummy".to_string()
694    }
695}
696
697#[async_trait]
698impl ModelExecutor for WhisperModelExecutor {
699    fn info(&self) -> &ModelInfo {
700        &self.info
701    }
702
703    async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
704        Err(FerrumError::model(
705            "Whisper uses transcribe(), not prefill/decode",
706        ))
707    }
708
709    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
710        Err(FerrumError::model(
711            "Whisper uses transcribe(), not prefill/decode",
712        ))
713    }
714
715    fn capabilities(&self) -> ExecutorCapabilities {
716        ExecutorCapabilities {
717            max_batch_size: 1,
718            max_sequence_length: self.info.max_sequence_length,
719            attention_mechanisms: vec![AttentionType::MultiHead],
720            supports_dynamic_batching: false,
721            supports_continuous_batching: false,
722            supports_speculative_decoding: false,
723            supports_tensor_parallelism: false,
724            supports_pipeline_parallelism: false,
725            supported_dtypes: vec![DataType::FP32],
726            supported_devices: vec![self.info.device.clone()],
727            memory_requirements: MemoryRequirements {
728                parameter_memory: 0,
729                activation_memory_per_token: 0,
730                kv_cache_memory_per_token: 0,
731                overhead_memory: 0,
732            },
733        }
734    }
735
736    fn release_cache(&self, _: &str) {}
737
738    fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
739        common::default_executor_status()
740    }
741}