memo_stt/engine.rs
1use crate::Result;
2use num_cpus;
3use std::path::Path;
4use std::sync::{Arc, Mutex};
5use whisper_rs::{
6 FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState,
7};
8
9/// Speech-to-text engine optimized for speed and ease of use.
10///
11/// This is the main entry point for transcription. Create an engine, warm it up,
12/// and start transcribing audio samples.
13///
14/// # Example
15///
16/// ```no_run
17/// use memo_stt::SttEngine;
18///
19/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
20/// // Create engine with default model
21/// let mut engine = SttEngine::new_default(16000)?;
22///
23/// // Warm up GPU (reduces first-transcription latency)
24/// engine.warmup()?;
25///
26/// // Transcribe audio samples (16kHz, mono, i16 PCM)
27/// let samples: Vec<i16> = vec![]; // Your audio data here
28/// let text = engine.transcribe(&samples)?;
29/// println!("Transcribed: {}", text);
30/// # Ok(())
31/// # }
32/// ```
33///
34/// # Performance
35///
36/// - First transcription: ~500ms-1s (after warmup)
37/// - Subsequent transcriptions: ~200-500ms
38/// - GPU acceleration is automatic on supported platforms
39pub struct SttEngine {
40 state: Arc<Mutex<WhisperState>>,
41 initial_prompt: Option<String>, // Cache prompt, recreate params each time
42 input_sample_rate: u32,
43 f32_buffer: Vec<f32>, // Reusable buffer
44}
45
46impl SttEngine {
47 /// Create a new engine with the default model.
48 ///
49 /// The model will be automatically downloaded to the cache directory on first use.
50 /// For custom model paths, use [`new`](Self::new).
51 ///
52 /// # Arguments
53 ///
54 /// * `input_sample_rate` - Sample rate of input audio (e.g., 16000, 48000)
55 ///
56 /// # Example
57 ///
58 /// ```no_run
59 /// use memo_stt::SttEngine;
60 /// let engine = SttEngine::new_default(16000)?;
61 /// # Ok::<(), Box<dyn std::error::Error>>(())
62 /// ```
63 pub fn new_default(input_sample_rate: u32) -> Result<Self> {
64 // Ensure default model is available (downloads if needed)
65 let model_path = crate::ensure_model(crate::default_model_path())?;
66 Self::new(model_path, input_sample_rate)
67 }
68
69 /// Create a new engine with a custom model path.
70 ///
71 /// If the model doesn't exist, it will attempt to download it automatically
72 /// (if it's a known model name). Otherwise, you'll need to provide the full path
73 /// to an existing model file.
74 ///
75 /// # Arguments
76 ///
77 /// * `model_path` - Path to a GGML speech model, or model name
78 /// * `input_sample_rate` - Sample rate of input audio (e.g., 16000, 48000)
79 ///
80 /// # Example
81 ///
82 /// ```no_run
83 /// use memo_stt::SttEngine;
84 /// // Use default model (auto-downloads if needed)
85 /// let engine = SttEngine::new_default(16000)?;
86 ///
87 /// // Or specify a custom path
88 /// let engine = SttEngine::new("models/ggml-small.en-q5_1.bin", 16000)?;
89 /// # Ok::<(), Box<dyn std::error::Error>>(())
90 /// ```
91 ///
92 /// # Recommended Models
93 ///
94 /// - `ggml-small.en-q5_1.bin` (~500MB) - Best balance of speed and accuracy
95 /// - `ggml-distil-large-v3-q5_1.bin` (~500MB) - Higher accuracy
96 /// - `ggml-distil-large-v3-q8_0.bin` (~800MB) - Highest accuracy
97 ///
98 /// Models are downloaded from: <https://huggingface.co/ggerganov/whisper.cpp>
99 pub fn new(model_path: impl AsRef<Path>, input_sample_rate: u32) -> Result<Self> {
100 // Ensure model exists (may download if it's the default model)
101 let path = crate::ensure_model(model_path)?;
102
103 let path_str = path
104 .to_str()
105 .ok_or_else(|| crate::Error("Invalid model path".into()))?;
106
107 // Enable GPU/ACCEL auto-detection. The local runtime will use Metal/CUDA/
108 // Vulkan/OpenCL where available and fall back to CPU otherwise.
109 let params = WhisperContextParameters {
110 use_gpu: true,
111 ..WhisperContextParameters::default()
112 };
113
114 let ctx = WhisperContext::new_with_params(path_str, params)
115 .map_err(|e| crate::Error(format!("Failed to load model: {}", e)))?;
116
117 let state = ctx
118 .create_state()
119 .map_err(|e| crate::Error(format!("Failed to create state: {}", e)))?;
120
121 Ok(Self {
122 state: Arc::new(Mutex::new(state)),
123 initial_prompt: None,
124 input_sample_rate,
125 f32_buffer: Vec::with_capacity(48000), // Pre-allocate for common sizes
126 })
127 }
128
129 /// Transcribe audio samples to text.
130 ///
131 /// Takes PCM audio samples (16-bit signed integers) and returns transcribed text.
132 ///
133 /// # Arguments
134 ///
135 /// * `samples` - Audio samples as `i16` PCM data at the sample rate specified when creating the engine
136 ///
137 /// # Returns
138 ///
139 /// Transcribed text as a `String`. Returns empty string if no speech detected.
140 ///
141 /// # Example
142 ///
143 /// ```no_run
144 /// use memo_stt::SttEngine;
145 ///
146 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
147 /// let mut engine = SttEngine::new_default(16000)?;
148 /// engine.warmup()?;
149 ///
150 /// // Your audio samples (16kHz, mono, i16 PCM)
151 /// let samples: Vec<i16> = vec![]; // Replace with actual audio
152 /// let text = engine.transcribe(&samples)?;
153 /// println!("{}", text);
154 /// # Ok(())
155 /// # }
156 /// ```
157 ///
158 /// # Audio Format Requirements
159 ///
160 /// - Format: 16-bit signed integer PCM (`i16`)
161 /// - Channels: Mono
162 /// - Sample rate: Must match the `input_sample_rate` provided to `new()` or `new_default()`
163 /// - Minimum length: 1 second (16000 samples at 16kHz)
164 pub fn transcribe(&mut self, samples: &[i16]) -> Result<String> {
165 if samples.is_empty() {
166 return Ok(String::new());
167 }
168
169 // Normalize and resample inline
170 self.f32_buffer.clear();
171 if self.input_sample_rate == 16000 {
172 // Direct normalization, no resampling
173 self.f32_buffer.reserve(samples.len());
174 for &s in samples {
175 self.f32_buffer.push(s as f32 / 32768.0);
176 }
177 } else {
178 // Resample directly without intermediate Vec
179 let ratio = self.input_sample_rate as f32 / 16000.0;
180 let out_len = (samples.len() as f32 / ratio).max(1.0) as usize;
181 self.f32_buffer.reserve(out_len);
182 for i in 0..out_len {
183 let pos = i as f32 * ratio;
184 let i0 = pos.floor() as usize;
185 let i1 = (i0 + 1).min(samples.len().saturating_sub(1));
186 let t = pos - i0 as f32;
187 let s0 = samples[i0] as f32 / 32768.0;
188 let s1 = samples[i1] as f32 / 32768.0;
189 self.f32_buffer.push(s0 * (1.0 - t) + s1 * t);
190 }
191 }
192
193 if self.f32_buffer.len() < 16000 {
194 return Err(crate::Error(format!(
195 "Audio too short: {} samples",
196 self.f32_buffer.len()
197 )));
198 }
199
200 // Create params (reuse configuration pattern)
201 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
202 // Use all available CPU cores for transcription (thread count is set per-transcription)
203 // For Raspberry Pi, 4-6 threads is optimal
204 params.set_n_threads(num_cpus::get().min(8) as i32);
205 params.set_translate(false);
206 params.set_language(Some("en"));
207 params.set_print_progress(false);
208 params.set_print_special(false);
209 params.set_print_realtime(false);
210 params.set_print_timestamps(false);
211 params.set_suppress_blank(true);
212 params.set_suppress_non_speech_tokens(true);
213 params.set_max_len(0);
214 params.set_token_timestamps(false);
215 params.set_speed_up(false);
216 params.set_audio_ctx(0);
217 params.set_temperature(0.0);
218 params.set_max_initial_ts(1.0);
219 params.set_length_penalty(-1.0);
220 params.set_temperature_inc(0.2);
221 params.set_entropy_thold(2.4);
222 params.set_logprob_thold(-1.0);
223 params.set_no_speech_thold(0.6);
224 if let Some(ref prompt) = self.initial_prompt {
225 if !prompt.trim().is_empty() {
226 params.set_initial_prompt(prompt);
227 }
228 }
229
230 // Lock state and run inference
231 let mut state = self
232 .state
233 .lock()
234 .map_err(|e| crate::Error(format!("State lock failed: {}", e)))?;
235 state
236 .full(params, &self.f32_buffer)
237 .map_err(|e| crate::Error(format!("Inference failed: {}", e)))?;
238
239 // Extract text
240 let n = state
241 .full_n_segments()
242 .map_err(|e| crate::Error(format!("Failed to get segments: {}", e)))?;
243
244 let mut text = String::new();
245 for i in 0..n {
246 if let Ok(seg) = state.full_get_segment_text(i) {
247 if !text.is_empty() {
248 text.push(' ');
249 }
250 text.push_str(seg.trim());
251 }
252 }
253
254 Ok(text)
255 }
256
257 /// Set initial prompt for custom vocabulary or context.
258 ///
259 /// Useful for improving accuracy with domain-specific terms, names, or technical vocabulary.
260 ///
261 /// # Example
262 ///
263 /// ```no_run
264 /// use memo_stt::SttEngine;
265 ///
266 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
267 /// let mut engine = SttEngine::new_default(16000)?;
268 /// engine.set_prompt(Some("Rust programming language, cargo, crates.io".to_string()));
269 /// # Ok(())
270 /// # }
271 /// ```
272 pub fn set_prompt(&mut self, prompt: Option<String>) {
273 self.initial_prompt = prompt;
274 }
275
276 /// Warm up the GPU to reduce first-transcription latency.
277 ///
278 /// Call this after creating the engine to pre-initialize GPU resources.
279 /// The first transcription after warmup will be faster.
280 ///
281 /// # Example
282 ///
283 /// ```no_run
284 /// use memo_stt::SttEngine;
285 ///
286 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
287 /// let mut engine = SttEngine::new_default(16000)?;
288 /// engine.warmup()?; // Pre-initialize GPU
289 /// // Now transcriptions will be faster
290 /// # Ok(())
291 /// # }
292 /// ```
293 pub fn warmup(&self) -> Result<()> {
294 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
295 params.set_n_threads(2);
296 params.set_language(Some("en"));
297 params.set_print_progress(false);
298 params.set_print_special(false);
299 params.set_print_realtime(false);
300 let mut state = self
301 .state
302 .lock()
303 .map_err(|e| crate::Error(format!("State lock failed: {}", e)))?;
304 let _ = state.full(params, &vec![0.0f32; 1600]);
305 Ok(())
306 }
307}