use std::path::{Path, PathBuf};
use whisper_cpp_plus::{WhisperContext, TranscriptionParams, FullParams, SamplingStrategy};
use whisper_cpp_plus::enhanced::fallback::{
EnhancedTranscriptionParams, EnhancedTranscriptionParamsBuilder,
QualityThresholds, EnhancedWhisperState
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = find_model("ggml-tiny.en.bin")
.ok_or("Model not found. Run: cargo xtask test-setup")?;
println!("Loading model from {:?}...", model_path);
let ctx = WhisperContext::new(&model_path)?;
let (clear_audio, noisy_audio) = load_audio_examples()?;
println!("\n=== Example 1: Clear Audio ===");
compare_transcription_methods(&ctx, &clear_audio)?;
println!("\n=== Example 2: Noisy/Difficult Audio ===");
compare_transcription_methods(&ctx, &noisy_audio)?;
println!("\n=== Example 3: Custom Quality Thresholds ===");
demonstrate_custom_thresholds(&ctx, &noisy_audio)?;
println!("\n=== Example 4: Direct Enhanced State Control ===");
demonstrate_direct_enhanced_state(&ctx, &noisy_audio)?;
Ok(())
}
fn compare_transcription_methods(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
println!("1. Standard transcription:");
let start = std::time::Instant::now();
let standard_text = ctx.transcribe(audio)?;
let standard_time = start.elapsed();
println!(" Text: {}", standard_text);
println!(" Time: {:?}", standard_time);
println!("\n2. Enhanced transcription with temperature fallback:");
let params = TranscriptionParams::builder()
.language("en")
.build();
let start = std::time::Instant::now();
let enhanced_result = ctx.transcribe_with_params_enhanced(audio, params)?;
let enhanced_time = start.elapsed();
println!(" Text: {}", enhanced_result.text);
println!(" Time: {:?}", enhanced_time);
if enhanced_result.text != standard_text {
println!(" Note: Enhanced version produced different (likely better) result!");
}
Ok(())
}
fn demonstrate_custom_thresholds(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
println!("Creating enhanced parameters with custom quality thresholds...");
let base_params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.language("en");
let enhanced_params = EnhancedTranscriptionParamsBuilder::new()
.base_params(base_params)
.temperatures(vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
.compression_ratio_threshold(Some(2.0)) .log_prob_threshold(Some(-0.5)) .build();
println!("Quality thresholds:");
println!(" - Max compression ratio: {:?}", enhanced_params.thresholds.compression_ratio_threshold);
println!(" - Min log probability: {:?}", enhanced_params.thresholds.log_prob_threshold);
println!(" - Temperature sequence: {:?}", enhanced_params.temperatures);
let mut state = ctx.create_state()?;
let mut enhanced_state = EnhancedWhisperState::new(&mut state);
let result = enhanced_state.transcribe_with_fallback(enhanced_params, audio)?;
println!("\nTranscription result:");
println!(" Text: {}", result.text);
println!(" Segments: {}", result.segments.len());
for (i, segment) in result.segments.iter().enumerate() {
println!(" Segment {}: [{:.2}s - {:.2}s] {}",
i + 1,
segment.start_seconds(),
segment.end_seconds(),
segment.text
);
}
Ok(())
}
fn demonstrate_direct_enhanced_state(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
println!("Using enhanced state directly for fine control...");
let mut state = ctx.create_state()?;
let relaxed_thresholds = QualityThresholds {
compression_ratio_threshold: Some(3.0), log_prob_threshold: Some(-2.0), no_speech_threshold: Some(0.8),
};
let strict_thresholds = QualityThresholds {
compression_ratio_threshold: Some(1.5), log_prob_threshold: Some(-0.3), no_speech_threshold: Some(0.4),
};
println!("\n1. With relaxed thresholds:");
let params = EnhancedTranscriptionParams {
base: FullParams::default().language("en"),
temperatures: vec![0.0, 0.5, 1.0],
thresholds: relaxed_thresholds,
prompt_reset_on_temperature: 0.5,
};
let mut enhanced_state = EnhancedWhisperState::new(&mut state);
let result = enhanced_state.transcribe_with_fallback(params, audio)?;
println!(" Result: {}", result.text);
println!("\n2. With strict thresholds:");
let params = EnhancedTranscriptionParams {
base: FullParams::default().language("en"),
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
thresholds: strict_thresholds,
prompt_reset_on_temperature: 0.5,
};
let result = enhanced_state.transcribe_with_fallback(params, audio)?;
println!(" Result: {}", result.text);
println!(" Note: Stricter thresholds may have triggered more temperature fallbacks");
Ok(())
}
fn load_audio_examples() -> Result<(Vec<f32>, Vec<f32>), Box<dyn std::error::Error>> {
let jfk_from_env = std::env::var("WHISPER_TEST_AUDIO_DIR")
.ok()
.map(|d| format!("{}/jfk.wav", d))
.filter(|p| Path::new(p).exists());
let jfk_paths = [
"../whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"samples/clear_speech.wav",
];
let clear_audio = if let Some(ref p) = jfk_from_env {
println!("Loading clear audio from: {}", p);
load_wav_file(p)?
} else if let Some(path) = jfk_paths.iter().find(|p| Path::new(p).exists()) {
println!("Loading clear audio from: {}", path);
load_wav_file(path)?
} else {
eprintln!("\nError: No audio files found!");
eprintln!("Set WHISPER_TEST_AUDIO_DIR env var or provide audio.");
return Err("No audio files found".into());
};
let noisy_path = "samples/noisy_speech.wav";
let noisy_audio = if Path::new(noisy_path).exists() {
println!("Loading noisy audio from: {}", noisy_path);
load_wav_file(noisy_path)?
} else {
println!("Creating noisy version from clear audio for demonstration...");
add_noise_to_audio(&clear_audio)
};
Ok((clear_audio, noisy_audio))
}
fn add_noise_to_audio(audio: &[f32]) -> Vec<f32> {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
let mut rng = RandomState::new().build_hasher();
audio.iter().enumerate().map(|(i, &sample)| {
i.hash(&mut rng);
let noise_val = (rng.finish() as f32 / u64::MAX as f32 - 0.5) * 0.15; let noisy = sample + noise_val;
noisy.max(-1.0).min(1.0) }).collect()
}
fn load_wav_file(path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
use hound;
let mut reader = hound::WavReader::open(path)?;
let spec = reader.spec();
if spec.sample_rate != 16000 {
eprintln!("Warning: Audio sample rate is {}Hz, expected 16000Hz", spec.sample_rate);
}
if spec.channels != 1 {
eprintln!("Warning: Audio has {} channels, using first channel only", spec.channels);
}
let samples: Vec<f32> = reader
.samples::<i16>()
.step_by(spec.channels as usize)
.map(|s| s.unwrap() as f32 / 32768.0)
.collect();
Ok(samples)
}
fn find_model(name: &str) -> Option<PathBuf> {
for env_var in ["WHISPER_TEST_MODEL_DIR", "WHISPER_MODEL_PATH"] {
if let Ok(dir) = std::env::var(env_var) {
let path = Path::new(&dir).join(name);
if path.exists() { return Some(path); }
}
}
let paths = [
format!("tests/models/{}", name),
format!("whisper-cpp-plus/tests/models/{}", name),
format!("../whisper-cpp-plus-sys/whisper.cpp/models/{}", name),
format!("whisper-cpp-plus-sys/whisper.cpp/models/{}", name),
];
paths.iter().find(|p| Path::new(p).exists()).map(PathBuf::from)
}