car-voice 0.6.0

Voice I/O capability for CAR — mic capture, VAD, listener/speaker traits
Documentation
//! Chunk-overlap streaming wrapper around any [`SttProvider`].
//!
//! Whisper.cpp doesn't natively emit partials during decode — it
//! processes one fixed-size audio chunk at a time and returns a
//! single transcript. To approximate Granola/Meetily-style live
//! transcript UX (where words appear as they're spoken), we run
//! whisper on a rolling window of accumulated audio and treat each
//! re-transcription as a partial that supersedes the previous one.
//!
//! Stability detection: each partial is compared character-prefix to
//! the previous partial. The shared prefix is "stable" — it survived
//! one re-decode with more audio context and is unlikely to change
//! again. The diverging suffix is "in flight" and may revise.
//!
//! ```text
//!   feed(chunk_1)  →  whisper("hello")              partial: "hello"             stable_prefix=""
//!   feed(chunk_2)  →  whisper("hello world")        partial: "hello world"       stable_prefix="hello"
//!   feed(chunk_3)  →  whisper("hello world how")    partial: "hello world how"   stable_prefix="hello world"
//!   finalize()     →  whisper(everything)           transcript: "hello world how are you"
//! ```
//!
//! This is a real degradation vs native Parakeet streaming (each
//! re-decode burns CPU; the stable-prefix heuristic is char-based,
//! not token-based, so it can mis-stable on word boundaries). It's
//! the right v1 trade-off because it lets us ship live partials with
//! zero new model dependencies — Parakeet streaming lands when the
//! ONNX port does.

use crate::stt::SttProvider;
use crate::Result;
use std::sync::Arc;

/// Tunable parameters for the streamer.
#[derive(Debug, Clone)]
pub struct StreamingConfig {
    /// How much accumulated PCM (in ms) triggers the next partial
    /// re-decode. Smaller = faster partial updates, more CPU.
    /// Whisper.cpp processes ~5s of audio in ~0.5s on Apple Silicon
    /// Metal, so a 1500ms increment keeps roughly 3:1 wallclock
    /// margin.
    pub chunk_increment_ms: u32,

    /// Maximum window length to keep in the rolling buffer. When the
    /// buffer exceeds this, we drop the oldest samples — long
    /// utterances would otherwise re-decode an O(N²) amount of audio
    /// total over the segment. 30s window is enough to give whisper
    /// the context it wants without quadratic cost.
    pub max_window_ms: u32,
}

impl Default for StreamingConfig {
    fn default() -> Self {
        Self {
            chunk_increment_ms: 1500,
            max_window_ms: 30_000,
        }
    }
}

/// One partial transcription event.
#[derive(Debug, Clone, PartialEq)]
pub struct StreamingPartial {
    /// Full text of the latest re-decode.
    pub text: String,
    /// Prefix of `text` that has stabilized — same as the previous
    /// partial's prefix. Consumers can render this in a different
    /// style (e.g. opaque vs faded) to signal commitment.
    pub stable_prefix: String,
    /// Cumulative audio length (in ms) that produced this partial.
    pub duration_ms: u64,
}

/// Streaming wrapper around an [`SttProvider`].
///
/// Single-utterance lifecycle: construct → feed PCM in chunks →
/// finalize. The streamer is not Send-shared between concurrent
/// utterances; each new utterance gets a fresh instance.
pub struct ChunkOverlapStreamer {
    provider: Arc<dyn SttProvider>,
    config: StreamingConfig,
    sample_rate: u32,
    buffer: Vec<f32>,
    last_partial: Option<String>,
    /// Length of the buffer the last time we ran a partial decode.
    /// Used to decide when the next chunk_increment_ms threshold
    /// has been crossed.
    last_decode_len: usize,
}

impl ChunkOverlapStreamer {
    pub fn new(
        provider: Arc<dyn SttProvider>,
        sample_rate: u32,
        config: StreamingConfig,
    ) -> Self {
        Self {
            provider,
            config,
            sample_rate,
            buffer: Vec::new(),
            last_partial: None,
            last_decode_len: 0,
        }
    }

    /// Append PCM samples and, if enough new audio has accumulated,
    /// re-decode and return a partial. Returns `Ok(None)` when the
    /// chunk is below the increment threshold (no decode triggered).
    pub async fn feed(&mut self, samples: &[f32]) -> Result<Option<StreamingPartial>> {
        self.buffer.extend_from_slice(samples);
        self.trim_to_max_window();

        let increment_samples =
            (self.config.chunk_increment_ms as u64 * self.sample_rate as u64 / 1000) as usize;
        let new_samples = self.buffer.len().saturating_sub(self.last_decode_len);
        if new_samples < increment_samples {
            return Ok(None);
        }
        self.last_decode_len = self.buffer.len();
        let text = self
            .provider
            .transcribe(&self.buffer, self.sample_rate)
            .await?;
        let stable_prefix = match &self.last_partial {
            Some(prev) => longest_common_prefix(prev, &text),
            None => String::new(),
        };
        let duration_ms = self.buffer_duration_ms();
        self.last_partial = Some(text.clone());
        Ok(Some(StreamingPartial {
            text,
            stable_prefix,
            duration_ms,
        }))
    }

    /// Run a final decode on the full accumulated buffer. Returns the
    /// final transcript text. Resets the streamer's internal state so
    /// it can be reused for the next utterance.
    pub async fn finalize(&mut self) -> Result<String> {
        let text = if self.buffer.is_empty() {
            String::new()
        } else {
            self.provider
                .transcribe(&self.buffer, self.sample_rate)
                .await?
        };
        self.buffer.clear();
        self.last_partial = None;
        self.last_decode_len = 0;
        Ok(text)
    }

