use crate::{WhisperState, FullParams, Result, WhisperError, Segment, TranscriptionResult};
use std::io::Write;
use flate2::Compression;
use flate2::write::ZlibEncoder;
use whisper_cpp_plus_sys as ffi;
#[derive(Debug, Clone)]
pub struct QualityThresholds {
pub compression_ratio_threshold: Option<f32>,
pub log_prob_threshold: Option<f32>,
pub no_speech_threshold: Option<f32>,
}
impl Default for QualityThresholds {
fn default() -> Self {
Self {
compression_ratio_threshold: Some(2.4),
log_prob_threshold: Some(-1.0),
no_speech_threshold: Some(0.6),
}
}
}
#[derive(Clone)]
pub struct EnhancedTranscriptionParams {
pub base: FullParams,
pub temperatures: Vec<f32>,
pub thresholds: QualityThresholds,
pub prompt_reset_on_temperature: f32,
}
impl EnhancedTranscriptionParams {
pub fn from_base(base: FullParams) -> Self {
Self {
base,
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
thresholds: QualityThresholds::default(),
prompt_reset_on_temperature: 0.5,
}
}
pub fn builder() -> EnhancedTranscriptionParamsBuilder {
EnhancedTranscriptionParamsBuilder::new()
}
}
pub struct EnhancedTranscriptionParamsBuilder {
params: EnhancedTranscriptionParams,
}
impl EnhancedTranscriptionParamsBuilder {
pub fn new() -> Self {
Self {
params: EnhancedTranscriptionParams::from_base(FullParams::default()),
}
}
pub fn base_params(mut self, params: FullParams) -> Self {
self.params.base = params;
self
}
pub fn language(mut self, lang: &str) -> Self {
self.params.base = self.params.base.language(lang);
self
}
pub fn temperatures(mut self, temps: Vec<f32>) -> Self {
self.params.temperatures = temps;
self
}
pub fn compression_ratio_threshold(mut self, threshold: Option<f32>) -> Self {
self.params.thresholds.compression_ratio_threshold = threshold;
self
}
pub fn log_prob_threshold(mut self, threshold: Option<f32>) -> Self {
self.params.thresholds.log_prob_threshold = threshold;
self
}
pub fn build(self) -> EnhancedTranscriptionParams {
self.params
}
}
impl Default for EnhancedTranscriptionParamsBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn calculate_compression_ratio(text: &str) -> f32 {
let text_bytes = text.as_bytes();
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(text_bytes).unwrap();
let compressed = encoder.finish().unwrap();
text_bytes.len() as f32 / compressed.len() as f32
}
#[derive(Debug)]
pub struct TranscriptionAttempt {
pub text: String,
pub segments: Vec<Segment>,
pub temperature: f32,
pub compression_ratio: f32,
pub avg_logprob: f32,
pub no_speech_prob: f32,
}
impl TranscriptionAttempt {
pub fn meets_thresholds(&self, thresholds: &QualityThresholds) -> bool {
let mut meets = true;
if let Some(cr_threshold) = thresholds.compression_ratio_threshold {
if self.compression_ratio > cr_threshold {
meets = false;
}
}
if let Some(lp_threshold) = thresholds.log_prob_threshold {
if self.avg_logprob < lp_threshold {
if let Some(ns_threshold) = thresholds.no_speech_threshold {
if self.no_speech_prob <= ns_threshold {
meets = false;
}
} else {
meets = false;
}
}
}
meets
}
}
pub struct EnhancedWhisperState<'a> {
state: &'a mut WhisperState,
}
impl<'a> EnhancedWhisperState<'a> {
pub fn new(state: &'a mut WhisperState) -> Self {
Self { state }
}
fn get_no_speech_prob(&self, segment_idx: i32) -> f32 {
unsafe {
ffi::whisper_full_get_segment_no_speech_prob_from_state(
self.state.ptr,
segment_idx
)
}
}
fn calculate_avg_logprob(&self, segment_idx: i32) -> f32 {
let n_tokens = self.state.full_n_tokens(segment_idx);
if n_tokens == 0 {
return 0.0;
}
let mut sum_logprob = 0.0;
for i in 0..n_tokens {
let prob = self.state.full_get_token_prob(segment_idx, i);
if prob > 0.0 {
sum_logprob += prob.ln();
}
}
sum_logprob / n_tokens as f32
}
pub fn transcribe_with_fallback(
&mut self,
params: EnhancedTranscriptionParams,
audio: &[f32],
) -> Result<TranscriptionResult> {
let mut all_attempts = Vec::new();
let mut below_cr_attempts = Vec::new();
for temperature in ¶ms.temperatures {
let mut current_params = params.base.clone();
current_params = current_params.temperature(*temperature);
if *temperature > params.prompt_reset_on_temperature {
current_params = current_params.initial_prompt("");
}
self.state.full(current_params, audio)?;
let n_segments = self.state.full_n_segments();
let mut segments = Vec::new();
let mut text = String::new();
let mut total_logprob = 0.0;
let mut total_tokens = 0;
for i in 0..n_segments {
let segment_text = self.state.full_get_segment_text(i)?;
let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
if i > 0 {
text.push(' ');
}
text.push_str(&segment_text);
segments.push(Segment {
start_ms,
end_ms,
text: segment_text,
speaker_turn_next,
});
let avg_lp = self.calculate_avg_logprob(i);
let n_tokens = self.state.full_n_tokens(i);
total_logprob += avg_lp * n_tokens as f32;
total_tokens += n_tokens;
}
let avg_logprob = if total_tokens > 0 {
total_logprob / total_tokens as f32
} else {
0.0
};
let compression_ratio = calculate_compression_ratio(&text);
let no_speech_prob = if n_segments > 0 {
self.get_no_speech_prob(0)
} else {
0.0
};
let attempt = TranscriptionAttempt {
text: text.clone(),
segments: segments.clone(),
temperature: *temperature,
compression_ratio,
avg_logprob,
no_speech_prob,
};
if attempt.meets_thresholds(¶ms.thresholds) {
return Ok(TranscriptionResult {
text: attempt.text,
segments: attempt.segments,
});
}
if let Some(cr_threshold) = params.thresholds.compression_ratio_threshold {
if attempt.compression_ratio <= cr_threshold {
below_cr_attempts.push(attempt);
} else {
all_attempts.push(attempt);
}
} else {
all_attempts.push(attempt);
}
}
let best_attempt = if !below_cr_attempts.is_empty() {
below_cr_attempts.into_iter()
.max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
} else {
all_attempts.into_iter()
.max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
};
best_attempt
.map(|a| TranscriptionResult {
text: a.text,
segments: a.segments,
})
.ok_or_else(|| WhisperError::TranscriptionError(
"Failed to produce acceptable transcription with any temperature".into()
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_ratio_calculation() {
let text = "The quick brown fox jumps over the lazy dog";
let ratio = calculate_compression_ratio(text);
assert!(ratio > 0.0);
let longer_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
let longer_ratio = calculate_compression_ratio(&longer_text);
assert!(longer_ratio > 1.0);
let repetitive = "a".repeat(1000);
let repetitive_ratio = calculate_compression_ratio(&repetitive);
assert!(repetitive_ratio > 5.0); }
#[test]
fn test_quality_threshold_checking() {
let thresholds = QualityThresholds {
compression_ratio_threshold: Some(2.4),
log_prob_threshold: Some(-1.0),
no_speech_threshold: Some(0.6),
};
let good_attempt = TranscriptionAttempt {
text: "Hello world".to_string(),
segments: vec![],
temperature: 0.0,
compression_ratio: 1.5,
avg_logprob: -0.5,
no_speech_prob: 0.1,
};
assert!(good_attempt.meets_thresholds(&thresholds));
let bad_attempt = TranscriptionAttempt {
text: "a".repeat(100),
segments: vec![],
temperature: 0.0,
compression_ratio: 10.0, avg_logprob: -0.5,
no_speech_prob: 0.1,
};
assert!(!bad_attempt.meets_thresholds(&thresholds));
}
#[test]
fn test_enhanced_params_from_base() {
let base = FullParams::default()
.language("en");
let enhanced = EnhancedTranscriptionParams::from_base(base);
assert_eq!(enhanced.temperatures.len(), 6);
assert_eq!(enhanced.temperatures[0], 0.0);
assert_eq!(enhanced.prompt_reset_on_temperature, 0.5);
assert!(enhanced.thresholds.compression_ratio_threshold.is_some());
}
#[test]
fn test_enhanced_transcription_params_builder() {
let params = EnhancedTranscriptionParamsBuilder::new()
.language("en")
.temperatures(vec![0.0, 0.5, 1.0])
.compression_ratio_threshold(Some(3.0))
.build();
assert_eq!(params.temperatures.len(), 3);
assert_eq!(params.thresholds.compression_ratio_threshold, Some(3.0));
}
}