1#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use candle_core::{DType, Device as CandleDevice, Tensor};
14use ferrum_interfaces::{
15 model_executor::{
16 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17 PrefillInput, PrefillOutput,
18 },
19 ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::architectures::whisper::WhisperModelWrapper;
26use crate::audio_processor;
27
28const TIMESTAMP_BEGIN: u32 = 50364;
32const INPUT_STRIDE: usize = 2; const TIME_PRECISION: f64 = 0.02; const NON_SPEECH_TOKENS: &[u32] = &[
38 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
39 366, 438, 532, 685, 691, 1060, 1258, 1261, 1435, 1436, 1652, 2028, 2029, 2150, 2404, 2932,
40 3292, 3455, 3723, 4100, 5751, 6283, 6347, 6436, 6615, 7579, 8765, 9929, 10563, 10813, 11318,
41 12380, 14117, 14397, 14734, 15003, 15068, 15206, 16450, 16805, 17193, 17832, 19063, 19438,
42 19635, 20203, 21111, 24220, 24408, 25212, 25830, 26622, 28156, 28279, 29464, 31650, 32302,
43 32470, 36865, 42863, 47425, 49870, 50254,
44];
45
46pub struct WhisperModelExecutor {
48 model: WhisperModelWrapper,
49 tokenizer: tokenizers::Tokenizer,
50 info: ModelInfo,
51 sot_token: u32,
53 eot_token: u32,
54 transcribe_token: u32,
55 translate_token: u32,
56 no_timestamps_token: u32,
57 no_speech_token: u32, sot_prev: u32,
59 sot_lm: u32,
60 language_tokens: HashMap<String, u32>,
61 suppress_token_ids: Vec<u32>,
63 sample_len: usize,
65}
66
67impl WhisperModelExecutor {
68 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
70 let dir = std::path::Path::new(model_path);
71
72 let model = WhisperModelWrapper::from_model_dir(dir, device, dtype)?;
73
74 let tokenizer = tokenizers::Tokenizer::from_file(dir.join("tokenizer.json"))
75 .map_err(|e| FerrumError::model(format!("load tokenizer: {e}")))?;
76
77 let sot_token = token_id(&tokenizer, "<|startoftranscript|>");
79 let eot_token = token_id(&tokenizer, "<|endoftext|>");
80 let transcribe_token = token_id(&tokenizer, "<|transcribe|>");
81 let translate_token = token_id(&tokenizer, "<|translate|>");
82 let no_timestamps_token = token_id(&tokenizer, "<|notimestamps|>");
83 let no_speech_token = token_id(&tokenizer, "<|nocaptions|>");
84 let sot_prev = token_id(&tokenizer, "<|startofprev|>");
85 let sot_lm = token_id(&tokenizer, "<|startoflm|>");
86
87 let mut language_tokens = HashMap::new();
89 for lang in &[
90 "en", "zh", "ja", "ko", "fr", "de", "es", "ru", "ar", "pt", "it", "nl", "tr", "pl",
91 "sv", "da", "fi", "hu", "cs", "ro", "bg", "uk", "el", "hr", "sk", "th", "vi", "id",
92 "ms", "hi", "ta", "te", "ur", "fa", "he", "ca", "gl", "eu", "la",
93 ] {
94 let token_str = format!("<|{lang}|>");
95 if let Some(id) = tokenizer.token_to_id(&token_str) {
96 language_tokens.insert(lang.to_string(), id);
97 }
98 }
99
100 let mut suppress_ids: Vec<u32> = NON_SPEECH_TOKENS.to_vec();
102 suppress_ids.extend_from_slice(&[
103 transcribe_token,
104 translate_token,
105 sot_token,
106 sot_prev,
107 sot_lm,
108 no_speech_token,
109 ]);
110 suppress_ids.sort();
111 suppress_ids.dedup();
112
113 let sample_len = model.config().max_target_positions / 2;
114
115 let info = ModelInfo {
116 model_id: ferrum_types::ModelId(model_path.to_string()),
117 model_type: ModelType::Custom("whisper".to_string()),
118 hidden_size: model.config().d_model,
119 vocab_size: model.config().vocab_size,
120 num_layers: model.config().encoder_layers + model.config().decoder_layers,
121 num_heads: model.config().encoder_attention_heads,
122 num_kv_heads: model.config().decoder_attention_heads,
123 num_parameters: 0,
124 max_sequence_length: model.config().max_target_positions,
125 device: Device::CPU,
126 dtype: DataType::FP32,
127 version: None,
128 license: None,
129 metadata: HashMap::new(),
130 };
131
132 info!(
133 "WhisperModelExecutor: {} (d_model={}, languages={}, suppress_tokens={})",
134 model_path,
135 model.config().d_model,
136 language_tokens.len(),
137 suppress_ids.len(),
138 );
139
140 Ok(Self {
141 model,
142 tokenizer,
143 info,
144 sot_token,
145 eot_token,
146 transcribe_token,
147 translate_token,
148 no_timestamps_token,
149 no_speech_token,
150 sot_prev,
151 sot_lm,
152 language_tokens,
153 suppress_token_ids: suppress_ids,
154 sample_len,
155 })
156 }
157
158 pub fn transcribe_file(&self, audio_path: &str, language: Option<&str>) -> Result<String> {
160 let pcm = audio_processor::load_audio(audio_path)?;
161 self.transcribe_pcm(&pcm, language)
162 }
163
164 pub fn transcribe_bytes(&self, audio_data: &[u8], language: Option<&str>) -> Result<String> {
166 let pcm = audio_processor::load_audio_bytes(audio_data)?;
167 self.transcribe_pcm(&pcm, language)
168 }
169
170 fn transcribe_pcm(&self, pcm: &[f32], language: Option<&str>) -> Result<String> {
177 let lang_token = language
178 .and_then(|l| self.language_tokens.get(l).copied())
179 .unwrap_or_else(|| {
180 self.language_tokens
181 .get("en")
182 .copied()
183 .unwrap_or(self.sot_token + 1)
184 });
185
186 let n_samples = candle_transformers::models::whisper::N_SAMPLES;
188 let n_frames = candle_transformers::models::whisper::N_FRAMES;
189 let mut padded_pcm = pcm.to_vec();
190 padded_pcm.resize(padded_pcm.len() + n_samples, 0.0); let content_frames = pcm.len() / candle_transformers::models::whisper::HOP_LENGTH;
192
193 let sot_sequence = vec![self.sot_token, lang_token, self.transcribe_token];
195 let sample_begin = sot_sequence.len();
196
197 let blank_token = 220u32; let temperatures: &[f32] = &[0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
202
203 let max_initial_timestamp_index: u32 = 50;
205
206 let mut seek: usize = 0;
207 let mut all_tokens: Vec<u32> = Vec::new();
208
209 while seek < content_frames {
210 let segment_size = n_frames.min(content_frames - seek);
211
212 let mel = self.mel_segment_at(&padded_pcm, seek, n_frames)?;
214
215 let encoder_out = self.model.encode(&mel)?;
217
218 let (tokens, avg_logprob, no_speech_prob, _temperature) = self.decode_with_fallback(
220 &encoder_out,
221 &sot_sequence,
222 sample_begin,
223 blank_token,
224 max_initial_timestamp_index,
225 temperatures,
226 )?;
227
228 let should_skip = no_speech_prob > 0.6 && avg_logprob < -1.0;
231 if should_skip {
232 seek += segment_size;
233 continue;
234 }
235
236 let sampled = &tokens[sample_begin..];
238 let timestamp_mask: Vec<bool> = sampled.iter().map(|&t| t >= TIMESTAMP_BEGIN).collect();
239
240 let mut consecutive_indices = Vec::new();
242 for i in 0..timestamp_mask.len().saturating_sub(1) {
243 if timestamp_mask[i] && timestamp_mask[i + 1] {
244 consecutive_indices.push(i + 1);
245 }
246 }
247
248 let text_tokens: Vec<u32> = sampled
250 .iter()
251 .copied()
252 .filter(|&t| t < self.eot_token)
253 .collect();
254 all_tokens.extend_from_slice(&text_tokens);
255
256 if !consecutive_indices.is_empty() {
257 let single_timestamp_ending = timestamp_mask.len() >= 2
259 && !timestamp_mask[timestamp_mask.len() - 2]
260 && timestamp_mask[timestamp_mask.len() - 1];
261
262 if single_timestamp_ending {
263 seek += segment_size;
264 } else {
265 let last_idx = *consecutive_indices.last().unwrap();
266 let last_ts_pos = (sampled[last_idx] - TIMESTAMP_BEGIN) as usize;
267 seek += last_ts_pos * INPUT_STRIDE;
268 }
269 } else {
270 let timestamps: Vec<u32> = sampled
272 .iter()
273 .copied()
274 .filter(|&t| t >= TIMESTAMP_BEGIN)
275 .collect();
276 if !timestamps.is_empty() && *timestamps.last().unwrap() != TIMESTAMP_BEGIN {
277 let last_ts_pos = (*timestamps.last().unwrap() - TIMESTAMP_BEGIN) as usize;
278 seek += last_ts_pos * INPUT_STRIDE;
279 } else {
280 seek += segment_size;
281 }
282 }
283 }
284
285 let text = self
287 .tokenizer
288 .decode(&all_tokens, true)
289 .map_err(|e| FerrumError::model(format!("decode tokens: {e}")))?;
290
291 Ok(text.trim().to_string())
292 }
293
294 fn mel_segment_at(&self, pcm: &[f32], seek_frames: usize, n_frames: usize) -> Result<Tensor> {
296 let hop = candle_transformers::models::whisper::HOP_LENGTH;
297 let start_sample = seek_frames * hop;
298 let n_samples = candle_transformers::models::whisper::N_SAMPLES;
299 let end_sample = (start_sample + n_samples).min(pcm.len());
300 let segment = &pcm[start_sample..end_sample];
301 self.model.pcm_to_mel_tensor(segment)
302 }
303
304 fn decode_with_fallback(
307 &self,
308 encoder_out: &Tensor,
309 sot_sequence: &[u32],
310 sample_begin: usize,
311 blank_token: u32,
312 max_initial_timestamp_index: u32,
313 temperatures: &[f32],
314 ) -> Result<(Vec<u32>, f32, f32, f32)> {
315 let mut last_result = None;
316
317 for &temp in temperatures {
318 let (tokens, avg_logprob, no_speech_prob) = self.decode_segment(
319 encoder_out,
320 sot_sequence,
321 sample_begin,
322 blank_token,
323 max_initial_timestamp_index,
324 temp,
325 )?;
326
327 let text_tokens: Vec<u32> = tokens[sample_begin..]
328 .iter()
329 .copied()
330 .filter(|&t| t < self.eot_token)
331 .collect();
332 let text = self
333 .tokenizer
334 .decode(&text_tokens, true)
335 .unwrap_or_default();
336
337 let cr = compression_ratio(&text);
338
339 let mut needs_fallback = false;
342 if cr > 2.4 {
343 needs_fallback = true;
344 }
345 if avg_logprob < -1.0 {
346 needs_fallback = true;
347 }
348 if no_speech_prob > 0.6 {
349 needs_fallback = false; }
351
352 last_result = Some((tokens, avg_logprob, no_speech_prob, temp));
353
354 if !needs_fallback {
355 break;
356 }
357 }
358
359 last_result.ok_or_else(|| FerrumError::model("decode_with_fallback: no result"))
360 }
361
362 fn decode_segment(
365 &self,
366 encoder_out: &Tensor,
367 sot_sequence: &[u32],
368 sample_begin: usize,
369 blank_token: u32,
370 max_initial_timestamp_index: u32,
371 temperature: f32,
372 ) -> Result<(Vec<u32>, f32, f32)> {
373 self.model.reset_decoder();
374
375 let mut tokens: Vec<u32> = sot_sequence.to_vec();
376 let mut sum_logprobs: f32 = 0.0;
377 let mut no_speech_prob: f32 = 0.0;
378 let mut n_text_tokens: usize = 0;
379
380 let mut logits = self.model.decode_step(&tokens, encoder_out)?;
382
383 for step in 0..self.sample_len {
384 if step == 0 {
386 let sot_logits = &logits; let probs = softmax_vec(sot_logits);
388 no_speech_prob = probs[self.no_speech_token as usize];
389 }
390
391 let sampled_tokens = &tokens[sample_begin..];
394
395 if sampled_tokens.is_empty() {
397 logits[blank_token as usize] = f32::NEG_INFINITY;
398 logits[self.eot_token as usize] = f32::NEG_INFINITY;
399 }
400
401 for &t in &self.suppress_token_ids {
403 if (t as usize) < logits.len() {
404 logits[t as usize] = f32::NEG_INFINITY;
405 }
406 }
407
408 self.apply_timestamp_rules(
410 &mut logits,
411 sampled_tokens,
412 sample_begin,
413 max_initial_timestamp_index,
414 step,
415 );
416
417 let next_token = if temperature == 0.0 {
419 argmax(&logits)
420 } else {
421 sample_with_temperature(&logits, temperature)
422 };
423
424 let log_probs = log_softmax_vec(&logits);
426 if next_token != self.eot_token {
427 sum_logprobs += log_probs[next_token as usize];
428 n_text_tokens += 1;
429 }
430
431 tokens.push(next_token);
432
433 if next_token == self.eot_token
434 || tokens.len() > self.model.config().max_target_positions
435 {
436 break;
437 }
438
439 let text_tail: Vec<u32> = tokens[sample_begin..]
442 .iter()
443 .copied()
444 .filter(|&t| t < TIMESTAMP_BEGIN && t != self.eot_token)
445 .collect();
446 if text_tail.len() >= 6 {
447 let last = *text_tail.last().unwrap();
448 let consecutive = text_tail.iter().rev().take_while(|&&t| t == last).count();
449 if consecutive >= 5 {
450 let mut keep = tokens.len();
452 let mut removed = 0;
453 while keep > sample_begin && removed < consecutive - 1 {
454 keep -= 1;
455 if tokens[keep] == last {
456 removed += 1;
457 }
458 }
459 tokens.truncate(keep + 1);
460 break;
461 }
462 }
463
464 logits = self.model.decode_step(&[next_token], encoder_out)?;
466 }
467
468 let avg_logprob = if n_text_tokens > 0 {
469 sum_logprobs / n_text_tokens as f32
470 } else {
471 f32::NEG_INFINITY
472 };
473
474 Ok((tokens, avg_logprob, no_speech_prob))
475 }
476
477 fn apply_timestamp_rules(
479 &self,
480 logits: &mut [f32],
481 sampled_tokens: &[u32],
482 _sample_begin: usize,
483 max_initial_timestamp_index: u32,
484 _step: usize,
485 ) {
486 let ts_begin = TIMESTAMP_BEGIN as usize;
487
488 logits[self.no_timestamps_token as usize] = f32::NEG_INFINITY;
490
491 let last_was_timestamp =
493 !sampled_tokens.is_empty() && *sampled_tokens.last().unwrap() >= TIMESTAMP_BEGIN;
494
495 let penultimate_was_timestamp =
496 sampled_tokens.len() < 2 || sampled_tokens[sampled_tokens.len() - 2] >= TIMESTAMP_BEGIN;
497
498 if last_was_timestamp {
499 if penultimate_was_timestamp {
500 for i in ts_begin..logits.len() {
502 logits[i] = f32::NEG_INFINITY;
503 }
504 } else {
505 for i in 0..self.eot_token as usize {
507 logits[i] = f32::NEG_INFINITY;
508 }
509 }
510 }
511
512 let timestamps: Vec<u32> = sampled_tokens
514 .iter()
515 .copied()
516 .filter(|&t| t >= TIMESTAMP_BEGIN)
517 .collect();
518 if !timestamps.is_empty() {
519 let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
520 *timestamps.last().unwrap()
521 } else {
522 *timestamps.last().unwrap() + 1
523 };
524 for i in ts_begin..timestamp_last as usize {
525 if i < logits.len() {
526 logits[i] = f32::NEG_INFINITY;
527 }
528 }
529 }
530
531 if sampled_tokens.is_empty() {
533 for i in 0..ts_begin {
534 logits[i] = f32::NEG_INFINITY;
535 }
536 let last_allowed = TIMESTAMP_BEGIN + max_initial_timestamp_index;
537 for i in (last_allowed as usize + 1)..logits.len() {
538 logits[i] = f32::NEG_INFINITY;
539 }
540 }
541
542 let log_probs = log_softmax_vec(logits);
544 let ts_logsumexp = {
545 let max_ts = log_probs[ts_begin..]
546 .iter()
547 .copied()
548 .fold(f32::NEG_INFINITY, f32::max);
549 if max_ts.is_finite() {
550 max_ts
551 + log_probs[ts_begin..]
552 .iter()
553 .map(|&lp| (lp - max_ts).exp())
554 .sum::<f32>()
555 .ln()
556 } else {
557 f32::NEG_INFINITY
558 }
559 };
560 let max_text_logprob = log_probs[..ts_begin]
561 .iter()
562 .copied()
563 .fold(f32::NEG_INFINITY, f32::max);
564
565 if ts_logsumexp > max_text_logprob {
566 for i in 0..ts_begin {
567 logits[i] = f32::NEG_INFINITY;
568 }
569 }
570 }
571}
572
573fn token_id(tokenizer: &tokenizers::Tokenizer, token: &str) -> u32 {
576 tokenizer.token_to_id(token).unwrap_or(0)
577}
578
579fn argmax(v: &[f32]) -> u32 {
580 v.iter()
581 .enumerate()
582 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
583 .map(|(i, _)| i as u32)
584 .unwrap_or(0)
585}
586
587fn softmax_vec(logits: &[f32]) -> Vec<f32> {
588 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
589 let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
590 let sum: f32 = exps.iter().sum();
591 exps.iter().map(|&e| e / sum).collect()
592}
593
594fn log_softmax_vec(logits: &[f32]) -> Vec<f32> {
595 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
596 let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
597 let log_sum = max + sum_exp.ln();
598 logits.iter().map(|&x| x - log_sum).collect()
599}
600
601fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
602 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
603 let probs = softmax_vec(&scaled);
604 let r: f32 = rand_f32();
606 let mut cumulative = 0.0;
607 for (i, &p) in probs.iter().enumerate() {
608 cumulative += p;
609 if cumulative >= r {
610 return i as u32;
611 }
612 }
613 (probs.len() - 1) as u32
614}
615
616fn rand_f32() -> f32 {
617 use std::sync::atomic::{AtomicU64, Ordering};
619 static STATE: AtomicU64 = AtomicU64::new(0x12345678_9abcdef0);
620 let mut s = STATE.load(Ordering::Relaxed);
621 s ^= s << 13;
622 s ^= s >> 7;
623 s ^= s << 17;
624 STATE.store(s, Ordering::Relaxed);
625 (s as f32) / (u64::MAX as f32)
626}
627
628fn compression_ratio(text: &str) -> f32 {
630 if text.is_empty() {
631 return 0.0;
632 }
633 use flate2::write::DeflateEncoder;
634 use flate2::Compression;
635 use std::io::Write;
636 let text_bytes = text.as_bytes();
637 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
638 encoder.write_all(text_bytes).unwrap();
639 let compressed = encoder.finish().unwrap();
640 text_bytes.len() as f32 / compressed.len().max(1) as f32
641}
642
643#[derive(Clone, Debug)]
646#[allow(dead_code)]
647struct DummyWhisperCache;
648
649impl ferrum_interfaces::KvCacheHandle for DummyWhisperCache {
650 fn block_table(&self) -> &ferrum_interfaces::BlockTable {
651 static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
652 std::sync::OnceLock::new();
653 EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
654 }
655 fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
656 unimplemented!()
657 }
658 fn as_any(&self) -> &dyn std::any::Any {
659 self
660 }
661 fn device(&self) -> Device {
662 Device::CPU
663 }
664 fn num_layers(&self) -> usize {
665 0
666 }
667 fn num_heads(&self) -> usize {
668 0
669 }
670 fn head_dim(&self) -> usize {
671 0
672 }
673 fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
674 Ok(None)
675 }
676 fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
677 Ok(None)
678 }
679 fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
680 Ok(Arc::new(self.clone()))
681 }
682 fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
683 ferrum_interfaces::CacheHandleStats {
684 memory_bytes: 0,
685 blocks_allocated: 0,
686 tokens_stored: 0,
687 utilization: 0.0,
688 last_access: std::time::Instant::now(),
689 }
690 }
691 fn is_valid(&self) -> bool {
692 true
693 }
694 fn cache_id(&self) -> String {
695 "whisper_dummy".to_string()
696 }
697}
698
699#[async_trait]
700impl ModelExecutor for WhisperModelExecutor {
701 fn info(&self) -> &ModelInfo {
702 &self.info
703 }
704
705 async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
706 Err(FerrumError::model(
707 "Whisper uses transcribe(), not prefill/decode",
708 ))
709 }
710
711 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
712 Err(FerrumError::model(
713 "Whisper uses transcribe(), not prefill/decode",
714 ))
715 }
716
717 fn capabilities(&self) -> ExecutorCapabilities {
718 ExecutorCapabilities {
719 max_batch_size: 1,
720 max_sequence_length: self.info.max_sequence_length,
721 attention_mechanisms: vec![AttentionType::MultiHead],
722 supports_dynamic_batching: false,
723 supports_continuous_batching: false,
724 supports_speculative_decoding: false,
725 supports_tensor_parallelism: false,
726 supports_pipeline_parallelism: false,
727 supported_dtypes: vec![DataType::FP32],
728 supported_devices: vec![self.info.device.clone()],
729 memory_requirements: MemoryRequirements {
730 parameter_memory: 0,
731 activation_memory_per_token: 0,
732 kv_cache_memory_per_token: 0,
733 overhead_memory: 0,
734 },
735 }
736 }
737
738 fn release_cache(&self, _: &str) {}
739
740 fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
741 common::default_executor_status()
742 }
743}