    /// Cumulative buffer length in ms — useful for tests / logging.
    pub fn buffer_duration_ms(&self) -> u64 {
        if self.sample_rate == 0 {
            return 0;
        }
        (self.buffer.len() as u64 * 1000) / self.sample_rate as u64
    }

    fn trim_to_max_window(&mut self) {
        let max_samples =
            (self.config.max_window_ms as u64 * self.sample_rate as u64 / 1000) as usize;
        if self.buffer.len() > max_samples {
            let drop = self.buffer.len() - max_samples;
            self.buffer.drain(..drop);
            // last_decode_len is an absolute index into the (now
            // shorter) buffer — adjust so the next increment trigger
            // still measures "samples since last decode".
            self.last_decode_len = self.last_decode_len.saturating_sub(drop);
        }
    }
}

/// Longest common character prefix of two strings.
///
/// Char-aware so multi-byte UTF-8 sequences (most languages whisper
/// supports) don't get split at byte boundaries — slicing at a non-
/// boundary would panic when constructing the prefix String.
fn longest_common_prefix(a: &str, b: &str) -> String {
    let mut out = String::new();
    for (ca, cb) in a.chars().zip(b.chars()) {
        if ca != cb {
            break;
        }
        out.push(ca);
    }
    out
}

#[cfg(test)]
mod tests {
    use super::*;
    use async_trait::async_trait;
    use std::sync::Mutex;

    /// Stub provider — returns a scripted transcript per call so we
    /// can drive the streamer without a real model.
    struct ScriptedProvider {
        outputs: Mutex<Vec<String>>,
    }
    impl ScriptedProvider {
        fn new(outputs: Vec<&str>) -> Self {
            Self {
                outputs: Mutex::new(outputs.into_iter().rev().map(str::to_string).collect()),
            }
        }
    }
    #[async_trait]
    impl SttProvider for ScriptedProvider {
        async fn transcribe(&self, _samples: &[f32], _sr: u32) -> Result<String> {
            let mut outs = self.outputs.lock().unwrap();
            Ok(outs.pop().unwrap_or_default())
        }
    }

    #[test]
    fn lcp_handles_disjoint_prefix() {
        assert_eq!(longest_common_prefix("hello", "world"), "");
        assert_eq!(longest_common_prefix("hello", "hellp"), "hell");
        assert_eq!(longest_common_prefix("hello", "hello!"), "hello");
        assert_eq!(longest_common_prefix("", "anything"), "");
    }

    #[test]
    fn lcp_is_char_aware_for_multibyte() {
        // "héllo" vs "héllo!" — 'é' is 2 bytes in UTF-8.
        assert_eq!(longest_common_prefix("héllo", "héllo!"), "héllo");
        // Diverging at the multi-byte char must not panic.
        assert_eq!(longest_common_prefix("", ""), "h");
    }

    #[tokio::test]
    async fn feed_returns_none_below_threshold() {
        let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec!["never"]));
        let config = StreamingConfig {
            chunk_increment_ms: 500,
            ..Default::default()
        };
        let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);
        // 100ms = 1600 samples — far under 500ms threshold (8000 samples)
        let result = streamer.feed(&vec![0.0f32; 1600]).await.unwrap();
        assert!(result.is_none());
    }

    #[tokio::test]
    async fn feed_emits_partial_above_threshold_with_stable_prefix() {
        let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec![
            "hello",
            "hello world",
            "hello world how",
        ]));
        let config = StreamingConfig {
            chunk_increment_ms: 500,
            ..Default::default()
        };
        let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);

        // 600ms — above threshold, triggers first decode.
        let p1 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
        assert_eq!(p1.text, "hello");
        assert_eq!(p1.stable_prefix, ""); // first partial has no prior

        // Another 600ms — second decode, stable prefix = "hello"
        let p2 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
        assert_eq!(p2.text, "hello world");
        assert_eq!(p2.stable_prefix, "hello");

        // Another 600ms — third decode
        let p3 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
        assert_eq!(p3.text, "hello world how");
        assert_eq!(p3.stable_prefix, "hello world");
    }

    #[tokio::test]
    async fn finalize_returns_full_transcript_and_resets() {
        let provider: Arc<dyn SttProvider> =
            Arc::new(ScriptedProvider::new(vec!["hello world how are you"]));
        let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, StreamingConfig::default());
        streamer.buffer.extend(vec![0.0f32; 16_000]); // 1s

        let final_text = streamer.finalize().await.unwrap();
        assert_eq!(final_text, "hello world how are you");

        // After finalize, internal state is reset.
        assert!(streamer.buffer.is_empty());
        assert!(streamer.last_partial.is_none());
        assert_eq!(streamer.last_decode_len, 0);
    }

    #[tokio::test]
    async fn finalize_on_empty_buffer_returns_empty_string() {
        let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec![]));
        let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, StreamingConfig::default());
        let final_text = streamer.finalize().await.unwrap();
        assert_eq!(final_text, "");
    }

    #[tokio::test]
    async fn buffer_trims_to_max_window() {
        let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec!["x"; 100]));
        let config = StreamingConfig {
            chunk_increment_ms: 100,
            max_window_ms: 1000,
        };
        let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);
        // Feed 3s of audio in chunks. Buffer should cap at 1s.
        for _ in 0..30 {
            let _ = streamer.feed(&vec![0.0f32; 1600]).await.unwrap();
        }
        assert!(streamer.buffer.len() <= 16_000);
    }
}