transcribe-cli 0.0.5

Whisper CLI transcription pipeline on CTranslate2 with CPU and optional CUDA support
use std::collections::VecDeque;

use anyhow::{Result, bail};

const DEFAULT_WINDOW_SECONDS: usize = 10;
const DEFAULT_OVERLAP_SECONDS: usize = 2;
const DEFAULT_MAX_DEDUP_WORDS: usize = 24;
const DEFAULT_HOLD_WORDS: usize = 3;

#[derive(Clone, Debug)]
pub struct StreamConfig {
    pub sample_rate: usize,
    pub window_samples: usize,
    pub overlap_samples: usize,
    pub max_dedup_words: usize,
    pub hold_words: usize,
}

impl StreamConfig {
    pub fn for_model(sample_rate: usize, window_samples: usize) -> Result<Self> {
        if sample_rate == 0 {
            bail!("stream sample rate must be greater than zero");
        }
        if window_samples == 0 {
            bail!("stream window size must be greater than zero");
        }

        let overlap_samples = (sample_rate * DEFAULT_OVERLAP_SECONDS).min(window_samples / 2);

        Ok(Self {
            sample_rate,
            window_samples,
            overlap_samples,
            max_dedup_words: DEFAULT_MAX_DEDUP_WORDS,
            hold_words: DEFAULT_HOLD_WORDS,
        })
    }

    pub fn for_realtime(sample_rate: usize, model_window_samples: usize) -> Result<Self> {
        if sample_rate == 0 {
            bail!("stream sample rate must be greater than zero");
        }
        if model_window_samples == 0 {
            bail!("stream window size must be greater than zero");
        }

        let preferred_window_samples = sample_rate * DEFAULT_WINDOW_SECONDS;
        Self::for_model(
            sample_rate,
            preferred_window_samples.min(model_window_samples),
        )
    }

    pub fn step_samples(&self) -> usize {
        self.window_samples
            .saturating_sub(self.overlap_samples)
            .max(1)
    }

    pub fn overlap_seconds(&self) -> usize {
        self.overlap_samples / self.sample_rate
    }

    pub fn window_count(&self, total_samples: usize) -> usize {
        if total_samples == 0 {
            return 0;
        }
        total_samples.div_ceil(self.step_samples())
    }
}

#[derive(Clone, Debug)]
pub struct StreamChunk {
    pub index: usize,
    pub start_sample: usize,
    pub end_sample: usize,
    pub samples: Vec<f32>,
    pub is_partial: bool,
}

impl StreamChunk {
    pub fn status(&self, total_chunks: usize, overlap_seconds: usize) -> String {
        let partial_suffix = if self.is_partial { " / tail" } else { "" };
        format!(
            "chunk {}/{} / samples {}..{} / {}s overlap{}",
            self.index + 1,
            total_chunks.max(1),
            self.start_sample,
            self.end_sample,
            overlap_seconds,
            partial_suffix
        )
    }
}

#[derive(Debug)]
pub struct StreamEngine {
    config: StreamConfig,
    audio_buffer: VecDeque<f32>,
    next_chunk_index: usize,
    next_start_sample: usize,
    seen_tail: String,
    held_words: Vec<String>,
}

impl StreamEngine {
    pub fn new(config: StreamConfig) -> Self {
        Self {
            config,
            audio_buffer: VecDeque::new(),
            next_chunk_index: 0,
            next_start_sample: 0,
            seen_tail: String::new(),
            held_words: Vec::new(),
        }
    }

    pub fn config(&self) -> &StreamConfig {
        &self.config
    }

    pub fn push_audio(&mut self, samples: &[f32]) -> Vec<StreamChunk> {
        self.audio_buffer.extend(samples.iter().copied());

        let mut chunks = Vec::new();
        while self.audio_buffer.len() >= self.config.window_samples {
            chunks.push(self.build_chunk(self.config.window_samples, false));
            self.advance_audio();
        }

        chunks
    }

    pub fn finish_audio(&mut self) -> Option<StreamChunk> {
        if self.audio_buffer.is_empty() {
            return None;
        }

        let remaining = self.audio_buffer.len();
        let chunk = self.build_chunk(remaining, true);
        self.audio_buffer.clear();
        self.next_chunk_index += 1;
        self.next_start_sample += self.config.step_samples();
        Some(chunk)
    }

    pub fn stabilize_text(&mut self, raw_text: &str) -> Option<String> {
        let text = trim_stream_overlap(&self.seen_tail, raw_text, self.config.max_dedup_words);
        if text.is_empty() {
            return None;
        }

        self.held_words
            .extend(text.split_whitespace().map(str::to_string));
        self.seen_tail = merge_stream_tail(&self.seen_tail, &text, self.config.max_dedup_words);

        let emit_now = self.held_words.len().saturating_sub(self.config.hold_words);
        if emit_now == 0 {
            return None;
        }

        let stable_text = self.held_words[..emit_now].join(" ");
        self.held_words.drain(..emit_now);
        Some(stable_text)
    }

