Skip to main content

opencode_voice/audio/
capture.rs

1//! cpal-based microphone capture: 16kHz mono i16 PCM audio recording.
2//!
3//! Tries the ideal config (16kHz mono) first.  When the device doesn't support
4//! it — common on macOS — falls back to the device's native sample-rate and
5//! channel count and resamples to 16kHz mono in the audio callback.
6
7use anyhow::{Context, Result};
8use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
9use cpal::{SampleFormat, StreamConfig};
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13/// Lists all available audio input device names.
14pub fn list_devices() -> Result<Vec<String>> {
15    let host = cpal::default_host();
16    let devices = host
17        .input_devices()
18        .context("Failed to enumerate input devices")?;
19    Ok(devices.filter_map(|d| d.name().ok()).collect())
20}
21
22/// Resampling state carried across audio callbacks.
23struct ResampleState {
24    /// native_sample_rate / 16_000
25    ratio: f64,
26    /// Fractional input-sample position carried between callbacks.
27    phase: f64,
28}
29
30/// Records microphone audio via cpal, always producing 16kHz mono i16 output.
31pub struct CpalRecorder {
32    device_name: Option<String>,
33    samples: Arc<Mutex<Vec<i16>>>,
34    stream: Option<cpal::Stream>,
35    start_time: Option<Instant>,
36    energy_tx: Option<tokio::sync::mpsc::UnboundedSender<f32>>,
37}
38
39impl CpalRecorder {
40    /// Creates a new recorder for the given device (or default if None).
41    pub fn new(device: Option<&str>) -> Result<Self> {
42        Ok(CpalRecorder {
43            device_name: device.map(|s| s.to_string()),
44            samples: Arc::new(Mutex::new(Vec::new())),
45            stream: None,
46            start_time: None,
47            energy_tx: None,
48        })
49    }
50
51    /// Starts recording. Returns a receiver for RMS energy updates (0.0–1.0).
52    pub fn start(&mut self) -> Result<tokio::sync::mpsc::UnboundedReceiver<f32>> {
53        let host = cpal::default_host();
54
55        // Find device
56        let device = if let Some(ref name) = self.device_name {
57            host.input_devices()
58                .context("Failed to enumerate devices")?
59                .find(|d| d.name().map(|n| n == *name).unwrap_or(false))
60                .with_context(|| format!("Audio device '{}' not found", name))?
61        } else {
62            host.default_input_device().context(
63                "No default audio input device found. Please check microphone connection.",
64            )?
65        };
66
67        // Store the resolved device name for debugging.
68        self.device_name = device.name().ok();
69
70        let (energy_tx, energy_rx) = tokio::sync::mpsc::unbounded_channel::<f32>();
71
72        let stream = self.build_stream(&device, energy_tx.clone())?;
73
74        stream.play().context("Failed to start audio stream")?;
75
76        self.stream = Some(stream);
77        self.start_time = Some(Instant::now());
78        self.energy_tx = Some(energy_tx);
79
80        Ok(energy_rx)
81    }
82
83    /// Builds the cpal input stream.
84    ///
85    /// 1. Try 16kHz mono i16 (zero conversion — ideal).
86    /// 2. Try 16kHz mono f32 (format conversion only).
87    /// 3. Fall back to the device's native config and resample in the callback.
88    fn build_stream(
89        &self,
90        device: &cpal::Device,
91        energy_tx: tokio::sync::mpsc::UnboundedSender<f32>,
92    ) -> Result<cpal::Stream> {
93        let ideal_config = StreamConfig {
94            channels: 1,
95            sample_rate: cpal::SampleRate(16_000),
96            buffer_size: cpal::BufferSize::Default,
97        };
98
99        let debug = std::env::var("RUST_LOG").is_ok();
100
101        // --- Strategy 1: 16kHz mono i16 (ideal) ---
102        if let Ok(stream) = self.build_direct_i16_stream(device, &ideal_config, energy_tx.clone()) {
103            if debug {
104                eprintln!("[audio] Using 16kHz mono i16 (ideal)");
105            }
106            return Ok(stream);
107        }
108
109        // --- Strategy 2: 16kHz mono f32 ---
110        if let Ok(stream) = self.build_direct_f32_stream(device, &ideal_config, energy_tx.clone()) {
111            if debug {
112                eprintln!("[audio] Using 16kHz mono f32");
113            }
114            return Ok(stream);
115        }
116
117        // --- Strategy 3: native config + resample ---
118        let default_config = device
119            .default_input_config()
120            .context("Failed to get any supported input config from audio device")?;
121
122        let native_rate = default_config.sample_rate().0;
123        let native_channels = default_config.channels();
124        let native_format = default_config.sample_format();
125
126        if debug {
127            eprintln!(
128                "[audio] Capturing at native {}Hz {}ch {:?}, resampling to 16kHz",
129                native_rate, native_channels, native_format
130            );
131        }
132
133        let stream_config = StreamConfig {
134            channels: native_channels,
135            sample_rate: cpal::SampleRate(native_rate),
136            buffer_size: cpal::BufferSize::Default,
137        };
138
139        match native_format {
140            SampleFormat::I16 => self.build_resampling_i16_stream(
141                device,
142                &stream_config,
143                native_rate,
144                native_channels,
145                energy_tx,
146            ),
147            _ => self.build_resampling_f32_stream(
148                device,
149                &stream_config,
150                native_rate,
151                native_channels,
152                energy_tx,
153            ),
154        }
155        .context("Failed to build audio input stream with any supported configuration. Check microphone permissions.")
156    }
157
158    // ---------------------------------------------------------------
159    //  Direct streams (16kHz mono, no resampling)
160    // ---------------------------------------------------------------
161
162    /// 16kHz mono i16 — no conversion needed.
163    fn build_direct_i16_stream(
164        &self,
165        device: &cpal::Device,
166        config: &StreamConfig,
167        energy_tx: tokio::sync::mpsc::UnboundedSender<f32>,
168    ) -> Result<cpal::Stream> {
169        let samples_arc = Arc::clone(&self.samples);
170
171        let stream = device
172            .build_input_stream(
173                config,
174                move |data: &[i16], _: &cpal::InputCallbackInfo| {
175                    if !data.is_empty() {
176                        let sum_sq: f64 = data
177                            .iter()
178                            .map(|&s| {
179                                let f = s as f64 / 32768.0;
180                                f * f
181                            })
182                            .sum();
183                        let rms = (sum_sq / data.len() as f64).sqrt() as f32;
184                        let _ = energy_tx.send(rms.min(1.0));
185                    }
186                    if let Ok(mut guard) = samples_arc.try_lock() {
187                        guard.extend_from_slice(data);
188                    }
189                },
190                |err| eprintln!("Audio stream error: {}", err),
191                None,
192            )
193            .map_err(|e| anyhow::anyhow!("i16 stream: {}", e))?;
194
195        Ok(stream)
196    }
197
198    /// 16kHz mono f32 — format conversion only (f32 → i16).
199    fn build_direct_f32_stream(
200        &self,
201        device: &cpal::Device,
202        config: &StreamConfig,
203        energy_tx: tokio::sync::mpsc::UnboundedSender<f32>,
204    ) -> Result<cpal::Stream> {
205        let samples_arc = Arc::clone(&self.samples);
206
207        let stream = device
208            .build_input_stream(
209                config,
210                move |data: &[f32], _: &cpal::InputCallbackInfo| {
211                    if !data.is_empty() {
212                        let sum_sq: f64 = data.iter().map(|&s| (s as f64) * (s as f64)).sum();
213                        let rms = (sum_sq / data.len() as f64).sqrt() as f32;
214                        let _ = energy_tx.send(rms.min(1.0));
215                    }
216                    if let Ok(mut guard) = samples_arc.try_lock() {
217                        for &s in data {
218                            let clamped = s.clamp(-1.0, 1.0);
219                            guard.push((clamped * 32767.0) as i16);
220                        }
221                    }
222                },
223                |err| eprintln!("Audio stream error: {}", err),
224                None,
225            )
226            .map_err(|e| anyhow::anyhow!("f32 stream: {}", e))?;
227
228        Ok(stream)
229    }
230
231    // ---------------------------------------------------------------
232    //  Resampling streams (native rate/channels → 16kHz mono)
233    // ---------------------------------------------------------------
234
235    /// Native-rate f32 stream with downmix + resample to 16kHz mono i16.
236    fn build_resampling_f32_stream(
237        &self,
238        device: &cpal::Device,
239        config: &StreamConfig,
240        native_rate: u32,
241        native_channels: u16,
242        energy_tx: tokio::sync::mpsc::UnboundedSender<f32>,
243    ) -> Result<cpal::Stream> {
244        let samples_arc = Arc::clone(&self.samples);
245        let state = Arc::new(Mutex::new(ResampleState {
246            ratio: native_rate as f64 / 16_000.0,
247            phase: 0.0,
248        }));
249
250        let stream = device
251            .build_input_stream(
252                config,
253                move |data: &[f32], _: &cpal::InputCallbackInfo| {
254                    let ch = native_channels as usize;
255
256                    // --- Downmix to mono ---
257                    let mono: Vec<f32> = if ch > 1 {
258                        data.chunks(ch)
259                            .map(|frame| frame.iter().sum::<f32>() / ch as f32)
260                            .collect()
261                    } else {
262                        data.to_vec()
263                    };
264
265                    // --- RMS energy ---
266                    if !mono.is_empty() {
267                        let sum_sq: f64 = mono.iter().map(|&s| (s as f64) * (s as f64)).sum();
268                        let rms = (sum_sq / mono.len() as f64).sqrt() as f32;
269                        let _ = energy_tx.send(rms.min(1.0));
270                    }
271
272                    // --- Resample (linear interpolation) ---
273                    if let Ok(mut st) = state.lock() {
274                        let ratio = st.ratio;
275                        let mut phase = st.phase;
276                        let len = mono.len() as f64;
277                        let mut resampled = Vec::new();
278
279                        while phase < len {
280                            let idx = phase as usize;
281                            let frac = (phase - idx as f64) as f32;
282                            let a = mono[idx];
283                            let b = if idx + 1 < mono.len() {
284                                mono[idx + 1]
285                            } else {
286                                a
287                            };
288                            let sample = a + (b - a) * frac;
289                            let clamped = sample.clamp(-1.0, 1.0);
290                            resampled.push((clamped * 32767.0) as i16);
291                            phase += ratio;
292                        }
293
294                        st.phase = phase - len;
295
296                        if let Ok(mut guard) = samples_arc.try_lock() {
297                            guard.extend_from_slice(&resampled);
298                        }
299                    }
300                },
301                |err| eprintln!("Audio stream error: {}", err),
302                None,
303            )
304            .map_err(|e| anyhow::anyhow!("Resampling f32 stream: {}", e))?;
305
306        Ok(stream)
307    }
308
309    /// Native-rate i16 stream with downmix + resample to 16kHz mono i16.
310    fn build_resampling_i16_stream(
311        &self,
312        device: &cpal::Device,
313        config: &StreamConfig,
314        native_rate: u32,
315        native_channels: u16,
316        energy_tx: tokio::sync::mpsc::UnboundedSender<f32>,
317    ) -> Result<cpal::Stream> {
318        let samples_arc = Arc::clone(&self.samples);
319        let state = Arc::new(Mutex::new(ResampleState {
320            ratio: native_rate as f64 / 16_000.0,
321            phase: 0.0,
322        }));
323
324        let stream = device
325            .build_input_stream(
326                config,
327                move |data: &[i16], _: &cpal::InputCallbackInfo| {
328                    let ch = native_channels as usize;
329
330                    // --- Convert to f32 and downmix to mono ---
331                    let mono: Vec<f32> = if ch > 1 {
332                        data.chunks(ch)
333                            .map(|frame| {
334                                let sum: f32 = frame.iter().map(|&s| s as f32 / 32768.0).sum();
335                                sum / ch as f32
336                            })
337                            .collect()
338                    } else {
339                        data.iter().map(|&s| s as f32 / 32768.0).collect()
340                    };
341
342                    // --- RMS energy ---
343                    if !mono.is_empty() {
344                        let sum_sq: f64 = mono.iter().map(|&s| (s as f64) * (s as f64)).sum();
345                        let rms = (sum_sq / mono.len() as f64).sqrt() as f32;
346                        let _ = energy_tx.send(rms.min(1.0));
347                    }
348
349                    // --- Resample (linear interpolation) ---
350                    if let Ok(mut st) = state.lock() {
351                        let ratio = st.ratio;
352                        let mut phase = st.phase;
353                        let len = mono.len() as f64;
354                        let mut resampled = Vec::new();
355
356                        while phase < len {
357                            let idx = phase as usize;
358                            let frac = (phase - idx as f64) as f32;
359                            let a = mono[idx];
360                            let b = if idx + 1 < mono.len() {
361                                mono[idx + 1]
362                            } else {
363                                a
364                            };
365                            let sample = a + (b - a) * frac;
366                            let clamped = sample.clamp(-1.0, 1.0);
367                            resampled.push((clamped * 32767.0) as i16);
368                            phase += ratio;
369                        }
370
371                        st.phase = phase - len;
372
373                        if let Ok(mut guard) = samples_arc.try_lock() {
374                            guard.extend_from_slice(&resampled);
375                        }
376                    }
377                },
378                |err| eprintln!("Audio stream error: {}", err),
379                None,
380            )
381            .map_err(|e| anyhow::anyhow!("Resampling i16 stream: {}", e))?;
382
383        Ok(stream)
384    }
385
386    /// Stops recording and returns all captured samples (16kHz mono i16).
387    pub fn stop(&mut self) -> Result<Vec<i16>> {
388        // Drop the stream to stop recording
389        self.stream = None;
390        self.energy_tx = None;
391
392        let samples = {
393            let guard = self
394                .samples
395                .lock()
396                .map_err(|_| anyhow::anyhow!("Failed to lock samples buffer"))?;
397            guard.clone()
398        };
399
400        // Clear for next use
401        if let Ok(mut guard) = self.samples.lock() {
402            guard.clear();
403        }
404
405        Ok(samples)
406    }
407
408    /// Returns the resolved audio device name (available after `start()`).
409    pub fn device_name(&self) -> Option<&str> {
410        self.device_name.as_deref()
411    }
412
413    /// Returns the elapsed recording duration in seconds.
414    pub fn duration(&self) -> f64 {
415        self.start_time
416            .map(|t| t.elapsed().as_secs_f64())
417            .unwrap_or(0.0)
418    }
419}
420
421// Safety: CpalRecorder is Send because Arc<Mutex<>> handles shared state.
422// cpal::Stream is not Send on all platforms (e.g. macOS CoreAudio), but we
423// manage it carefully: the stream is only dropped (in stop()), never accessed
424// from another thread after creation.
425unsafe impl Send for CpalRecorder {}