Skip to main content

ferrum_models/executor/
whisper_executor.rs

1//! Whisper ASR Executor — full decode pipeline matching Python whisper.
2//!
3//! Implements: timestamp-based sequential decode, logit suppression (SuppressBlank,
4//! SuppressTokens, ApplyTimestampRules), temperature fallback, compression ratio
5//! check, seek-based segmentation.
6
7#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use candle_core::{DType, Device as CandleDevice, Tensor};
14use ferrum_interfaces::{
15    model_executor::{
16        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17        PrefillInput, PrefillOutput,
18    },
19    ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::audio_processor;
26use crate::multimodal::whisper::WhisperModelWrapper;
27
28// ── Token constants ─────────────────────────────────────────────────────
29// These match Python whisper's tokenizer for multilingual models.
30
31const TIMESTAMP_BEGIN: u32 = 50364;
32const INPUT_STRIDE: usize = 2; // mel frames per output token (N_FRAMES / n_audio_ctx = 3000/1500)
33const TIME_PRECISION: f64 = 0.02; // seconds per timestamp token (INPUT_STRIDE * HOP_LENGTH / SAMPLE_RATE)
34
35/// Non-speech token IDs to suppress (from Python whisper tokenizer.non_speech_tokens).
36/// These are symbols that shouldn't appear in transcription (from Python whisper).
37const NON_SPEECH_TOKENS: &[u32] = &[
38    1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
39    366, 438, 532, 685, 691, 1060, 1258, 1261, 1435, 1436, 1652, 2028, 2029, 2150, 2404, 2932,
40    3292, 3455, 3723, 4100, 5751, 6283, 6347, 6436, 6615, 7579, 8765, 9929, 10563, 10813, 11318,
41    12380, 14117, 14397, 14734, 15003, 15068, 15206, 16450, 16805, 17193, 17832, 19063, 19438,
42    19635, 20203, 21111, 24220, 24408, 25212, 25830, 26622, 28156, 28279, 29464, 31650, 32302,
43    32470, 36865, 42863, 47425, 49870, 50254,
44];
45
46/// Whisper executor for speech-to-text.
47pub struct WhisperModelExecutor {
48    model: WhisperModelWrapper,
49    tokenizer: tokenizers::Tokenizer,
50    info: ModelInfo,
51    // Special token IDs
52    sot_token: u32,
53    eot_token: u32,
54    transcribe_token: u32,
55    translate_token: u32,
56    no_timestamps_token: u32,
57    no_speech_token: u32, // <|nocaptions|> = 50362
58    sot_prev: u32,
59    sot_lm: u32,
60    language_tokens: HashMap<String, u32>,
61    /// Precomputed suppress mask: token IDs that are always suppressed.
62    suppress_token_ids: Vec<u32>,
63    /// Sample length (max decode tokens per segment).
64    sample_len: usize,
65}
66
67impl WhisperModelExecutor {
68    /// Load from model directory.
69    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
70        let dir = std::path::Path::new(model_path);
71
72        let model = WhisperModelWrapper::from_model_dir(dir, device, dtype)?;
73
74        let tokenizer = tokenizers::Tokenizer::from_file(dir.join("tokenizer.json"))
75            .map_err(|e| FerrumError::model(format!("load tokenizer: {e}")))?;
76
77        // Resolve special token IDs
78        let sot_token = token_id(&tokenizer, "<|startoftranscript|>");
79        let eot_token = token_id(&tokenizer, "<|endoftext|>");
80        let transcribe_token = token_id(&tokenizer, "<|transcribe|>");
81        let translate_token = token_id(&tokenizer, "<|translate|>");
82        let no_timestamps_token = token_id(&tokenizer, "<|notimestamps|>");
83        let no_speech_token = token_id(&tokenizer, "<|nocaptions|>");
84        let sot_prev = token_id(&tokenizer, "<|startofprev|>");
85        let sot_lm = token_id(&tokenizer, "<|startoflm|>");
86
87        // Build language token map
88        let mut language_tokens = HashMap::new();
89        for lang in &[
90            "en", "zh", "ja", "ko", "fr", "de", "es", "ru", "ar", "pt", "it", "nl", "tr", "pl",
91            "sv", "da", "fi", "hu", "cs", "ro", "bg", "uk", "el", "hr", "sk", "th", "vi", "id",
92            "ms", "hi", "ta", "te", "ur", "fa", "he", "ca", "gl", "eu", "la",
93        ] {
94            let token_str = format!("<|{lang}|>");
95            if let Some(id) = tokenizer.token_to_id(&token_str) {
96                language_tokens.insert(lang.to_string(), id);
97            }
98        }
99
100        // Build suppress token list (matches Python _get_suppress_tokens)
101        let mut suppress_ids: Vec<u32> = NON_SPEECH_TOKENS.to_vec();
102        suppress_ids.extend_from_slice(&[
103            transcribe_token,
104            translate_token,
105            sot_token,
106            sot_prev,
107            sot_lm,
108            no_speech_token,
109        ]);
110        suppress_ids.sort();
111        suppress_ids.dedup();
112
113        let sample_len = model.config().max_target_positions / 2;
114
115        let info = ModelInfo {
116            model_id: ferrum_types::ModelId(model_path.to_string()),
117            model_type: ModelType::Custom("whisper".to_string()),
118            hidden_size: model.config().d_model,
119            vocab_size: model.config().vocab_size,
120            num_layers: model.config().encoder_layers + model.config().decoder_layers,
121            num_heads: model.config().encoder_attention_heads,
122            num_kv_heads: model.config().decoder_attention_heads,
123            num_parameters: 0,
124            max_sequence_length: model.config().max_target_positions,
125            device: Device::CPU,
126            dtype: DataType::FP32,
127            version: None,
128            license: None,
129            metadata: HashMap::new(),
130        };
131
132        info!(
133            "WhisperModelExecutor: {} (d_model={}, languages={}, suppress_tokens={})",
134            model_path,
135            model.config().d_model,
136            language_tokens.len(),
137            suppress_ids.len(),
138        );
139
140        Ok(Self {
141            model,
142            tokenizer,
143            info,
144            sot_token,
145            eot_token,
146            transcribe_token,
147            translate_token,
148            no_timestamps_token,
149            no_speech_token,
150            sot_prev,
151            sot_lm,
152            language_tokens,
153            suppress_token_ids: suppress_ids,
154            sample_len,
155        })
156    }
157
158    /// Transcribe audio file → text.
159    pub fn transcribe_file(&self, audio_path: &str, language: Option<&str>) -> Result<String> {
160        let pcm = audio_processor::load_audio(audio_path)?;
161        self.transcribe_pcm(&pcm, language)
162    }
163
164    /// Transcribe raw audio bytes (WAV/any) → text.
165    pub fn transcribe_bytes(&self, audio_data: &[u8], language: Option<&str>) -> Result<String> {
166        let pcm = audio_processor::load_audio_bytes(audio_data)?;
167        self.transcribe_pcm(&pcm, language)
168    }
169
170    /// Full transcription pipeline matching Python whisper.transcribe().
171    ///
172    /// - Computes mel for entire audio (padded by 30s of silence)
173    /// - Seek-based loop over 30s segments
174    /// - For each segment: encode → decode with timestamp rules → extract segments
175    /// - Temperature fallback on high compression ratio or low avg logprob
176    fn transcribe_pcm(&self, pcm: &[f32], language: Option<&str>) -> Result<String> {
177        let lang_token = language
178            .and_then(|l| self.language_tokens.get(l).copied())
179            .unwrap_or_else(|| {
180                self.language_tokens
181                    .get("en")
182                    .copied()
183                    .unwrap_or(self.sot_token + 1)
184            });
185
186        // Compute mel for full audio + 30s padding (matching Python: padding=N_SAMPLES)
187        let n_samples = candle_transformers::models::whisper::N_SAMPLES;
188        let n_frames = candle_transformers::models::whisper::N_FRAMES;
189        let mut padded_pcm = pcm.to_vec();
190        padded_pcm.resize(padded_pcm.len() + n_samples, 0.0); // 30s silence padding
191        let content_frames = pcm.len() / candle_transformers::models::whisper::HOP_LENGTH;
192
193        // Initial tokens: SOT + language + transcribe (WITHOUT no_timestamps — we use timestamps)
194        let sot_sequence = vec![self.sot_token, lang_token, self.transcribe_token];
195        let sample_begin = sot_sequence.len();
196
197        // Blank token for SuppressBlank
198        let blank_token = 220u32; // space token, matches Python tokenizer.encode(" ")
199
200        // Temperatures for fallback
201        let temperatures: &[f32] = &[0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
202
203        // Max initial timestamp: 1.0 second → index 50 (1.0 / 0.02)
204        let max_initial_timestamp_index: u32 = 50;
205
206        let mut seek: usize = 0;
207        let mut all_tokens: Vec<u32> = Vec::new();
208
209        while seek < content_frames {
210            let segment_size = n_frames.min(content_frames - seek);
211
212            // Extract mel segment
213            let mel = self.mel_segment_at(&padded_pcm, seek, n_frames)?;
214
215            // Encode
216            let encoder_out = self.model.encode(&mel)?;
217
218            // Decode with temperature fallback
219            let (tokens, avg_logprob, no_speech_prob, _temperature) = self.decode_with_fallback(
220                &encoder_out,
221                &sot_sequence,
222                sample_begin,
223                blank_token,
224                max_initial_timestamp_index,
225                temperatures,
226            )?;
227
228            // No speech check (matching Python transcribe.py):
229            // Skip segment if no_speech_prob is high, unless logprob is also high
230            let should_skip = no_speech_prob > 0.6 && avg_logprob < -1.0;
231            if should_skip {
232                seek += segment_size;
233                continue;
234            }
235
236            // Parse timestamp tokens to determine seek advancement
237            let sampled = &tokens[sample_begin..];
238            let timestamp_mask: Vec<bool> = sampled.iter().map(|&t| t >= TIMESTAMP_BEGIN).collect();
239
240            // Find consecutive timestamp pairs
241            let mut consecutive_indices = Vec::new();
242            for i in 0..timestamp_mask.len().saturating_sub(1) {
243                if timestamp_mask[i] && timestamp_mask[i + 1] {
244                    consecutive_indices.push(i + 1);
245                }
246            }
247
248            // Collect text tokens (strip timestamps and special tokens)
249            let text_tokens: Vec<u32> = sampled
250                .iter()
251                .copied()
252                .filter(|&t| t < self.eot_token)
253                .collect();
254            all_tokens.extend_from_slice(&text_tokens);
255
256            if !consecutive_indices.is_empty() {
257                // Has consecutive timestamps → use last timestamp to advance seek
258                let single_timestamp_ending = timestamp_mask.len() >= 2
259                    && !timestamp_mask[timestamp_mask.len() - 2]
260                    && timestamp_mask[timestamp_mask.len() - 1];
261
262                if single_timestamp_ending {
263                    seek += segment_size;
264                } else {
265                    let last_idx = *consecutive_indices.last().unwrap();
266                    let last_ts_pos = (sampled[last_idx] - TIMESTAMP_BEGIN) as usize;
267                    seek += last_ts_pos * INPUT_STRIDE;
268                }
269            } else {
270                // No consecutive timestamps — check for any single timestamp
271                let timestamps: Vec<u32> = sampled
272                    .iter()
273                    .copied()
274                    .filter(|&t| t >= TIMESTAMP_BEGIN)
275                    .collect();
276                if !timestamps.is_empty() && *timestamps.last().unwrap() != TIMESTAMP_BEGIN {
277                    let last_ts_pos = (*timestamps.last().unwrap() - TIMESTAMP_BEGIN) as usize;
278                    seek += last_ts_pos * INPUT_STRIDE;
279                } else {
280                    seek += segment_size;
281                }
282            }
283        }
284
285        // Decode all collected text tokens
286        let text = self
287            .tokenizer
288            .decode(&all_tokens, true)
289            .map_err(|e| FerrumError::model(format!("decode tokens: {e}")))?;
290
291        Ok(text.trim().to_string())
292    }
293
294    /// Extract mel segment at given seek position, pad to n_frames.
295    fn mel_segment_at(&self, pcm: &[f32], seek_frames: usize, n_frames: usize) -> Result<Tensor> {
296        let hop = candle_transformers::models::whisper::HOP_LENGTH;
297        let start_sample = seek_frames * hop;
298        let n_samples = candle_transformers::models::whisper::N_SAMPLES;
299        let end_sample = (start_sample + n_samples).min(pcm.len());
300        let segment = &pcm[start_sample..end_sample];
301        self.model.pcm_to_mel_tensor(segment)
302    }
303
304    /// Decode one segment with temperature fallback.
305    /// Returns (all_tokens, avg_logprob, no_speech_prob, temperature_used).
306    fn decode_with_fallback(
307        &self,
308        encoder_out: &Tensor,
309        sot_sequence: &[u32],
310        sample_begin: usize,
311        blank_token: u32,
312        max_initial_timestamp_index: u32,
313        temperatures: &[f32],
314    ) -> Result<(Vec<u32>, f32, f32, f32)> {
315        let mut last_result = None;
316
317        for &temp in temperatures {
318            let (tokens, avg_logprob, no_speech_prob) = self.decode_segment(
319                encoder_out,
320                sot_sequence,
321                sample_begin,
322                blank_token,
323                max_initial_timestamp_index,
324                temp,
325            )?;
326
327            let text_tokens: Vec<u32> = tokens[sample_begin..]
328                .iter()
329                .copied()
330                .filter(|&t| t < self.eot_token)
331                .collect();
332            let text = self
333                .tokenizer
334                .decode(&text_tokens, true)
335                .unwrap_or_default();
336
337            let cr = compression_ratio(&text);
338
339            // Matching Python: fallback if too repetitive or logprob too low,
340            // but NOT if it's silence (high no_speech_prob overrides).
341            let mut needs_fallback = false;
342            if cr > 2.4 {
343                needs_fallback = true;
344            }
345            if avg_logprob < -1.0 {
346                needs_fallback = true;
347            }
348            if no_speech_prob > 0.6 {
349                needs_fallback = false; // silence — accept as-is
350            }
351
352            last_result = Some((tokens, avg_logprob, no_speech_prob, temp));
353
354            if !needs_fallback {
355                break;
356            }
357        }
358
359        last_result.ok_or_else(|| FerrumError::model("decode_with_fallback: no result"))
360    }
361
362    /// Decode one segment at a given temperature.
363    /// Returns (full_token_sequence, avg_logprob, no_speech_prob).
364    fn decode_segment(
365        &self,
366        encoder_out: &Tensor,
367        sot_sequence: &[u32],
368        sample_begin: usize,
369        blank_token: u32,
370        max_initial_timestamp_index: u32,
371        temperature: f32,
372    ) -> Result<(Vec<u32>, f32, f32)> {
373        self.model.reset_decoder();
374
375        let mut tokens: Vec<u32> = sot_sequence.to_vec();
376        let mut sum_logprobs: f32 = 0.0;
377        let mut no_speech_prob: f32 = 0.0;
378        let mut n_text_tokens: usize = 0;
379
380        // First forward: feed all initial tokens
381        let mut logits = self.model.decode_step(&tokens, encoder_out)?;
382
383        for step in 0..self.sample_len {
384            // On first step, capture no_speech_prob
385            if step == 0 {
386                let sot_logits = &logits; // logits from last position of initial tokens
387                let probs = softmax_vec(sot_logits);
388                no_speech_prob = probs[self.no_speech_token as usize];
389            }
390
391            // ─── Apply logit filters (matching Python) ───
392
393            let sampled_tokens = &tokens[sample_begin..];
394
395            // 1. SuppressBlank: on first text token, suppress blank and EOT
396            if sampled_tokens.is_empty() {
397                logits[blank_token as usize] = f32::NEG_INFINITY;
398                logits[self.eot_token as usize] = f32::NEG_INFINITY;
399            }
400
401            // 2. SuppressTokens: always suppress non-speech + special tokens
402            for &t in &self.suppress_token_ids {
403                if (t as usize) < logits.len() {
404                    logits[t as usize] = f32::NEG_INFINITY;
405                }
406            }
407
408            // 3. ApplyTimestampRules
409            self.apply_timestamp_rules(
410                &mut logits,
411                sampled_tokens,
412                sample_begin,
413                max_initial_timestamp_index,
414                step,
415            );
416
417            // ─── Select next token ───
418            let next_token = if temperature == 0.0 {
419                argmax(&logits)
420            } else {
421                sample_with_temperature(&logits, temperature)
422            };
423
424            // Track logprobs
425            let log_probs = log_softmax_vec(&logits);
426            if next_token != self.eot_token {
427                sum_logprobs += log_probs[next_token as usize];
428                n_text_tokens += 1;
429            }
430
431            tokens.push(next_token);
432
433            if next_token == self.eot_token
434                || tokens.len() > self.model.config().max_target_positions
435            {
436                break;
437            }
438
439            // Repetition detection on text tokens only (ignoring timestamps).
440            // If a text token repeats > 5 times in the last 10 text tokens, stop.
441            let text_tail: Vec<u32> = tokens[sample_begin..]
442                .iter()
443                .copied()
444                .filter(|&t| t < TIMESTAMP_BEGIN && t != self.eot_token)
445                .collect();
446            if text_tail.len() >= 6 {
447                let last = *text_tail.last().unwrap();
448                let consecutive = text_tail.iter().rev().take_while(|&&t| t == last).count();
449                if consecutive >= 5 {
450                    // Trim: remove the repeated tokens, keep 1
451                    let mut keep = tokens.len();
452                    let mut removed = 0;
453                    while keep > sample_begin && removed < consecutive - 1 {
454                        keep -= 1;
455                        if tokens[keep] == last {
456                            removed += 1;
457                        }
458                    }
459                    tokens.truncate(keep + 1);
460                    break;
461                }
462            }
463
464            // Next step: feed only the new token
465            logits = self.model.decode_step(&[next_token], encoder_out)?;
466        }
467
468        let avg_logprob = if n_text_tokens > 0 {
469            sum_logprobs / n_text_tokens as f32
470        } else {
471            f32::NEG_INFINITY
472        };
473
474        Ok((tokens, avg_logprob, no_speech_prob))
475    }
476
477    /// Apply timestamp rules (matching Python ApplyTimestampRules).
478    fn apply_timestamp_rules(
479        &self,
480        logits: &mut [f32],
481        sampled_tokens: &[u32],
482        _sample_begin: usize,
483        max_initial_timestamp_index: u32,
484        _step: usize,
485    ) {
486        let ts_begin = TIMESTAMP_BEGIN as usize;
487
488        // Suppress <|notimestamps|>
489        logits[self.no_timestamps_token as usize] = f32::NEG_INFINITY;
490
491        // Timestamp pairing rules
492        let last_was_timestamp =
493            !sampled_tokens.is_empty() && *sampled_tokens.last().unwrap() >= TIMESTAMP_BEGIN;
494
495        let penultimate_was_timestamp =
496            sampled_tokens.len() < 2 || sampled_tokens[sampled_tokens.len() - 2] >= TIMESTAMP_BEGIN;
497
498        if last_was_timestamp {
499            if penultimate_was_timestamp {
500                // Two timestamps in a row → must produce non-timestamp
501                for i in ts_begin..logits.len() {
502                    logits[i] = f32::NEG_INFINITY;
503                }
504            } else {
505                // Timestamp after text → must produce timestamp or EOT (suppress all text)
506                for i in 0..self.eot_token as usize {
507                    logits[i] = f32::NEG_INFINITY;
508                }
509            }
510        }
511
512        // Monotonically increasing timestamps
513        let timestamps: Vec<u32> = sampled_tokens
514            .iter()
515            .copied()
516            .filter(|&t| t >= TIMESTAMP_BEGIN)
517            .collect();
518        if !timestamps.is_empty() {
519            let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
520                *timestamps.last().unwrap()
521            } else {
522                *timestamps.last().unwrap() + 1
523            };
524            for i in ts_begin..timestamp_last as usize {
525                if i < logits.len() {
526                    logits[i] = f32::NEG_INFINITY;
527                }
528            }
529        }
530
531        // First token: must be a timestamp, constrained by max_initial_timestamp
532        if sampled_tokens.is_empty() {
533            for i in 0..ts_begin {
534                logits[i] = f32::NEG_INFINITY;
535            }
536            let last_allowed = TIMESTAMP_BEGIN + max_initial_timestamp_index;
537            for i in (last_allowed as usize + 1)..logits.len() {
538                logits[i] = f32::NEG_INFINITY;
539            }
540        }
541
542        // If sum of timestamp probabilities > max text token probability, force timestamp
543        let log_probs = log_softmax_vec(logits);
544        let ts_logsumexp = {
545            let max_ts = log_probs[ts_begin..]
546                .iter()
547                .copied()
548                .fold(f32::NEG_INFINITY, f32::max);
549            if max_ts.is_finite() {
550                max_ts
551                    + log_probs[ts_begin..]
552                        .iter()
553                        .map(|&lp| (lp - max_ts).exp())
554                        .sum::<f32>()
555                        .ln()
556            } else {
557                f32::NEG_INFINITY
558            }
559        };
560        let max_text_logprob = log_probs[..ts_begin]
561            .iter()
562            .copied()
563            .fold(f32::NEG_INFINITY, f32::max);
564
565        if ts_logsumexp > max_text_logprob {
566            for i in 0..ts_begin {
567                logits[i] = f32::NEG_INFINITY;
568            }
569        }
570    }
571}
572
573// ── Utility functions ───────────────────────────────────────────────────
574
575fn token_id(tokenizer: &tokenizers::Tokenizer, token: &str) -> u32 {
576    tokenizer.token_to_id(token).unwrap_or(0)
577}
578
579fn argmax(v: &[f32]) -> u32 {
580    v.iter()
581        .enumerate()
582        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
583        .map(|(i, _)| i as u32)
584        .unwrap_or(0)
585}
586
587fn softmax_vec(logits: &[f32]) -> Vec<f32> {
588    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
589    let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
590    let sum: f32 = exps.iter().sum();
591    exps.iter().map(|&e| e / sum).collect()
592}
593
594fn log_softmax_vec(logits: &[f32]) -> Vec<f32> {
595    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
596    let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
597    let log_sum = max + sum_exp.ln();
598    logits.iter().map(|&x| x - log_sum).collect()
599}
600
601fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
602    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
603    let probs = softmax_vec(&scaled);
604    // Weighted random sampling
605    let r: f32 = rand_f32();
606    let mut cumulative = 0.0;
607    for (i, &p) in probs.iter().enumerate() {
608        cumulative += p;
609        if cumulative >= r {
610            return i as u32;
611        }
612    }
613    (probs.len() - 1) as u32
614}
615
616fn rand_f32() -> f32 {
617    // Simple xorshift-based PRNG (good enough for temperature sampling)
618    use std::sync::atomic::{AtomicU64, Ordering};
619    static STATE: AtomicU64 = AtomicU64::new(0x12345678_9abcdef0);
620    let mut s = STATE.load(Ordering::Relaxed);
621    s ^= s << 13;
622    s ^= s >> 7;
623    s ^= s << 17;
624    STATE.store(s, Ordering::Relaxed);
625    (s as f32) / (u64::MAX as f32)
626}
627
628/// Compression ratio using zlib deflate (matches Python whisper.utils.compression_ratio).
629fn compression_ratio(text: &str) -> f32 {
630    if text.is_empty() {
631        return 0.0;
632    }
633    use flate2::write::DeflateEncoder;
634    use flate2::Compression;
635    use std::io::Write;
636    let text_bytes = text.as_bytes();
637    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
638    encoder.write_all(text_bytes).unwrap();
639    let compressed = encoder.finish().unwrap();
640    text_bytes.len() as f32 / compressed.len().max(1) as f32
641}
642
643// ── Dummy KV cache + ModelExecutor trait impl ───────────────────────────
644
645#[derive(Clone, Debug)]
646#[allow(dead_code)]
647struct DummyWhisperCache;
648
649impl ferrum_interfaces::KvCacheHandle for DummyWhisperCache {
650    fn block_table(&self) -> &ferrum_interfaces::BlockTable {
651        static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
652            std::sync::OnceLock::new();
653        EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
654    }
655    fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
656        unimplemented!()
657    }
658    fn as_any(&self) -> &dyn std::any::Any {
659        self
660    }
661    fn device(&self) -> Device {
662        Device::CPU
663    }
664    fn num_layers(&self) -> usize {
665        0
666    }
667    fn num_heads(&self) -> usize {
668        0
669    }
670    fn head_dim(&self) -> usize {
671        0
672    }
673    fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
674        Ok(None)
675    }
676    fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
677        Ok(None)
678    }
679    fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
680        Ok(Arc::new(self.clone()))
681    }
682    fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
683        ferrum_interfaces::CacheHandleStats {
684            memory_bytes: 0,
685            blocks_allocated: 0,
686            tokens_stored: 0,
687            utilization: 0.0,
688            last_access: std::time::Instant::now(),
689        }
690    }
691    fn is_valid(&self) -> bool {
692        true
693    }
694    fn cache_id(&self) -> String {
695        "whisper_dummy".to_string()
696    }
697}
698
699#[async_trait]
700impl ModelExecutor for WhisperModelExecutor {
701    fn info(&self) -> &ModelInfo {
702        &self.info
703    }
704
705    async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
706        Err(FerrumError::model(
707            "Whisper uses transcribe(), not prefill/decode",
708        ))
709    }
710
711    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
712        Err(FerrumError::model(
713            "Whisper uses transcribe(), not prefill/decode",
714        ))
715    }
716
717    fn capabilities(&self) -> ExecutorCapabilities {
718        ExecutorCapabilities {
719            max_batch_size: 1,
720            max_sequence_length: self.info.max_sequence_length,
721            attention_mechanisms: vec![AttentionType::MultiHead],
722            supports_dynamic_batching: false,
723            supports_continuous_batching: false,
724            supports_speculative_decoding: false,
725            supports_tensor_parallelism: false,
726            supports_pipeline_parallelism: false,
727            supported_dtypes: vec![DataType::FP32],
728            supported_devices: vec![self.info.device.clone()],
729            memory_requirements: MemoryRequirements {
730                parameter_memory: 0,
731                activation_memory_per_token: 0,
732                kv_cache_memory_per_token: 0,
733                overhead_memory: 0,
734            },
735        }
736    }
737
738    fn release_cache(&self, _: &str) {}
739
740    fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
741        common::default_executor_status()
742    }
743}