mod context;
mod error;
mod params;
mod state;
mod stream;
mod stream_pcm;
mod vad;
pub mod enhanced;
#[cfg(feature = "quantization")]
mod quantize;
#[cfg(feature = "async")]
mod async_api;
pub use context::WhisperContext;
pub use error::{Result, WhisperError};
pub use params::{
FullParams, SamplingStrategy, TranscriptionParams, TranscriptionParamsBuilder,
};
pub use state::{Segment, TranscriptionResult, WhisperState};
pub use stream::{WhisperStream, WhisperStreamConfig};
pub use stream_pcm::{
PcmFormat, PcmReader, PcmReaderConfig, WhisperStreamPcm, WhisperStreamPcmConfig, vad_simple,
};
pub use vad::{
VadContextParams, VadParams, VadParamsBuilder, WhisperVadProcessor, VadSegments,
};
#[cfg(feature = "quantization")]
pub use quantize::{WhisperQuantize, QuantizationType, QuantizeError};
#[doc(hidden)]
pub mod bench_helpers {
pub use crate::vad::{WhisperVadProcessor, VadParams};
}
#[cfg(feature = "async")]
pub use async_api::{AsyncWhisperStream, SharedAsyncStream};
pub use whisper_cpp_plus_sys;
impl WhisperContext {
pub fn transcribe(&self, audio: &[f32]) -> Result<String> {
let mut state = WhisperState::new(self)?;
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
state.full(params, audio)?;
let n_segments = state.full_n_segments();
let mut text = String::new();
for i in 0..n_segments {
if i > 0 {
text.push(' ');
}
text.push_str(&state.full_get_segment_text(i)?);
}
Ok(text)
}
pub fn transcribe_with_params(
&self,
audio: &[f32],
params: TranscriptionParams,
) -> Result<TranscriptionResult> {
self.transcribe_with_full_params(audio, params.into_full_params())
}
pub fn transcribe_with_full_params(
&self,
audio: &[f32],
params: FullParams,
) -> Result<TranscriptionResult> {
let mut state = WhisperState::new(self)?;
state.full(params, audio)?;
let n_segments = state.full_n_segments();
let mut segments = Vec::with_capacity(n_segments as usize);
let mut full_text = String::new();
for i in 0..n_segments {
let text = state.full_get_segment_text(i)?;
let (start_ms, end_ms) = state.full_get_segment_timestamps(i);
let speaker_turn_next = state.full_get_segment_speaker_turn_next(i);
if i > 0 {
full_text.push(' ');
}
full_text.push_str(&text);
segments.push(Segment {
start_ms,
end_ms,
text,
speaker_turn_next,
});
}
Ok(TranscriptionResult {
text: full_text,
segments,
})
}
pub fn create_state(&self) -> Result<WhisperState> {
WhisperState::new(self)
}
pub fn transcribe_with_params_enhanced(
&self,
audio: &[f32],
params: TranscriptionParams,
) -> Result<TranscriptionResult> {
self.transcribe_with_full_params_enhanced(audio, params.into_full_params())
}
pub fn transcribe_with_full_params_enhanced(
&self,
audio: &[f32],
params: FullParams,
) -> Result<TranscriptionResult> {
use crate::enhanced::fallback::{EnhancedTranscriptionParams, EnhancedWhisperState};
let enhanced_params = EnhancedTranscriptionParams::from_base(params);
let mut state = self.create_state()?;
let mut enhanced_state = EnhancedWhisperState::new(&mut state);
enhanced_state.transcribe_with_fallback(enhanced_params, audio)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
use std::sync::Arc;
#[test]
fn test_error_on_invalid_model() {
let result = WhisperContext::new("nonexistent_model.bin");
assert!(result.is_err());
}
#[test]
fn test_model_loading() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = WhisperContext::new(model_path);
assert!(ctx.is_ok());
} else {
eprintln!("Skipping test_model_loading: model file not found");
}
}
#[test]
fn test_silence_handling() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = WhisperContext::new(model_path).unwrap();
let silence = vec![0.0f32; 16000]; let result = ctx.transcribe(&silence);
assert!(result.is_ok());
} else {
eprintln!("Skipping test_silence_handling: model file not found");
}
}
#[test]
fn test_concurrent_states() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = Arc::new(WhisperContext::new(model_path).unwrap());
let handles: Vec<_> = (0..4)
.map(|_| {
let ctx = Arc::clone(&ctx);
std::thread::spawn(move || {
let audio = vec![0.0f32; 16000];
ctx.transcribe(&audio)
})
})
.collect();
for handle in handles {
assert!(handle.join().unwrap().is_ok());
}
} else {
eprintln!("Skipping test_concurrent_states: model file not found");
}
}
#[test]
fn test_params_builder() {
let params = TranscriptionParams::builder()
.language("en")
.temperature(0.8)
.enable_timestamps()
.n_threads(4)
.build();
let _ = params.into_full_params();
}
}