    pub fn finish_text(&mut self) -> Option<String> {
        if self.held_words.is_empty() {
            return None;
        }

        let final_text = self.held_words.join(" ");
        self.held_words.clear();
        Some(final_text)
    }

    fn build_chunk(&self, len: usize, is_partial: bool) -> StreamChunk {
        let samples = self
            .audio_buffer
            .iter()
            .take(len)
            .copied()
            .collect::<Vec<_>>();
        StreamChunk {
            index: self.next_chunk_index,
            start_sample: self.next_start_sample,
            end_sample: self.next_start_sample + len,
            samples,
            is_partial,
        }
    }

    fn advance_audio(&mut self) {
        let step = self.config.step_samples();
        let to_drop = step.min(self.audio_buffer.len());
        self.audio_buffer.drain(..to_drop);
        self.next_chunk_index += 1;
        self.next_start_sample += step;
    }
}

fn trim_stream_overlap(previous_tail: &str, current: &str, max_dedup_words: usize) -> String {
    let current_words = current.split_whitespace().collect::<Vec<_>>();
    if current_words.is_empty() {
        return String::new();
    }

    let previous_words = previous_tail.split_whitespace().collect::<Vec<_>>();
    let max_overlap = previous_words
        .len()
        .min(current_words.len())
        .min(max_dedup_words);

    for overlap in (1..=max_overlap).rev() {
        let previous_slice = &previous_words[previous_words.len() - overlap..];
        let current_slice = &current_words[..overlap];

        if words_match(previous_slice, current_slice) {
            return current_words[overlap..].join(" ");
        }
    }

    current.trim().to_string()
}

fn merge_stream_tail(previous_tail: &str, current: &str, max_dedup_words: usize) -> String {
    let mut words = previous_tail
        .split_whitespace()
        .chain(current.split_whitespace())
        .collect::<Vec<_>>();

    if words.len() > max_dedup_words {
        words = words.split_off(words.len() - max_dedup_words);
    }

    words.join(" ")
}

fn words_match(previous: &[&str], current: &[&str]) -> bool {
    previous.len() == current.len()
        && previous
            .iter()
            .zip(current.iter())
            .all(|(left, right)| normalize_word(left) == normalize_word(right))
}

fn normalize_word(word: &str) -> String {
    word.chars()
        .filter(|character| character.is_alphanumeric())
        .flat_map(|character| character.to_lowercase())
        .collect()
}

#[cfg(test)]
mod tests {
    use super::{StreamConfig, StreamEngine};

    #[test]
    fn emits_full_and_partial_audio_windows() {
        let config = StreamConfig {
            sample_rate: 16_000,
            window_samples: 6,
            overlap_samples: 2,
            max_dedup_words: 24,
            hold_words: 3,
        };
        let mut engine = StreamEngine::new(config);

        let full = engine.push_audio(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        let tail = engine.finish_audio().expect("tail chunk");

        assert_eq!(full.len(), 1);
        assert_eq!(full[0].samples, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
        assert_eq!(full[0].start_sample, 0);
        assert_eq!(full[0].end_sample, 6);
        assert_eq!(tail.samples, vec![4.0, 5.0, 6.0, 7.0, 8.0]);
        assert_eq!(tail.start_sample, 4);
        assert!(tail.is_partial);
    }

    #[test]
    fn stabilizes_transcript_overlap_and_holds_tail_words() {
        let config = StreamConfig {
            sample_rate: 16_000,
            window_samples: 6,
            overlap_samples: 2,
            max_dedup_words: 24,
            hold_words: 2,
        };
        let mut engine = StreamEngine::new(config);

        assert_eq!(
            engine.stabilize_text("hello brave new world"),
            Some("hello brave".to_string())
        );
        assert_eq!(
            engine.stabilize_text("new world again there"),
            Some("new world".to_string())
        );
        assert_eq!(engine.finish_text(), Some("again there".to_string()));
    }

    #[test]
    fn computes_window_count_for_incremental_stream() {
        let config = StreamConfig::for_model(16_000, 16_000).expect("config");

        assert_eq!(config.window_count(0), 0);
        assert_eq!(config.window_count(16_000), 2);
        assert_eq!(config.window_count(24_000), 3);
    }

    #[test]
    fn caps_realtime_window_to_ten_seconds() {
        let config = StreamConfig::for_realtime(16_000, 480_000).expect("config");

        assert_eq!(config.window_samples, 160_000);
        assert_eq!(config.overlap_samples, 32_000);
    }
}