1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use candle_core::{DType, Device as CandleDevice, Tensor};
12use ferrum_interfaces::{
13 model_executor::{
14 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
15 PrefillInput, PrefillOutput,
16 },
17 ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::whisper::WhisperModelWrapper;
24use crate::audio_processor;
25
26const TIMESTAMP_BEGIN: u32 = 50364;
30const INPUT_STRIDE: usize = 2; const TIME_PRECISION: f64 = 0.02; const NON_SPEECH_TOKENS: &[u32] = &[
36 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
37 366, 438, 532, 685, 691, 1060, 1258, 1261, 1435, 1436, 1652, 2028, 2029, 2150, 2404, 2932,
38 3292, 3455, 3723, 4100, 5751, 6283, 6347, 6436, 6615, 7579, 8765, 9929, 10563, 10813, 11318,
39 12380, 14117, 14397, 14734, 15003, 15068, 15206, 16450, 16805, 17193, 17832, 19063, 19438,
40 19635, 20203, 21111, 24220, 24408, 25212, 25830, 26622, 28156, 28279, 29464, 31650, 32302,
41 32470, 36865, 42863, 47425, 49870, 50254,
42];
43
44pub struct WhisperModelExecutor {
46 model: WhisperModelWrapper,
47 tokenizer: tokenizers::Tokenizer,
48 info: ModelInfo,
49 sot_token: u32,
51 eot_token: u32,
52 transcribe_token: u32,
53 translate_token: u32,
54 no_timestamps_token: u32,
55 no_speech_token: u32, sot_prev: u32,
57 sot_lm: u32,
58 language_tokens: HashMap<String, u32>,
59 suppress_token_ids: Vec<u32>,
61 sample_len: usize,
63}
64
65impl WhisperModelExecutor {
66 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
68 let dir = std::path::Path::new(model_path);
69
70 let model = WhisperModelWrapper::from_model_dir(dir, device, dtype)?;
71
72 let tokenizer = tokenizers::Tokenizer::from_file(dir.join("tokenizer.json"))
73 .map_err(|e| FerrumError::model(format!("load tokenizer: {e}")))?;
74
75 let sot_token = token_id(&tokenizer, "<|startoftranscript|>");
77 let eot_token = token_id(&tokenizer, "<|endoftext|>");
78 let transcribe_token = token_id(&tokenizer, "<|transcribe|>");
79 let translate_token = token_id(&tokenizer, "<|translate|>");
80 let no_timestamps_token = token_id(&tokenizer, "<|notimestamps|>");
81 let no_speech_token = token_id(&tokenizer, "<|nocaptions|>");
82 let sot_prev = token_id(&tokenizer, "<|startofprev|>");
83 let sot_lm = token_id(&tokenizer, "<|startoflm|>");
84
85 let mut language_tokens = HashMap::new();
87 for lang in &[
88 "en", "zh", "ja", "ko", "fr", "de", "es", "ru", "ar", "pt", "it", "nl", "tr", "pl",
89 "sv", "da", "fi", "hu", "cs", "ro", "bg", "uk", "el", "hr", "sk", "th", "vi", "id",
90 "ms", "hi", "ta", "te", "ur", "fa", "he", "ca", "gl", "eu", "la",
91 ] {
92 let token_str = format!("<|{lang}|>");
93 if let Some(id) = tokenizer.token_to_id(&token_str) {
94 language_tokens.insert(lang.to_string(), id);
95 }
96 }
97
98 let mut suppress_ids: Vec<u32> = NON_SPEECH_TOKENS.to_vec();
100 suppress_ids.extend_from_slice(&[
101 transcribe_token,
102 translate_token,
103 sot_token,
104 sot_prev,
105 sot_lm,
106 no_speech_token,
107 ]);
108 suppress_ids.sort();
109 suppress_ids.dedup();
110
111 let sample_len = model.config().max_target_positions / 2;
112
113 let info = ModelInfo {
114 model_id: ferrum_types::ModelId(model_path.to_string()),
115 model_type: ModelType::Custom("whisper".to_string()),
116 hidden_size: model.config().d_model,
117 vocab_size: model.config().vocab_size,
118 num_layers: model.config().encoder_layers + model.config().decoder_layers,
119 num_heads: model.config().encoder_attention_heads,
120 num_kv_heads: model.config().decoder_attention_heads,
121 num_parameters: 0,
122 max_sequence_length: model.config().max_target_positions,
123 device: Device::CPU,
124 dtype: DataType::FP32,
125 version: None,
126 license: None,
127 metadata: HashMap::new(),
128 };
129
130 info!(
131 "WhisperModelExecutor: {} (d_model={}, languages={}, suppress_tokens={})",
132 model_path,
133 model.config().d_model,
134 language_tokens.len(),
135 suppress_ids.len(),
136 );
137
138 Ok(Self {
139 model,
140 tokenizer,
141 info,
142 sot_token,
143 eot_token,
144 transcribe_token,
145 translate_token,
146 no_timestamps_token,
147 no_speech_token,
148 sot_prev,
149 sot_lm,
150 language_tokens,
151 suppress_token_ids: suppress_ids,
152 sample_len,
153 })
154 }
155
156 pub fn transcribe_file(&self, audio_path: &str, language: Option<&str>) -> Result<String> {
158 let pcm = audio_processor::load_audio(audio_path)?;
159 self.transcribe_pcm(&pcm, language)
160 }
161
162 pub fn transcribe_bytes(&self, audio_data: &[u8], language: Option<&str>) -> Result<String> {
164 let pcm = audio_processor::load_audio_bytes(audio_data)?;
165 self.transcribe_pcm(&pcm, language)
166 }
167
168 fn transcribe_pcm(&self, pcm: &[f32], language: Option<&str>) -> Result<String> {
175 let lang_token = language
176 .and_then(|l| self.language_tokens.get(l).copied())
177 .unwrap_or_else(|| {
178 self.language_tokens
179 .get("en")
180 .copied()
181 .unwrap_or(self.sot_token + 1)
182 });
183
184 let n_samples = candle_transformers::models::whisper::N_SAMPLES;
186 let n_frames = candle_transformers::models::whisper::N_FRAMES;
187 let mut padded_pcm = pcm.to_vec();
188 padded_pcm.resize(padded_pcm.len() + n_samples, 0.0); let content_frames = pcm.len() / candle_transformers::models::whisper::HOP_LENGTH;
190
191 let sot_sequence = vec![self.sot_token, lang_token, self.transcribe_token];
193 let sample_begin = sot_sequence.len();
194
195 let blank_token = 220u32; let temperatures: &[f32] = &[0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
200
201 let max_initial_timestamp_index: u32 = 50;
203
204 let mut seek: usize = 0;
205 let mut all_tokens: Vec<u32> = Vec::new();
206
207 while seek < content_frames {
208 let segment_size = n_frames.min(content_frames - seek);
209
210 let mel = self.mel_segment_at(&padded_pcm, seek, n_frames)?;
212
213 let encoder_out = self.model.encode(&mel)?;
215
216 let (tokens, avg_logprob, no_speech_prob, _temperature) = self.decode_with_fallback(
218 &encoder_out,
219 &sot_sequence,
220 sample_begin,
221 blank_token,
222 max_initial_timestamp_index,
223 temperatures,
224 )?;
225
226 let should_skip = no_speech_prob > 0.6 && avg_logprob < -1.0;
229 if should_skip {
230 seek += segment_size;
231 continue;
232 }
233
234 let sampled = &tokens[sample_begin..];
236 let timestamp_mask: Vec<bool> = sampled.iter().map(|&t| t >= TIMESTAMP_BEGIN).collect();
237
238 let mut consecutive_indices = Vec::new();
240 for i in 0..timestamp_mask.len().saturating_sub(1) {
241 if timestamp_mask[i] && timestamp_mask[i + 1] {
242 consecutive_indices.push(i + 1);
243 }
244 }
245
246 let text_tokens: Vec<u32> = sampled
248 .iter()
249 .copied()
250 .filter(|&t| t < self.eot_token)
251 .collect();
252 all_tokens.extend_from_slice(&text_tokens);
253
254 if !consecutive_indices.is_empty() {
255 let single_timestamp_ending = timestamp_mask.len() >= 2
257 && !timestamp_mask[timestamp_mask.len() - 2]
258 && timestamp_mask[timestamp_mask.len() - 1];
259
260 if single_timestamp_ending {
261 seek += segment_size;
262 } else {
263 let last_idx = *consecutive_indices.last().unwrap();
264 let last_ts_pos = (sampled[last_idx] - TIMESTAMP_BEGIN) as usize;
265 seek += last_ts_pos * INPUT_STRIDE;
266 }
267 } else {
268 let timestamps: Vec<u32> = sampled
270 .iter()
271 .copied()
272 .filter(|&t| t >= TIMESTAMP_BEGIN)
273 .collect();
274 if !timestamps.is_empty() && *timestamps.last().unwrap() != TIMESTAMP_BEGIN {
275 let last_ts_pos = (*timestamps.last().unwrap() - TIMESTAMP_BEGIN) as usize;
276 seek += last_ts_pos * INPUT_STRIDE;
277 } else {
278 seek += segment_size;
279 }
280 }
281 }
282
283 let text = self
285 .tokenizer
286 .decode(&all_tokens, true)
287 .map_err(|e| FerrumError::model(format!("decode tokens: {e}")))?;
288
289 Ok(text.trim().to_string())
290 }
291
292 fn mel_segment_at(&self, pcm: &[f32], seek_frames: usize, n_frames: usize) -> Result<Tensor> {
294 let hop = candle_transformers::models::whisper::HOP_LENGTH;
295 let start_sample = seek_frames * hop;
296 let n_samples = candle_transformers::models::whisper::N_SAMPLES;
297 let end_sample = (start_sample + n_samples).min(pcm.len());
298 let segment = &pcm[start_sample..end_sample];
299 self.model.pcm_to_mel_tensor(segment)
300 }
301
302 fn decode_with_fallback(
305 &self,
306 encoder_out: &Tensor,
307 sot_sequence: &[u32],
308 sample_begin: usize,
309 blank_token: u32,
310 max_initial_timestamp_index: u32,
311 temperatures: &[f32],
312 ) -> Result<(Vec<u32>, f32, f32, f32)> {
313 let mut last_result = None;
314
315 for &temp in temperatures {
316 let (tokens, avg_logprob, no_speech_prob) = self.decode_segment(
317 encoder_out,
318 sot_sequence,
319 sample_begin,
320 blank_token,
321 max_initial_timestamp_index,
322 temp,
323 )?;
324
325 let text_tokens: Vec<u32> = tokens[sample_begin..]
326 .iter()
327 .copied()
328 .filter(|&t| t < self.eot_token)
329 .collect();
330 let text = self
331 .tokenizer
332 .decode(&text_tokens, true)
333 .unwrap_or_default();
334
335 let cr = compression_ratio(&text);
336
337 let mut needs_fallback = false;
340 if cr > 2.4 {
341 needs_fallback = true;
342 }
343 if avg_logprob < -1.0 {
344 needs_fallback = true;
345 }
346 if no_speech_prob > 0.6 {
347 needs_fallback = false; }
349
350 last_result = Some((tokens, avg_logprob, no_speech_prob, temp));
351
352 if !needs_fallback {
353 break;
354 }
355 }
356
357 last_result.ok_or_else(|| FerrumError::model("decode_with_fallback: no result"))
358 }
359
360 fn decode_segment(
363 &self,
364 encoder_out: &Tensor,
365 sot_sequence: &[u32],
366 sample_begin: usize,
367 blank_token: u32,
368 max_initial_timestamp_index: u32,
369 temperature: f32,
370 ) -> Result<(Vec<u32>, f32, f32)> {
371 self.model.reset_decoder();
372
373 let mut tokens: Vec<u32> = sot_sequence.to_vec();
374 let mut sum_logprobs: f32 = 0.0;
375 let mut no_speech_prob: f32 = 0.0;
376 let mut n_text_tokens: usize = 0;
377
378 let mut logits = self.model.decode_step(&tokens, encoder_out)?;
380
381 for step in 0..self.sample_len {
382 if step == 0 {
384 let sot_logits = &logits; let probs = softmax_vec(sot_logits);
386 no_speech_prob = probs[self.no_speech_token as usize];
387 }
388
389 let sampled_tokens = &tokens[sample_begin..];
392
393 if sampled_tokens.is_empty() {
395 logits[blank_token as usize] = f32::NEG_INFINITY;
396 logits[self.eot_token as usize] = f32::NEG_INFINITY;
397 }
398
399 for &t in &self.suppress_token_ids {
401 if (t as usize) < logits.len() {
402 logits[t as usize] = f32::NEG_INFINITY;
403 }
404 }
405
406 self.apply_timestamp_rules(
408 &mut logits,
409 sampled_tokens,
410 sample_begin,
411 max_initial_timestamp_index,
412 step,
413 );
414
415 let next_token = if temperature == 0.0 {
417 argmax(&logits)
418 } else {
419 sample_with_temperature(&logits, temperature)
420 };
421
422 let log_probs = log_softmax_vec(&logits);
424 if next_token != self.eot_token {
425 sum_logprobs += log_probs[next_token as usize];
426 n_text_tokens += 1;
427 }
428
429 tokens.push(next_token);
430
431 if next_token == self.eot_token
432 || tokens.len() > self.model.config().max_target_positions
433 {
434 break;
435 }
436
437 let text_tail: Vec<u32> = tokens[sample_begin..]
440 .iter()
441 .copied()
442 .filter(|&t| t < TIMESTAMP_BEGIN && t != self.eot_token)
443 .collect();
444 if text_tail.len() >= 6 {
445 let last = *text_tail.last().unwrap();
446 let consecutive = text_tail.iter().rev().take_while(|&&t| t == last).count();
447 if consecutive >= 5 {
448 let mut keep = tokens.len();
450 let mut removed = 0;
451 while keep > sample_begin && removed < consecutive - 1 {
452 keep -= 1;
453 if tokens[keep] == last {
454 removed += 1;
455 }
456 }
457 tokens.truncate(keep + 1);
458 break;
459 }
460 }
461
462 logits = self.model.decode_step(&[next_token], encoder_out)?;
464 }
465
466 let avg_logprob = if n_text_tokens > 0 {
467 sum_logprobs / n_text_tokens as f32
468 } else {
469 f32::NEG_INFINITY
470 };
471
472 Ok((tokens, avg_logprob, no_speech_prob))
473 }
474
475 fn apply_timestamp_rules(
477 &self,
478 logits: &mut [f32],
479 sampled_tokens: &[u32],
480 _sample_begin: usize,
481 max_initial_timestamp_index: u32,
482 _step: usize,
483 ) {
484 let ts_begin = TIMESTAMP_BEGIN as usize;
485
486 logits[self.no_timestamps_token as usize] = f32::NEG_INFINITY;
488
489 let last_was_timestamp =
491 !sampled_tokens.is_empty() && *sampled_tokens.last().unwrap() >= TIMESTAMP_BEGIN;
492
493 let penultimate_was_timestamp =
494 sampled_tokens.len() < 2 || sampled_tokens[sampled_tokens.len() - 2] >= TIMESTAMP_BEGIN;
495
496 if last_was_timestamp {
497 if penultimate_was_timestamp {
498 for i in ts_begin..logits.len() {
500 logits[i] = f32::NEG_INFINITY;
501 }
502 } else {
503 for i in 0..self.eot_token as usize {
505 logits[i] = f32::NEG_INFINITY;
506 }
507 }
508 }
509
510 let timestamps: Vec<u32> = sampled_tokens
512 .iter()
513 .copied()
514 .filter(|&t| t >= TIMESTAMP_BEGIN)
515 .collect();
516 if !timestamps.is_empty() {
517 let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
518 *timestamps.last().unwrap()
519 } else {
520 *timestamps.last().unwrap() + 1
521 };
522 for i in ts_begin..timestamp_last as usize {
523 if i < logits.len() {
524 logits[i] = f32::NEG_INFINITY;
525 }
526 }
527 }
528
529 if sampled_tokens.is_empty() {
531 for i in 0..ts_begin {
532 logits[i] = f32::NEG_INFINITY;
533 }
534 let last_allowed = TIMESTAMP_BEGIN + max_initial_timestamp_index;
535 for i in (last_allowed as usize + 1)..logits.len() {
536 logits[i] = f32::NEG_INFINITY;
537 }
538 }
539
540 let log_probs = log_softmax_vec(logits);
542 let ts_logsumexp = {
543 let max_ts = log_probs[ts_begin..]
544 .iter()
545 .copied()
546 .fold(f32::NEG_INFINITY, f32::max);
547 if max_ts.is_finite() {
548 max_ts
549 + log_probs[ts_begin..]
550 .iter()
551 .map(|&lp| (lp - max_ts).exp())
552 .sum::<f32>()
553 .ln()
554 } else {
555 f32::NEG_INFINITY
556 }
557 };
558 let max_text_logprob = log_probs[..ts_begin]
559 .iter()
560 .copied()
561 .fold(f32::NEG_INFINITY, f32::max);
562
563 if ts_logsumexp > max_text_logprob {
564 for i in 0..ts_begin {
565 logits[i] = f32::NEG_INFINITY;
566 }
567 }
568 }
569}
570
571fn token_id(tokenizer: &tokenizers::Tokenizer, token: &str) -> u32 {
574 tokenizer.token_to_id(token).unwrap_or(0)
575}
576
577fn argmax(v: &[f32]) -> u32 {
578 v.iter()
579 .enumerate()
580 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
581 .map(|(i, _)| i as u32)
582 .unwrap_or(0)
583}
584
585fn softmax_vec(logits: &[f32]) -> Vec<f32> {
586 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
587 let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
588 let sum: f32 = exps.iter().sum();
589 exps.iter().map(|&e| e / sum).collect()
590}
591
592fn log_softmax_vec(logits: &[f32]) -> Vec<f32> {
593 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
594 let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
595 let log_sum = max + sum_exp.ln();
596 logits.iter().map(|&x| x - log_sum).collect()
597}
598
599fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
600 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
601 let probs = softmax_vec(&scaled);
602 let r: f32 = rand_f32();
604 let mut cumulative = 0.0;
605 for (i, &p) in probs.iter().enumerate() {
606 cumulative += p;
607 if cumulative >= r {
608 return i as u32;
609 }
610 }
611 (probs.len() - 1) as u32
612}
613
614fn rand_f32() -> f32 {
615 use std::sync::atomic::{AtomicU64, Ordering};
617 static STATE: AtomicU64 = AtomicU64::new(0x12345678_9abcdef0);
618 let mut s = STATE.load(Ordering::Relaxed);
619 s ^= s << 13;
620 s ^= s >> 7;
621 s ^= s << 17;
622 STATE.store(s, Ordering::Relaxed);
623 (s as f32) / (u64::MAX as f32)
624}
625
626fn compression_ratio(text: &str) -> f32 {
628 if text.is_empty() {
629 return 0.0;
630 }
631 use flate2::write::DeflateEncoder;
632 use flate2::Compression;
633 use std::io::Write;
634 let text_bytes = text.as_bytes();
635 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
636 encoder.write_all(text_bytes).unwrap();
637 let compressed = encoder.finish().unwrap();
638 text_bytes.len() as f32 / compressed.len().max(1) as f32
639}
640
641#[derive(Clone, Debug)]
644#[allow(dead_code)]
645struct DummyWhisperCache;
646
647impl ferrum_interfaces::KvCacheHandle for DummyWhisperCache {
648 fn block_table(&self) -> &ferrum_interfaces::BlockTable {
649 static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
650 std::sync::OnceLock::new();
651 EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
652 }
653 fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
654 unimplemented!()
655 }
656 fn as_any(&self) -> &dyn std::any::Any {
657 self
658 }
659 fn device(&self) -> Device {
660 Device::CPU
661 }
662 fn num_layers(&self) -> usize {
663 0
664 }
665 fn num_heads(&self) -> usize {
666 0
667 }
668 fn head_dim(&self) -> usize {
669 0
670 }
671 fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
672 Ok(None)
673 }
674 fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
675 Ok(None)
676 }
677 fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
678 Ok(Arc::new(self.clone()))
679 }
680 fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
681 ferrum_interfaces::CacheHandleStats {
682 memory_bytes: 0,
683 blocks_allocated: 0,
684 tokens_stored: 0,
685 utilization: 0.0,
686 last_access: std::time::Instant::now(),
687 }
688 }
689 fn is_valid(&self) -> bool {
690 true
691 }
692 fn cache_id(&self) -> String {
693 "whisper_dummy".to_string()
694 }
695}
696
697#[async_trait]
698impl ModelExecutor for WhisperModelExecutor {
699 fn info(&self) -> &ModelInfo {
700 &self.info
701 }
702
703 async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
704 Err(FerrumError::model(
705 "Whisper uses transcribe(), not prefill/decode",
706 ))
707 }
708
709 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
710 Err(FerrumError::model(
711 "Whisper uses transcribe(), not prefill/decode",
712 ))
713 }
714
715 fn capabilities(&self) -> ExecutorCapabilities {
716 ExecutorCapabilities {
717 max_batch_size: 1,
718 max_sequence_length: self.info.max_sequence_length,
719 attention_mechanisms: vec![AttentionType::MultiHead],
720 supports_dynamic_batching: false,
721 supports_continuous_batching: false,
722 supports_speculative_decoding: false,
723 supports_tensor_parallelism: false,
724 supports_pipeline_parallelism: false,
725 supported_dtypes: vec![DataType::FP32],
726 supported_devices: vec![self.info.device.clone()],
727 memory_requirements: MemoryRequirements {
728 parameter_memory: 0,
729 activation_memory_per_token: 0,
730 kv_cache_memory_per_token: 0,
731 overhead_memory: 0,
732 },
733 }
734 }
735
736 fn release_cache(&self, _: &str) {}
737
738 fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
739 common::default_executor_status()
740 }
741}