Skip to main content

whisper_cpp_plus/
lib.rs

1//! Safe, idiomatic Rust bindings for whisper.cpp
2//!
3//! This crate provides high-level, safe Rust bindings for whisper.cpp,
4//! OpenAI's Whisper automatic speech recognition (ASR) model implementation in C++.
5//!
6//! # Quick Start
7//!
8//! ```no_run
9//! use whisper_cpp_plus::{WhisperContext, FullParams, SamplingStrategy};
10//!
11//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
12//! // Load a Whisper model
13//! let ctx = WhisperContext::new("path/to/model.bin")?;
14//!
15//! // Transcribe audio (must be 16kHz mono f32 samples)
16//! let audio = vec![0.0f32; 16000]; // 1 second of silence
17//! let text = ctx.transcribe(&audio)?;
18//! println!("Transcription: {}", text);
19//! # Ok(())
20//! # }
21//! ```
22//!
23//! # Advanced Usage
24//!
25//! ```no_run
26//! use whisper_cpp_plus::{WhisperContext, FullParams, SamplingStrategy, TranscriptionParams};
27//!
28//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
29//! let ctx = WhisperContext::new("path/to/model.bin")?;
30//! let audio = vec![0.0f32; 16000]; // 1 second of audio
31//!
32//! // Configure parameters using builder pattern
33//! let params = TranscriptionParams::builder()
34//!     .language("en")
35//!     .temperature(0.8)
36//!     .enable_timestamps()
37//!     .build();
38//!
39//! // Transcribe with custom parameters
40//! let result = ctx.transcribe_with_params(&audio, params)?;
41//!
42//! // Access segments with timestamps
43//! for segment in result.segments {
44//!     println!("[{}-{}]: {}", segment.start_seconds(), segment.end_seconds(), segment.text);
45//! }
46//! # Ok(())
47//! # }
48//! ```
49
50mod context;
51mod error;
52mod params;
53mod state;
54mod stream;
55mod stream_pcm;
56mod vad;
57
58pub mod enhanced;
59
60#[cfg(feature = "quantization")]
61mod quantize;
62
63#[cfg(feature = "async")]
64mod async_api;
65
66pub use context::WhisperContext;
67pub use error::{Result, WhisperError};
68pub use params::{
69    FullParams, SamplingStrategy, TranscriptionParams, TranscriptionParamsBuilder,
70};
71pub use state::{Segment, TranscriptionResult, WhisperState};
72pub use stream::{WhisperStream, WhisperStreamConfig};
73pub use stream_pcm::{
74    PcmFormat, PcmReader, PcmReaderConfig, WhisperStreamPcm, WhisperStreamPcmConfig, vad_simple,
75};
76pub use vad::{
77    VadContextParams, VadParams, VadParamsBuilder, WhisperVadProcessor, VadSegments,
78};
79#[cfg(feature = "quantization")]
80pub use quantize::{WhisperQuantize, QuantizationType, QuantizeError};
81
82// Re-export for benchmarks
83#[doc(hidden)]
84pub mod bench_helpers {
85    pub use crate::vad::{WhisperVadProcessor, VadParams};
86}
87
88#[cfg(feature = "async")]
89pub use async_api::{AsyncWhisperStream, SharedAsyncStream};
90
91// Re-export the sys crate for advanced users who need lower-level access
92pub use whisper_cpp_plus_sys;
93
94impl WhisperContext {
95    /// Transcribe audio using default parameters
96    ///
97    /// # Arguments
98    /// * `audio` - Audio samples (must be 16kHz mono f32)
99    ///
100    /// # Returns
101    /// The transcribed text as a string
102    ///
103    /// # Example
104    /// ```no_run
105    /// # use whisper_cpp_plus::WhisperContext;
106    /// # fn main() -> whisper_cpp_plus::Result<()> {
107    /// let ctx = WhisperContext::new("model.bin")?;
108    /// let audio = vec![0.0f32; 16000]; // 1 second
109    /// let text = ctx.transcribe(&audio)?;
110    /// # Ok(())
111    /// # }
112    /// ```
113    pub fn transcribe(&self, audio: &[f32]) -> Result<String> {
114        let mut state = WhisperState::new(self)?;
115        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
116
117        state.full(params, audio)?;
118
119        let n_segments = state.full_n_segments();
120        let mut text = String::new();
121
122        for i in 0..n_segments {
123            if i > 0 {
124                text.push(' ');
125            }
126            text.push_str(&state.full_get_segment_text(i)?);
127        }
128
129        Ok(text)
130    }
131
132    /// Transcribe audio with custom parameters
133    ///
134    /// # Arguments
135    /// * `audio` - Audio samples (must be 16kHz mono f32)
136    /// * `params` - Custom transcription parameters
137    ///
138    /// # Returns
139    /// A `TranscriptionResult` containing the full text and individual segments
140    pub fn transcribe_with_params(
141        &self,
142        audio: &[f32],
143        params: TranscriptionParams,
144    ) -> Result<TranscriptionResult> {
145        self.transcribe_with_full_params(audio, params.into_full_params())
146    }
147
148    /// Transcribe audio with full control over parameters
149    ///
150    /// # Arguments
151    /// * `audio` - Audio samples (must be 16kHz mono f32)
152    /// * `params` - Full parameter configuration
153    ///
154    /// # Returns
155    /// A `TranscriptionResult` containing the full text and individual segments
156    pub fn transcribe_with_full_params(
157        &self,
158        audio: &[f32],
159        params: FullParams,
160    ) -> Result<TranscriptionResult> {
161        let mut state = WhisperState::new(self)?;
162        state.full(params, audio)?;
163
164        let n_segments = state.full_n_segments();
165        let mut segments = Vec::with_capacity(n_segments as usize);
166        let mut full_text = String::new();
167
168        for i in 0..n_segments {
169            let text = state.full_get_segment_text(i)?;
170            let (start_ms, end_ms) = state.full_get_segment_timestamps(i);
171            let speaker_turn_next = state.full_get_segment_speaker_turn_next(i);
172
173            if i > 0 {
174                full_text.push(' ');
175            }
176            full_text.push_str(&text);
177
178            segments.push(Segment {
179                start_ms,
180                end_ms,
181                text,
182                speaker_turn_next,
183            });
184        }
185
186        Ok(TranscriptionResult {
187            text: full_text,
188            segments,
189        })
190    }
191
192    /// Create a new state for manual transcription control
193    ///
194    /// This allows you to reuse a state for multiple transcriptions,
195    /// which can be more efficient than creating a new state each time.
196    pub fn create_state(&self) -> Result<WhisperState> {
197        WhisperState::new(self)
198    }
199
200    /// Enhanced transcription with custom parameters and temperature fallback
201    ///
202    /// This method provides quality-based retry with multiple temperatures
203    /// if the initial transcription doesn't meet quality thresholds.
204    ///
205    /// # Arguments
206    /// * `audio` - Audio samples (must be 16kHz mono f32)
207    /// * `params` - Custom transcription parameters
208    ///
209    /// # Returns
210    /// A `TranscriptionResult` containing the full text and individual segments
211    ///
212    /// # Example
213    /// ```no_run
214    /// # use whisper_cpp_plus::{WhisperContext, TranscriptionParams};
215    /// # fn main() -> whisper_cpp_plus::Result<()> {
216    /// let ctx = WhisperContext::new("model.bin")?;
217    /// let params = TranscriptionParams::builder()
218    ///     .language("en")
219    ///     .build();
220    /// let audio = vec![0.0f32; 16000];
221    /// let result = ctx.transcribe_with_params_enhanced(&audio, params)?;
222    /// # Ok(())
223    /// # }
224    /// ```
225    pub fn transcribe_with_params_enhanced(
226        &self,
227        audio: &[f32],
228        params: TranscriptionParams,
229    ) -> Result<TranscriptionResult> {
230        self.transcribe_with_full_params_enhanced(audio, params.into_full_params())
231    }
232
233    /// Enhanced transcription with full parameters and temperature fallback
234    ///
235    /// This method provides quality-based retry with multiple temperatures
236    /// if the initial transcription doesn't meet quality thresholds.
237    ///
238    /// # Arguments
239    /// * `audio` - Audio samples (must be 16kHz mono f32)
240    /// * `params` - Full parameter configuration
241    ///
242    /// # Returns
243    /// A `TranscriptionResult` containing the full text and individual segments
244    pub fn transcribe_with_full_params_enhanced(
245        &self,
246        audio: &[f32],
247        params: FullParams,
248    ) -> Result<TranscriptionResult> {
249        use crate::enhanced::fallback::{EnhancedTranscriptionParams, EnhancedWhisperState};
250
251        // Convert to enhanced params with default fallback settings
252        let enhanced_params = EnhancedTranscriptionParams::from_base(params);
253
254        // Use enhanced state with temperature fallback logic
255        let mut state = self.create_state()?;
256        let mut enhanced_state = EnhancedWhisperState::new(&mut state);
257        enhanced_state.transcribe_with_fallback(enhanced_params, audio)
258    }
259}
260
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use std::path::Path;
266    use std::sync::Arc;
267
268    #[test]
269    fn test_error_on_invalid_model() {
270        let result = WhisperContext::new("nonexistent_model.bin");
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_model_loading() {
276        let model_path = "tests/models/ggml-tiny.en.bin";
277        if Path::new(model_path).exists() {
278            let ctx = WhisperContext::new(model_path);
279            assert!(ctx.is_ok());
280        } else {
281            eprintln!("Skipping test_model_loading: model file not found");
282        }
283    }
284
285    #[test]
286    fn test_silence_handling() {
287        let model_path = "tests/models/ggml-tiny.en.bin";
288        if Path::new(model_path).exists() {
289            let ctx = WhisperContext::new(model_path).unwrap();
290            let silence = vec![0.0f32; 16000]; // 1 second of silence
291            let result = ctx.transcribe(&silence);
292            assert!(result.is_ok());
293        } else {
294            eprintln!("Skipping test_silence_handling: model file not found");
295        }
296    }
297
298    #[test]
299    fn test_concurrent_states() {
300        let model_path = "tests/models/ggml-tiny.en.bin";
301        if Path::new(model_path).exists() {
302            let ctx = Arc::new(WhisperContext::new(model_path).unwrap());
303            let handles: Vec<_> = (0..4)
304                .map(|_| {
305                    let ctx = Arc::clone(&ctx);
306                    std::thread::spawn(move || {
307                        let audio = vec![0.0f32; 16000];
308                        ctx.transcribe(&audio)
309                    })
310                })
311                .collect();
312
313            for handle in handles {
314                assert!(handle.join().unwrap().is_ok());
315            }
316        } else {
317            eprintln!("Skipping test_concurrent_states: model file not found");
318        }
319    }
320
321    #[test]
322    fn test_params_builder() {
323        let params = TranscriptionParams::builder()
324            .language("en")
325            .temperature(0.8)
326            .enable_timestamps()
327            .n_threads(4)
328            .build();
329
330        // Just ensure it builds without panic
331        let _ = params.into_full_params();
332    }
333}