Skip to main content

memo_stt/
engine.rs

1use crate::Result;
2use num_cpus;
3use std::path::Path;
4use std::sync::{Arc, Mutex};
5use whisper_rs::{
6    FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState,
7};
8
9/// Speech-to-text engine optimized for speed and ease of use.
10///
11/// This is the main entry point for transcription. Create an engine, warm it up,
12/// and start transcribing audio samples.
13///
14/// # Example
15///
16/// ```no_run
17/// use memo_stt::SttEngine;
18///
19/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
20/// // Create engine with default model
21/// let mut engine = SttEngine::new_default(16000)?;
22///
23/// // Warm up GPU (reduces first-transcription latency)
24/// engine.warmup()?;
25///
26/// // Transcribe audio samples (16kHz, mono, i16 PCM)
27/// let samples: Vec<i16> = vec![]; // Your audio data here
28/// let text = engine.transcribe(&samples)?;
29/// println!("Transcribed: {}", text);
30/// # Ok(())
31/// # }
32/// ```
33///
34/// # Performance
35///
36/// - First transcription: ~500ms-1s (after warmup)
37/// - Subsequent transcriptions: ~200-500ms
38/// - GPU acceleration is automatic on supported platforms
39pub struct SttEngine {
40    state: Arc<Mutex<WhisperState>>,
41    initial_prompt: Option<String>, // Cache prompt, recreate params each time
42    input_sample_rate: u32,
43    f32_buffer: Vec<f32>, // Reusable buffer
44}
45
46impl SttEngine {
47    /// Create a new engine with the default model.
48    ///
49    /// The model will be automatically downloaded to the cache directory on first use.
50    /// For custom model paths, use [`new`](Self::new).
51    ///
52    /// # Arguments
53    ///
54    /// * `input_sample_rate` - Sample rate of input audio (e.g., 16000, 48000)
55    ///
56    /// # Example
57    ///
58    /// ```no_run
59    /// use memo_stt::SttEngine;
60    /// let engine = SttEngine::new_default(16000)?;
61    /// # Ok::<(), Box<dyn std::error::Error>>(())
62    /// ```
63    pub fn new_default(input_sample_rate: u32) -> Result<Self> {
64        // Ensure default model is available (downloads if needed)
65        let model_path = crate::ensure_model(crate::default_model_path())?;
66        Self::new(model_path, input_sample_rate)
67    }
68
69    /// Create a new engine with a custom model path.
70    ///
71    /// If the model doesn't exist, it will attempt to download it automatically
72    /// (if it's a known model name). Otherwise, you'll need to provide the full path
73    /// to an existing model file.
74    ///
75    /// # Arguments
76    ///
77    /// * `model_path` - Path to a GGML speech model, or model name
78    /// * `input_sample_rate` - Sample rate of input audio (e.g., 16000, 48000)
79    ///
80    /// # Example
81    ///
82    /// ```no_run
83    /// use memo_stt::SttEngine;
84    /// // Use default model (auto-downloads if needed)
85    /// let engine = SttEngine::new_default(16000)?;
86    ///
87    /// // Or specify a custom path
88    /// let engine = SttEngine::new("models/ggml-small.en-q5_1.bin", 16000)?;
89    /// # Ok::<(), Box<dyn std::error::Error>>(())
90    /// ```
91    ///
92    /// # Recommended Models
93    ///
94    /// - `ggml-small.en-q5_1.bin` (~500MB) - Best balance of speed and accuracy
95    /// - `ggml-distil-large-v3-q5_1.bin` (~500MB) - Higher accuracy
96    /// - `ggml-distil-large-v3-q8_0.bin` (~800MB) - Highest accuracy
97    ///
98    /// Models are downloaded from: <https://huggingface.co/ggerganov/whisper.cpp>
99    pub fn new(model_path: impl AsRef<Path>, input_sample_rate: u32) -> Result<Self> {
100        // Ensure model exists (may download if it's the default model)
101        let path = crate::ensure_model(model_path)?;
102
103        let path_str = path
104            .to_str()
105            .ok_or_else(|| crate::Error("Invalid model path".into()))?;
106
107        // Enable GPU/ACCEL auto-detection. The local runtime will use Metal/CUDA/
108        // Vulkan/OpenCL where available and fall back to CPU otherwise.
109        let params = WhisperContextParameters {
110            use_gpu: true,
111            ..WhisperContextParameters::default()
112        };
113
114        let ctx = WhisperContext::new_with_params(path_str, params)
115            .map_err(|e| crate::Error(format!("Failed to load model: {}", e)))?;
116
117        let state = ctx
118            .create_state()
119            .map_err(|e| crate::Error(format!("Failed to create state: {}", e)))?;
120
121        Ok(Self {
122            state: Arc::new(Mutex::new(state)),
123            initial_prompt: None,
124            input_sample_rate,
125            f32_buffer: Vec::with_capacity(48000), // Pre-allocate for common sizes
126        })
127    }
128
129    /// Transcribe audio samples to text.
130    ///
131    /// Takes PCM audio samples (16-bit signed integers) and returns transcribed text.
132    ///
133    /// # Arguments
134    ///
135    /// * `samples` - Audio samples as `i16` PCM data at the sample rate specified when creating the engine
136    ///
137    /// # Returns
138    ///
139    /// Transcribed text as a `String`. Returns empty string if no speech detected.
140    ///
141    /// # Example
142    ///
143    /// ```no_run
144    /// use memo_stt::SttEngine;
145    ///
146    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
147    /// let mut engine = SttEngine::new_default(16000)?;
148    /// engine.warmup()?;
149    ///
150    /// // Your audio samples (16kHz, mono, i16 PCM)
151    /// let samples: Vec<i16> = vec![]; // Replace with actual audio
152    /// let text = engine.transcribe(&samples)?;
153    /// println!("{}", text);
154    /// # Ok(())
155    /// # }
156    /// ```
157    ///
158    /// # Audio Format Requirements
159    ///
160    /// - Format: 16-bit signed integer PCM (`i16`)
161    /// - Channels: Mono
162    /// - Sample rate: Must match the `input_sample_rate` provided to `new()` or `new_default()`
163    /// - Minimum length: 1 second (16000 samples at 16kHz)
164    pub fn transcribe(&mut self, samples: &[i16]) -> Result<String> {
165        if samples.is_empty() {
166            return Ok(String::new());
167        }
168
169        // Normalize and resample inline
170        self.f32_buffer.clear();
171        if self.input_sample_rate == 16000 {
172            // Direct normalization, no resampling
173            self.f32_buffer.reserve(samples.len());
174            for &s in samples {
175                self.f32_buffer.push(s as f32 / 32768.0);
176            }
177        } else {
178            // Resample directly without intermediate Vec
179            let ratio = self.input_sample_rate as f32 / 16000.0;
180            let out_len = (samples.len() as f32 / ratio).max(1.0) as usize;
181            self.f32_buffer.reserve(out_len);
182            for i in 0..out_len {
183                let pos = i as f32 * ratio;
184                let i0 = pos.floor() as usize;
185                let i1 = (i0 + 1).min(samples.len().saturating_sub(1));
186                let t = pos - i0 as f32;
187                let s0 = samples[i0] as f32 / 32768.0;
188                let s1 = samples[i1] as f32 / 32768.0;
189                self.f32_buffer.push(s0 * (1.0 - t) + s1 * t);
190            }
191        }
192
193        if self.f32_buffer.len() < 16000 {
194            return Err(crate::Error(format!(
195                "Audio too short: {} samples",
196                self.f32_buffer.len()
197            )));
198        }
199
200        // Create params (reuse configuration pattern)
201        let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
202        // Use all available CPU cores for transcription (thread count is set per-transcription)
203        // For Raspberry Pi, 4-6 threads is optimal
204        params.set_n_threads(num_cpus::get().min(8) as i32);
205        params.set_translate(false);
206        params.set_language(Some("en"));
207        params.set_print_progress(false);
208        params.set_print_special(false);
209        params.set_print_realtime(false);
210        params.set_print_timestamps(false);
211        params.set_suppress_blank(true);
212        params.set_suppress_non_speech_tokens(true);
213        params.set_max_len(0);
214        params.set_token_timestamps(false);
215        params.set_speed_up(false);
216        params.set_audio_ctx(0);
217        params.set_temperature(0.0);
218        params.set_max_initial_ts(1.0);
219        params.set_length_penalty(-1.0);
220        params.set_temperature_inc(0.2);
221        params.set_entropy_thold(2.4);
222        params.set_logprob_thold(-1.0);
223        params.set_no_speech_thold(0.6);
224        if let Some(ref prompt) = self.initial_prompt {
225            if !prompt.trim().is_empty() {
226                params.set_initial_prompt(prompt);
227            }
228        }
229
230        // Lock state and run inference
231        let mut state = self
232            .state
233            .lock()
234            .map_err(|e| crate::Error(format!("State lock failed: {}", e)))?;
235        state
236            .full(params, &self.f32_buffer)
237            .map_err(|e| crate::Error(format!("Inference failed: {}", e)))?;
238
239        // Extract text
240        let n = state
241            .full_n_segments()
242            .map_err(|e| crate::Error(format!("Failed to get segments: {}", e)))?;
243
244        let mut text = String::new();
245        for i in 0..n {
246            if let Ok(seg) = state.full_get_segment_text(i) {
247                if !text.is_empty() {
248                    text.push(' ');
249                }
250                text.push_str(seg.trim());
251            }
252        }
253
254        Ok(text)
255    }
256
257    /// Set initial prompt for custom vocabulary or context.
258    ///
259    /// Useful for improving accuracy with domain-specific terms, names, or technical vocabulary.
260    ///
261    /// # Example
262    ///
263    /// ```no_run
264    /// use memo_stt::SttEngine;
265    ///
266    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
267    /// let mut engine = SttEngine::new_default(16000)?;
268    /// engine.set_prompt(Some("Rust programming language, cargo, crates.io".to_string()));
269    /// # Ok(())
270    /// # }
271    /// ```
272    pub fn set_prompt(&mut self, prompt: Option<String>) {
273        self.initial_prompt = prompt;
274    }
275
276    /// Warm up the GPU to reduce first-transcription latency.
277    ///
278    /// Call this after creating the engine to pre-initialize GPU resources.
279    /// The first transcription after warmup will be faster.
280    ///
281    /// # Example
282    ///
283    /// ```no_run
284    /// use memo_stt::SttEngine;
285    ///
286    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
287    /// let mut engine = SttEngine::new_default(16000)?;
288    /// engine.warmup()?; // Pre-initialize GPU
289    /// // Now transcriptions will be faster
290    /// # Ok(())
291    /// # }
292    /// ```
293    pub fn warmup(&self) -> Result<()> {
294        let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
295        params.set_n_threads(2);
296        params.set_language(Some("en"));
297        params.set_print_progress(false);
298        params.set_print_special(false);
299        params.set_print_realtime(false);
300        let mut state = self
301            .state
302            .lock()
303            .map_err(|e| crate::Error(format!("State lock failed: {}", e)))?;
304        let _ = state.full(params, &vec![0.0f32; 1600]);
305        Ok(())
306    }
307}