use std::io::{self, IsTerminal, Write};
use std::path::Path;
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
};
use std::thread;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use ct2rs::{ComputeType, Config, Device};
use indicatif::{ProgressBar, ProgressStyle};
use crate::audio::PreparedAudio;
use crate::model::ModelChoice;
use crate::whisper::{Whisper, WhisperOptions};
#[derive(Clone, Debug)]
pub struct ExecutionMode {
device: Device,
compute_type: ComputeType,
gpu_device: Option<i32>,
warning: Option<String>,
}
impl ExecutionMode {
pub fn from_cli(use_gpu: bool, gpu_device: i32) -> Result<Self> {
if use_gpu {
#[cfg(feature = "gpu")]
{
let visible_devices = ct2rs::sys::get_device_count(Device::CUDA);
if visible_devices <= 0 {
return Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: Some(
"CUDA backend is not available in this environment; falling back to CPU"
.to_string(),
),
});
}
if gpu_device < 0 || gpu_device >= visible_devices {
bail!(
"CUDA device {} is out of range; visible devices: 0..{}",
gpu_device,
visible_devices - 1
);
}
return Ok(Self {
device: Device::CUDA,
compute_type: ComputeType::AUTO,
gpu_device: Some(gpu_device),
warning: None,
});
}
#[cfg(not(feature = "gpu"))]
{
let _ = gpu_device;
bail!("--gpu not supported");
}
}
Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: None,
})
}
pub fn device_label(&self) -> &'static str {
match self.device {
Device::CPU => "cpu",
Device::CUDA => "cuda",
_ => "unknown",
}
}
pub fn compute_type_label(&self) -> &'static str {
match self.compute_type {
ComputeType::DEFAULT => "default",
ComputeType::AUTO => "auto",
ComputeType::FLOAT32 => "float32",
ComputeType::INT8 => "int8",
ComputeType::INT8_FLOAT32 => "int8_float32",
ComputeType::INT8_FLOAT16 => "int8_float16",
ComputeType::INT8_BFLOAT16 => "int8_bfloat16",
ComputeType::INT16 => "int16",
ComputeType::FLOAT16 => "float16",
ComputeType::BFLOAT16 => "bfloat16",
_ => "unknown",
}
}
pub fn warning(&self) -> Option<&str> {
self.warning.as_deref()
}
pub fn gpu_requested_but_unavailable(&self) -> bool {
self.warning.is_some()
}
}
const STREAM_OVERLAP_SECONDS: usize = 5;
const MAX_DEDUP_WORDS: usize = 24;
const STREAM_HOLD_WORDS: usize = 3;
const STREAM_SPINNER_FRAMES: &[&str] = &["|", "/", "-", "\\"];
pub async fn run_transcription(
audio: &PreparedAudio,
model_choice: ModelChoice,
model_dir: &Path,
_models_root: &Path,
execution: &ExecutionMode,
stream: bool,
) -> Result<()> {
let whisper = load_whisper_model(model_choice, model_dir, execution)?;
let options = WhisperOptions {
beam_size: 5,
..Default::default()
};
println!();
println!("Audio parameters");
println!(
"backend: CTranslate2 / {} ({})",
execution.device_label(),
execution.compute_type_label()
);
if let Some(gpu_device) = execution.gpu_device {
println!("gpu id : {gpu_device}");
}
println!("source : {}", audio.display_name);
if let Some(duration) = audio.metadata.duration {
println!("length : {}", format_duration(duration.as_secs_f64()));
}
if let Some(source_rate) = audio.metadata.source_sample_rate {
println!("input rate : {source_rate} Hz");
}
println!("model rate : {} Hz", audio.metadata.target_sample_rate);
if let Some(channels) = audio.metadata.channels {
println!("channels : {channels} -> mono");
}
println!("codec : {}", audio.metadata.codec);
println!("model path : {}", model_dir.display());
println!();
if whisper.sampling_rate() != audio.metadata.target_sample_rate as usize {
anyhow::bail!(
"audio was resampled to {} Hz but the model expects {} Hz",
audio.metadata.target_sample_rate,
whisper.sampling_rate()
);
}
if stream {
stream_transcription(&whisper, audio, &options)
} else {
println!("{}", "=".repeat(72));
let processing = ProgressBar::new_spinner();
processing.set_style(
ProgressStyle::with_template(" processing {spinner:.green} {msg}")
.context("failed to configure processing spinner")?,
);
processing.enable_steady_tick(Duration::from_millis(80));
processing.set_message(audio.display_name.clone());
let lines = whisper
.generate(&audio.samples, None, false, &options)
.context("CTranslate2 transcription failed")?;
processing.finish_with_message("processing complete");
print_transcript_divider();
for line in lines {
let line = line.trim();
if !line.is_empty() {
println!("{line}");
}
}
Ok(())
}
}
fn load_whisper_model(
model_choice: ModelChoice,
model_dir: &Path,
execution: &ExecutionMode,
) -> Result<Whisper> {
let progress = ProgressBar::new_spinner();
progress.set_style(
ProgressStyle::with_template(" loading model {spinner:.green} {msg}")
.context("failed to configure model loading spinner")?,
);
progress.enable_steady_tick(Duration::from_millis(80));
progress.set_message(model_dir.display().to_string());
let whisper = Whisper::new(model_dir, ctranslate2_config(execution)).with_context(|| {
format!(
"failed to initialize CTranslate2 model `{}` from `{}`",
model_choice.cli_name(),
model_dir.display()
)
})?;
progress.finish_with_message("model loaded");
Ok(whisper)
}
fn stream_transcription(
whisper: &Whisper,
audio: &PreparedAudio,
options: &WhisperOptions,
) -> Result<()> {
let chunk_size = whisper.n_samples();
let sample_rate = audio.metadata.target_sample_rate as usize;
let window_samples = chunk_size;
let overlap_samples = (sample_rate * STREAM_OVERLAP_SECONDS).min(chunk_size / 2);
let step_size = window_samples.saturating_sub(overlap_samples).max(1);
let window_count = audio.samples.len().div_ceil(step_size);
let overlap_seconds = overlap_samples / sample_rate;
let mut display = StreamDisplay::new(format!(
"chunk 1/{window_count} / stream mode / {overlap_seconds}s overlap"
))?;
let mut seen_tail = String::new();
let mut held_words: Vec<String> = Vec::new();
for (index, start) in (0..audio.samples.len()).step_by(step_size).enumerate() {
display.set_status(format!(
"chunk {}/{} / stream mode / {}s overlap",
index + 1,
window_count,
overlap_seconds
))?;
let end = (start + window_samples).min(audio.samples.len());
let chunk = &audio.samples[start..end];
let raw_text = whisper
.generate(chunk, None, false, options)
.context("streaming chunk transcription failed")?
.into_iter()
.map(|line| line.trim().to_string())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ");
let text = trim_stream_overlap(&seen_tail, &raw_text);
if text.is_empty() {
continue;
}
held_words.extend(text.split_whitespace().map(str::to_string));
seen_tail = merge_stream_tail(&seen_tail, &text);
let emit_now = held_words.len().saturating_sub(STREAM_HOLD_WORDS);
if emit_now > 0 {
let stable_text = held_words[..emit_now].join(" ");
display.push_text(&stable_text)?;
held_words.drain(..emit_now);
}
}
if !held_words.is_empty() {
let final_text = held_words.join(" ");
display.push_text(&final_text)?;
}
display.finish("processing complete")
}
fn format_duration(seconds: f64) -> String {
let total_millis = (seconds * 1000.0).round() as u64;
let minutes = total_millis / 60_000;
let seconds = (total_millis % 60_000) / 1000;
let centis = (total_millis % 1000) / 10;
format!("{minutes:02}:{seconds:02}.{centis:02}")
}
fn print_transcript_divider() {
println!("{}", "=".repeat(72));
println!();
println!();
}
fn trim_stream_overlap(previous_tail: &str, current: &str) -> String {
let current_words = current.split_whitespace().collect::<Vec<_>>();
if current_words.is_empty() {
return String::new();
}
let previous_words = previous_tail.split_whitespace().collect::<Vec<_>>();
let max_overlap = previous_words
.len()
.min(current_words.len())
.min(MAX_DEDUP_WORDS);
for overlap in (1..=max_overlap).rev() {
let previous_slice = &previous_words[previous_words.len() - overlap..];
let current_slice = ¤t_words[..overlap];
if words_match(previous_slice, current_slice) {
return current_words[overlap..].join(" ");
}
}
current.trim().to_string()
}
fn merge_stream_tail(previous_tail: &str, current: &str) -> String {
let mut words = previous_tail
.split_whitespace()
.chain(current.split_whitespace())
.collect::<Vec<_>>();
if words.len() > MAX_DEDUP_WORDS {
words = words.split_off(words.len() - MAX_DEDUP_WORDS);
}
words.join(" ")
}
fn words_match(previous: &[&str], current: &[&str]) -> bool {
previous.len() == current.len()
&& previous
.iter()
.zip(current.iter())
.all(|(left, right)| normalize_word(left) == normalize_word(right))
}
fn normalize_word(word: &str) -> String {
word.chars()
.filter(|character| character.is_alphanumeric())
.flat_map(|character| character.to_lowercase())
.collect()
}
fn ctranslate2_config(execution: &ExecutionMode) -> Config {
let threads = std::thread::available_parallelism()
.map(|parallelism| parallelism.get())
.unwrap_or(4);
Config {
device: execution.device,
compute_type: execution.compute_type,
device_indices: execution
.gpu_device
.map(|index| vec![index])
.unwrap_or_else(|| vec![0]),
num_threads_per_replica: threads,
max_queued_batches: 2,
..Default::default()
}
}
struct StreamDisplay {
inner: Option<InteractiveStreamDisplay>,
wrote_text: bool,
}
impl StreamDisplay {
fn new(status: String) -> Result<Self> {
if io::stdout().is_terminal() {
Ok(Self {
inner: Some(InteractiveStreamDisplay::new(status)?),
wrote_text: false,
})
} else {
print_transcript_divider();
println!(" processing {status}");
println!();
Ok(Self {
inner: None,
wrote_text: false,
})
}
}
fn set_status(&mut self, status: String) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.set_status(status)?;
}
Ok(())
}
fn push_text(&mut self, text: &str) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.push_text(text)?;
} else {
let mut stdout = io::stdout().lock();
let rendered = render_stream_chunk(text, self.wrote_text);
write!(stdout, "{}", rendered.as_deref().unwrap_or_default())
.context("failed to write stream text")?;
stdout.flush().context("failed to flush stream text")?;
}
if !text.trim().is_empty() {
self.wrote_text = true;
}
Ok(())
}
fn finish(mut self, final_status: &str) -> Result<()> {
if let Some(inner) = self.inner.take() {
inner.finish(final_status)?;
} else {
println!();
}
Ok(())
}
}
struct InteractiveStreamDisplay {
state: Arc<Mutex<StreamDisplayState>>,
running: Arc<AtomicBool>,
spinner_thread: Option<thread::JoinHandle<()>>,
}
impl InteractiveStreamDisplay {
fn new(initial_status: String) -> Result<Self> {
let state = Arc::new(Mutex::new(StreamDisplayState::new(initial_status)));
{
let mut guard = state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.redraw()?;
}
let running = Arc::new(AtomicBool::new(true));
let state_for_thread = Arc::clone(&state);
let running_for_thread = Arc::clone(&running);
let spinner_thread = thread::spawn(move || {
let mut frame_index = 0usize;
while running_for_thread.load(Ordering::Relaxed) {
if let Ok(mut guard) = state_for_thread.lock() {
guard.spinner = STREAM_SPINNER_FRAMES
[frame_index % STREAM_SPINNER_FRAMES.len()]
.to_string();
let _ = guard.redraw();
}
frame_index = frame_index.wrapping_add(1);
thread::sleep(Duration::from_millis(80));
}
});
Ok(Self {
state,
running,
spinner_thread: Some(spinner_thread),
})
}
fn set_status(&mut self, status: String) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.status = status;
guard.redraw()
}
fn push_text(&mut self, text: &str) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
if !guard.transcript.is_empty() {
guard.transcript.push(' ');
}
guard.transcript.push_str(text);
guard.redraw()
}
fn finish(mut self, final_status: &str) -> Result<()> {
self.running.store(false, Ordering::Relaxed);
if let Some(handle) = self.spinner_thread.take() {
let _ = handle.join();
}
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.spinner.clear();
guard.status = final_status.to_string();
guard.redraw()?;
println!();
Ok(())
}
}
struct StreamDisplayState {
status: String,
spinner: String,
transcript: String,
width: usize,
rendered_lines: usize,
}
impl StreamDisplayState {
fn new(status: String) -> Self {
Self {
status,
spinner: STREAM_SPINNER_FRAMES[0].to_string(),
transcript: String::new(),
width: std::env::var("COLUMNS")
.ok()
.and_then(|value| value.parse::<usize>().ok())
.filter(|width| *width >= 40)
.unwrap_or(100),
rendered_lines: 0,
}
}
fn redraw(&mut self) -> Result<()> {
let lines = self.render_lines();
let mut stdout = io::stdout().lock();
if self.rendered_lines > 0 {
write!(stdout, "\x1b[{}F\x1b[J", self.rendered_lines)
.context("failed to rewind stream display")?;
}
for line in &lines {
writeln!(stdout, "{line}").context("failed to render stream display")?;
}
stdout.flush().context("failed to flush stream display")?;
self.rendered_lines = lines.len();
Ok(())
}
fn render_lines(&self) -> Vec<String> {
let divider = "=".repeat(72);
let status_line = format!(" processing {} {}", self.spinner, self.status);
let transcript_lines = wrap_text(&self.transcript, self.width.saturating_sub(1).max(40));
let mut lines = vec![
divider.clone(),
status_line,
divider,
String::new(),
String::new(),
];
lines.extend(transcript_lines);
lines
}
}
fn wrap_text(text: &str, width: usize) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let mut lines = Vec::new();
let mut current = String::new();
for word in text.split_whitespace() {
if current.is_empty() {
current.push_str(word);
} else if current.len() + 1 + word.len() <= width {
current.push(' ');
current.push_str(word);
} else {
lines.push(current);
current = word.to_string();
}
}
if !current.is_empty() {
lines.push(current);
}
lines
}
fn render_stream_chunk(text: &str, needs_leading_space: bool) -> Option<String> {
let text = text.trim();
if text.is_empty() {
return None;
}
let mut rendered = String::new();
if needs_leading_space {
rendered.push(' ');
}
rendered.push_str(text);
Some(rendered)
}
#[cfg(test)]
mod tests {
use super::{merge_stream_tail, render_stream_chunk, trim_stream_overlap};
#[test]
fn renders_stream_chunks_with_separator_after_first_chunk() {
assert_eq!(
render_stream_chunk("hello world", false),
Some("hello world".to_string())
);
assert_eq!(
render_stream_chunk("again there", true),
Some(" again there".to_string())
);
}
#[test]
fn trims_repeated_overlap_prefix() {
let previous_tail = "hello brave new world";
let current = "brave new world again";
assert_eq!(trim_stream_overlap(previous_tail, current), "again");
}
#[test]
fn keeps_recent_tail_words_only() {
let tail = merge_stream_tail(
"one two three four five six seven eight nine ten eleven twelve",
"thirteen fourteen fifteen sixteen seventeen eighteen nineteen twenty twentyone twentytwo twentythree twentyfour twentyfive",
);
assert_eq!(
tail,
"two three four five six seven eight nine ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen twenty twentyone twentytwo twentythree twentyfour twentyfive"
);
}
}