use criterion::{black_box, criterion_group, criterion_main, Criterion};
use whisper_cpp_plus::enhanced::fallback::calculate_compression_ratio;
use std::time::Duration;
fn benchmark_compression_ratio(c: &mut Criterion) {
let mut group = c.benchmark_group("compression_ratio");
let short_text = "Hello world";
let medium_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
let long_text = include_str!("../src/lib.rs"); let repetitive_text = "a".repeat(1000);
let mixed_text = "aaaaa bbbbb ccccc ddddd eeeee fffff ".repeat(50);
group.bench_function("short_text", |b| {
b.iter(|| calculate_compression_ratio(black_box(short_text)))
});
group.bench_function("medium_text", |b| {
b.iter(|| calculate_compression_ratio(black_box(&medium_text)))
});
group.bench_function("long_text", |b| {
b.iter(|| calculate_compression_ratio(black_box(long_text)))
});
group.bench_function("repetitive_text", |b| {
b.iter(|| calculate_compression_ratio(black_box(&repetitive_text)))
});
group.bench_function("mixed_text", |b| {
b.iter(|| calculate_compression_ratio(black_box(&mixed_text)))
});
group.finish();
}
fn benchmark_quality_checks(c: &mut Criterion) {
use whisper_cpp_plus::enhanced::fallback::{TranscriptionAttempt, QualityThresholds};
let mut group = c.benchmark_group("quality_checks");
let thresholds = QualityThresholds {
compression_ratio_threshold: Some(2.4),
log_prob_threshold: Some(-1.0),
no_speech_threshold: Some(0.6),
};
let good_attempt = TranscriptionAttempt {
text: "This is a clear transcription with good quality metrics.".to_string(),
segments: vec![],
temperature: 0.0,
compression_ratio: 1.5,
avg_logprob: -0.5,
no_speech_prob: 0.1,
};
let poor_attempt = TranscriptionAttempt {
text: "a".repeat(500),
segments: vec![],
temperature: 0.0,
compression_ratio: 15.0,
avg_logprob: -2.0,
no_speech_prob: 0.1,
};
let borderline_attempt = TranscriptionAttempt {
text: "This text has borderline quality metrics that might trigger fallback.".to_string(),
segments: vec![],
temperature: 0.0,
compression_ratio: 2.3,
avg_logprob: -0.95,
no_speech_prob: 0.55,
};
group.bench_function("good_quality_check", |b| {
b.iter(|| good_attempt.meets_thresholds(black_box(&thresholds)))
});
group.bench_function("poor_quality_check", |b| {
b.iter(|| poor_attempt.meets_thresholds(black_box(&thresholds)))
});
group.bench_function("borderline_quality_check", |b| {
b.iter(|| borderline_attempt.meets_thresholds(black_box(&thresholds)))
});
group.finish();
}
fn find_whisper_model() -> Option<String> {
if let Ok(dir) = std::env::var("WHISPER_TEST_MODEL_DIR") {
let p = format!("{}/ggml-base.en.bin", dir);
if std::path::Path::new(&p).exists() { return Some(p); }
}
let paths = [
"tests/models/ggml-base.en.bin",
"../whisper-cpp-plus-sys/whisper.cpp/models/ggml-base.en.bin",
"whisper-cpp-plus-sys/whisper.cpp/models/ggml-base.en.bin",
];
paths.iter().find(|p| std::path::Path::new(p).exists()).map(|s| s.to_string())
}
fn benchmark_transcription_with_fallback_simulation(c: &mut Criterion) {
let model_path = find_whisper_model();
if model_path.is_none() {
eprintln!("Model not found. Set WHISPER_TEST_MODEL_DIR or run `cargo xtask test-setup`.");
return;
}
let model_path = model_path.unwrap();
use whisper_cpp_plus::{WhisperContext, TranscriptionParams};
use std::sync::Arc;
let ctx = Arc::new(WhisperContext::new(&model_path).unwrap());
let mut group = c.benchmark_group("transcription_fallback");
group.measurement_time(Duration::from_secs(15));
let audio = load_benchmark_audio().unwrap_or_else(|e| {
eprintln!("Failed to load audio: {}. Skipping transcription benchmarks.", e);
vec![0.0f32; 16000] });
let noisy_audio = add_noise_to_audio(&audio);
let params = TranscriptionParams::builder()
.language("en")
.build();
group.bench_function("standard_clear", |b| {
let ctx = Arc::clone(&ctx);
let audio = audio.clone();
b.iter(|| {
ctx.transcribe_with_params(black_box(&audio), params.clone()).unwrap()
})
});
group.bench_function("standard_noisy", |b| {
let ctx = Arc::clone(&ctx);
let audio = noisy_audio.clone();
b.iter(|| {
ctx.transcribe_with_params(black_box(&audio), params.clone()).unwrap()
})
});
group.bench_function("enhanced_clear", |b| {
let ctx = Arc::clone(&ctx);
let audio = audio.clone();
b.iter(|| {
ctx.transcribe_with_params_enhanced(black_box(&audio), params.clone()).unwrap()
})
});
group.bench_function("enhanced_noisy", |b| {
let ctx = Arc::clone(&ctx);
let audio = noisy_audio.clone();
b.iter(|| {
ctx.transcribe_with_params_enhanced(black_box(&audio), params.clone()).unwrap()
})
});
group.finish();
}
fn load_benchmark_audio() -> Result<Vec<f32>, Box<dyn std::error::Error>> {
if let Ok(dir) = std::env::var("WHISPER_TEST_AUDIO_DIR") {
let p = format!("{}/jfk.wav", dir);
if std::path::Path::new(&p).exists() {
return load_wav_file(&p);
}
}
let paths = [
"../whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"samples/benchmark_audio.wav",
];
for path in paths {
if std::path::Path::new(path).exists() {
return load_wav_file(path);
}
}
Err("No audio files found. Set WHISPER_TEST_AUDIO_DIR.".into())
}
fn load_wav_file(path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
use hound;
let mut reader = hound::WavReader::open(path)?;
let spec = reader.spec();
if spec.sample_rate != 16000 {
eprintln!("Warning: Audio sample rate is {}Hz, expected 16000Hz", spec.sample_rate);
}
let samples: Vec<f32> = reader
.samples::<i16>()
.step_by(spec.channels as usize)
.map(|s| s.unwrap() as f32 / 32768.0)
.collect();
let max_samples = 16000 * 3;
Ok(samples.into_iter().take(max_samples).collect())
}
fn add_noise_to_audio(audio: &[f32]) -> Vec<f32> {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
let mut rng = RandomState::new().build_hasher();
audio.iter().enumerate().map(|(i, &sample)| {
i.hash(&mut rng);
let noise_val = (rng.finish() as f32 / u64::MAX as f32 - 0.5) * 0.1;
let noisy = sample + noise_val;
noisy.max(-1.0).min(1.0)
}).collect()
}
criterion_group!(
benches,
benchmark_compression_ratio,
benchmark_quality_checks,
benchmark_transcription_with_fallback_simulation
);
criterion_main!(benches);