Skip to main content

whisper_cpp_plus/
lib.rs

1//! Safe, idiomatic Rust bindings for whisper.cpp
2//!
3//! This crate provides high-level, safe Rust bindings for whisper.cpp,
4//! OpenAI's Whisper automatic speech recognition (ASR) model implementation in C++.
5//!
6//! # Quick Start
7//!
8//! ```no_run
9//! use whisper_cpp_plus::{WhisperContext, FullParams, SamplingStrategy};
10//!
11//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
12//! // Load a Whisper model
13//! let ctx = WhisperContext::new("path/to/model.bin")?;
14//!
15//! // Transcribe audio (must be 16kHz mono f32 samples)
16//! let audio = vec![0.0f32; 16000]; // 1 second of silence
17//! let text = ctx.transcribe(&audio)?;
18//! println!("Transcription: {}", text);
19//! # Ok(())
20//! # }
21//! ```
22//!
23//! # Advanced Usage
24//!
25//! ```no_run
26//! use whisper_cpp_plus::{WhisperContext, FullParams, SamplingStrategy, TranscriptionParams};
27//!
28//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
29//! let ctx = WhisperContext::new("path/to/model.bin")?;
30//! let audio = vec![0.0f32; 16000]; // 1 second of audio
31//!
32//! // Configure parameters using builder pattern
33//! let params = TranscriptionParams::builder()
34//!     .language("en")
35//!     .temperature(0.8)
36//!     .enable_timestamps()
37//!     .build();
38//!
39//! // Transcribe with custom parameters
40//! let result = ctx.transcribe_with_params(&audio, params)?;
41//!
42//! // Access segments with timestamps
43//! for segment in result.segments {
44//!     println!("[{}-{}]: {}", segment.start_seconds(), segment.end_seconds(), segment.text);
45//! }
46//! # Ok(())
47//! # }
48//! ```
49
50mod context;
51mod error;
52mod params;
53mod state;
54mod stream;
55mod stream_pcm;
56mod vad;
57
58pub mod enhanced;
59
60#[cfg(feature = "quantization")]
61mod quantize;
62
63#[cfg(feature = "async")]
64mod async_api;
65
66pub use context::WhisperContext;
67pub use error::{Result, WhisperError};
68pub use params::{FullParams, SamplingStrategy, TranscriptionParams, TranscriptionParamsBuilder};
69#[cfg(feature = "quantization")]
70pub use quantize::{QuantizationType, QuantizeError, WhisperQuantize};
71pub use state::{Segment, TranscriptionResult, WhisperState};
72pub use stream::{WhisperStream, WhisperStreamConfig};
73pub use stream_pcm::{
74    vad_simple, PcmFormat, PcmReader, PcmReaderConfig, WhisperStreamPcm, WhisperStreamPcmConfig,
75};
76pub use vad::{VadContextParams, VadParams, VadParamsBuilder, VadSegments, WhisperVadProcessor};
77
78// Re-export for benchmarks
79#[doc(hidden)]
80pub mod bench_helpers {
81    pub use crate::vad::{VadParams, WhisperVadProcessor};
82}
83
84#[cfg(feature = "async")]
85pub use async_api::{AsyncWhisperStream, SharedAsyncStream};
86
87// Re-export the sys crate for advanced users who need lower-level access
88pub use whisper_cpp_plus_sys;
89
90impl WhisperContext {
91    /// Transcribe audio using default parameters
92    ///
93    /// # Arguments
94    /// * `audio` - Audio samples (must be 16kHz mono f32)
95    ///
96    /// # Returns
97    /// The transcribed text as a string
98    ///
99    /// # Example
100    /// ```no_run
101    /// # use whisper_cpp_plus::WhisperContext;
102    /// # fn main() -> whisper_cpp_plus::Result<()> {
103    /// let ctx = WhisperContext::new("model.bin")?;
104    /// let audio = vec![0.0f32; 16000]; // 1 second
105    /// let text = ctx.transcribe(&audio)?;
106    /// # Ok(())
107    /// # }
108    /// ```
109    pub fn transcribe(&self, audio: &[f32]) -> Result<String> {
110        let mut state = WhisperState::new(self)?;
111        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
112
113        state.full(params, audio)?;
114
115        let n_segments = state.full_n_segments();
116        let mut text = String::new();
117
118        for i in 0..n_segments {
119            if i > 0 {
120                text.push(' ');
121            }
122            text.push_str(&state.full_get_segment_text(i)?);
123        }
124
125        Ok(text)
126    }
127
128    /// Transcribe audio with custom parameters
129    ///
130    /// # Arguments
131    /// * `audio` - Audio samples (must be 16kHz mono f32)
132    /// * `params` - Custom transcription parameters
133    ///
134    /// # Returns
135    /// A `TranscriptionResult` containing the full text and individual segments
136    pub fn transcribe_with_params(
137        &self,
138        audio: &[f32],
139        params: TranscriptionParams,
140    ) -> Result<TranscriptionResult> {
141        self.transcribe_with_full_params(audio, params.into_full_params())
142    }
143
144    /// Transcribe audio with full control over parameters
145    ///
146    /// # Arguments
147    /// * `audio` - Audio samples (must be 16kHz mono f32)
148    /// * `params` - Full parameter configuration
149    ///
150    /// # Returns
151    /// A `TranscriptionResult` containing the full text and individual segments
152    pub fn transcribe_with_full_params(
153        &self,
154        audio: &[f32],
155        params: FullParams,
156    ) -> Result<TranscriptionResult> {
157        let mut state = WhisperState::new(self)?;
158        state.full(params, audio)?;
159
160        let n_segments = state.full_n_segments();
161        let mut segments = Vec::with_capacity(n_segments as usize);
162        let mut full_text = String::new();
163
164        for i in 0..n_segments {
165            let text = state.full_get_segment_text(i)?;
166            let (start_ms, end_ms) = state.full_get_segment_timestamps(i);
167            let speaker_turn_next = state.full_get_segment_speaker_turn_next(i);
168
169            if i > 0 {
170                full_text.push(' ');
171            }
172            full_text.push_str(&text);
173
174            segments.push(Segment {
175                start_ms,
176                end_ms,
177                text,
178                speaker_turn_next,
179            });
180        }
181
182        Ok(TranscriptionResult {
183            text: full_text,
184            segments,
185        })
186    }
187
188    /// Create a new state for manual transcription control
189    ///
190    /// This allows you to reuse a state for multiple transcriptions,
191    /// which can be more efficient than creating a new state each time.
192    pub fn create_state(&self) -> Result<WhisperState> {
193        WhisperState::new(self)
194    }
195
196    /// Enhanced transcription with custom parameters and temperature fallback
197    ///
198    /// This method provides quality-based retry with multiple temperatures
199    /// if the initial transcription doesn't meet quality thresholds.
200    ///
201    /// # Arguments
202    /// * `audio` - Audio samples (must be 16kHz mono f32)
203    /// * `params` - Custom transcription parameters
204    ///
205    /// # Returns
206    /// A `TranscriptionResult` containing the full text and individual segments
207    ///
208    /// # Example
209    /// ```no_run
210    /// # use whisper_cpp_plus::{WhisperContext, TranscriptionParams};
211    /// # fn main() -> whisper_cpp_plus::Result<()> {
212    /// let ctx = WhisperContext::new("model.bin")?;
213    /// let params = TranscriptionParams::builder()
214    ///     .language("en")
215    ///     .build();
216    /// let audio = vec![0.0f32; 16000];
217    /// let result = ctx.transcribe_with_params_enhanced(&audio, params)?;
218    /// # Ok(())
219    /// # }
220    /// ```
221    pub fn transcribe_with_params_enhanced(
222        &self,
223        audio: &[f32],
224        params: TranscriptionParams,
225    ) -> Result<TranscriptionResult> {
226        self.transcribe_with_full_params_enhanced(audio, params.into_full_params())
227    }
228
229    /// Enhanced transcription with full parameters and temperature fallback
230    ///
231    /// This method provides quality-based retry with multiple temperatures
232    /// if the initial transcription doesn't meet quality thresholds.
233    ///
234    /// # Arguments
235    /// * `audio` - Audio samples (must be 16kHz mono f32)
236    /// * `params` - Full parameter configuration
237    ///
238    /// # Returns
239    /// A `TranscriptionResult` containing the full text and individual segments
240    pub fn transcribe_with_full_params_enhanced(
241        &self,
242        audio: &[f32],
243        params: FullParams,
244    ) -> Result<TranscriptionResult> {
245        use crate::enhanced::fallback::{EnhancedTranscriptionParams, EnhancedWhisperState};
246
247        // Convert to enhanced params with default fallback settings
248        let enhanced_params = EnhancedTranscriptionParams::from_base(params);
249
250        // Use enhanced state with temperature fallback logic
251        let mut state = self.create_state()?;
252        let mut enhanced_state = EnhancedWhisperState::new(&mut state);
253        enhanced_state.transcribe_with_fallback(enhanced_params, audio)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use std::path::Path;
261    use std::sync::Arc;
262
263    #[test]
264    fn test_error_on_invalid_model() {
265        let result = WhisperContext::new("nonexistent_model.bin");
266        assert!(result.is_err());
267    }
268
269    #[test]
270    fn test_model_loading() {
271        let model_path = "tests/models/ggml-tiny.en.bin";
272        if Path::new(model_path).exists() {
273            let ctx = WhisperContext::new(model_path);
274            assert!(ctx.is_ok());
275        } else {
276            eprintln!("Skipping test_model_loading: model file not found");
277        }
278    }
279
280    #[test]
281    fn test_silence_handling() {
282        let model_path = "tests/models/ggml-tiny.en.bin";
283        if Path::new(model_path).exists() {
284            let ctx = WhisperContext::new(model_path).unwrap();
285            let silence = vec![0.0f32; 16000]; // 1 second of silence
286            let result = ctx.transcribe(&silence);
287            assert!(result.is_ok());
288        } else {
289            eprintln!("Skipping test_silence_handling: model file not found");
290        }
291    }
292
293    #[test]
294    fn test_concurrent_states() {
295        let model_path = "tests/models/ggml-tiny.en.bin";
296        if Path::new(model_path).exists() {
297            let ctx = Arc::new(WhisperContext::new(model_path).unwrap());
298            let handles: Vec<_> = (0..4)
299                .map(|_| {
300                    let ctx = Arc::clone(&ctx);
301                    std::thread::spawn(move || {
302                        let audio = vec![0.0f32; 16000];
303                        ctx.transcribe(&audio)
304                    })
305                })
306                .collect();
307
308            for handle in handles {
309                assert!(handle.join().unwrap().is_ok());
310            }
311        } else {
312            eprintln!("Skipping test_concurrent_states: model file not found");
313        }
314    }
315
316    #[test]
317    fn test_params_builder() {
318        let params = TranscriptionParams::builder()
319            .language("en")
320            .temperature(0.8)
321            .enable_timestamps()
322            .n_threads(4)
323            .build();
324
325        // Just ensure it builds without panic
326        let _ = params.into_full_params();
327    }
328}