mod common;
use std::sync::Arc;
use common::TestModels;
use whisper_cpp_plus::{
FullParams, SamplingStrategy, TranscriptionParams, WhisperContext, WhisperError,
};
#[test]
fn test_model_loading() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path);
assert!(ctx.is_ok(), "Failed to load model from {:?}", model_path);
}
#[test]
fn test_invalid_model_path() {
let ctx = WhisperContext::new("nonexistent_model.bin");
assert!(ctx.is_err());
match ctx {
Err(WhisperError::ModelLoadError(_)) => {}
_ => panic!("Expected ModelLoadError"),
}
}
#[test]
fn test_silence_handling() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
let silence = vec![0.0f32; 16000]; let result = ctx.transcribe(&silence);
assert!(result.is_ok());
let text = result.unwrap();
assert!(text.len() < 100); }
#[test]
fn test_concurrent_states() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = Arc::new(WhisperContext::new(&model_path).unwrap());
let handles: Vec<_> = (0..4)
.map(|_| {
let ctx = Arc::clone(&ctx);
std::thread::spawn(move || {
let audio = vec![0.0f32; 16000];
ctx.transcribe(&audio)
})
})
.collect();
for handle in handles {
assert!(handle.join().unwrap().is_ok());
}
}
#[test]
fn test_transcription_with_params() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
let audio = vec![0.0f32; 16000 * 3];
let params = TranscriptionParams::builder()
.language("en")
.temperature(0.0)
.enable_timestamps()
.n_threads(2)
.build();
let result = ctx.transcribe_with_params(&audio, params);
assert!(result.is_ok());
let transcription = result.unwrap();
assert!(transcription.segments.len() >= 0);
assert_eq!(
transcription.text.split_whitespace().collect::<Vec<_>>().len(),
transcription
.segments
.iter()
.map(|s| s.text.split_whitespace().count())
.sum::<usize>()
);
}
#[test]
fn test_full_params_configuration() {
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.language("en")
.translate(false)
.no_context(false)
.no_timestamps(false)
.single_segment(false)
.temperature(0.8)
.n_threads(4)
.offset_ms(0)
.duration_ms(0);
}
#[test]
fn test_beam_search_strategy() {
let params = FullParams::new(SamplingStrategy::BeamSearch { beam_size: 5 })
.language("en")
.temperature(0.5);
}
#[test]
fn test_model_info() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
assert!(ctx.n_vocab() > 0, "n_vocab should be positive");
assert!(ctx.n_audio_ctx() > 0, "n_audio_ctx should be positive");
assert!(ctx.n_text_ctx() > 0, "n_text_ctx should be positive");
assert!(!ctx.is_multilingual());
}
#[test]
fn test_segment_timestamps() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
let audio = vec![0.0f32; 16000 * 2];
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.no_timestamps(false);
let result = ctx.transcribe_with_full_params(&audio, params).unwrap();
for segment in &result.segments {
assert!(segment.start_ms >= 0);
assert!(segment.end_ms >= 0);
assert!(segment.end_ms >= segment.start_ms);
assert!(segment.start_seconds() >= 0.0);
assert!(segment.end_seconds() >= 0.0);
}
}
#[test]
fn test_state_reuse() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
let mut state = ctx.create_state().unwrap();
let audio1 = vec![0.0f32; 16000]; let audio2 = vec![0.0f32; 16000 * 2];
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
assert!(state.full(params.clone(), &audio1).is_ok());
let n_segments1 = state.full_n_segments();
assert!(state.full(params, &audio2).is_ok());
let n_segments2 = state.full_n_segments();
assert!(n_segments1 >= 0);
assert!(n_segments2 >= 0);
}
#[test]
fn test_empty_audio_error() {
let Some(model_path) = TestModels::tiny_en() else {
eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
return;
};
let ctx = WhisperContext::new(&model_path).unwrap();
let empty_audio = vec![];
let result = ctx.transcribe(&empty_audio);
assert!(result.is_err());
match result {
Err(WhisperError::InvalidAudioFormat) => {}
_ => panic!("Expected InvalidAudioFormat error"),
}
}