use std::io::{self, Write};
use std::path::Path;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use ct2rs::{ComputeType, Config, Device};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use crate::audio::PreparedAudio;
use crate::model::ModelChoice;
use crate::whisper::{Whisper, WhisperOptions};
#[derive(Clone, Debug)]
pub struct ExecutionMode {
device: Device,
compute_type: ComputeType,
gpu_device: Option<i32>,
warning: Option<String>,
}
impl ExecutionMode {
pub fn from_cli(use_gpu: bool, gpu_device: i32) -> Result<Self> {
if use_gpu {
#[cfg(feature = "cuda")]
{
let visible_devices = ct2rs::sys::get_device_count(Device::CUDA);
if visible_devices <= 0 {
return Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: Some(
"CUDA backend is not available in this environment; falling back to CPU"
.to_string(),
),
});
}
if gpu_device < 0 || gpu_device >= visible_devices {
bail!(
"CUDA device {} is out of range; visible devices: 0..{}",
gpu_device,
visible_devices - 1
);
}
return Ok(Self {
device: Device::CUDA,
compute_type: ComputeType::AUTO,
gpu_device: Some(gpu_device),
warning: None,
});
}
#[cfg(not(feature = "cuda"))]
{
let _ = gpu_device;
bail!("--gpu not supported");
}
}
Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: None,
})
}
pub fn device_label(&self) -> &'static str {
match self.device {
Device::CPU => "cpu",
Device::CUDA => "cuda",
_ => "unknown",
}
}
pub fn compute_type_label(&self) -> &'static str {
match self.compute_type {
ComputeType::DEFAULT => "default",
ComputeType::AUTO => "auto",
ComputeType::FLOAT32 => "float32",
ComputeType::INT8 => "int8",
ComputeType::INT8_FLOAT32 => "int8_float32",
ComputeType::INT8_FLOAT16 => "int8_float16",
ComputeType::INT8_BFLOAT16 => "int8_bfloat16",
ComputeType::INT16 => "int16",
ComputeType::FLOAT16 => "float16",
ComputeType::BFLOAT16 => "bfloat16",
_ => "unknown",
}
}
pub fn warning(&self) -> Option<&str> {
self.warning.as_deref()
}
pub fn gpu_requested_but_unavailable(&self) -> bool {
self.warning.is_some()
}
}
#[derive(Clone, Copy, Debug)]
pub enum StreamMode {
Chunk,
SubChunk,
}
const STREAM_OVERLAP_SECONDS: usize = 5;
const STREAM_SUBCHUNK_SECONDS: usize = 8;
const STREAM_SUBCHUNK_OVERLAP_SECONDS: usize = 4;
const MAX_DEDUP_WORDS: usize = 24;
const STREAM_HOLD_WORDS: usize = 3;
const STREAM_ACTIVITY_TICK_MS: u64 = 80;
pub async fn run_transcription(
audio: &PreparedAudio,
model_choice: ModelChoice,
model_dir: &Path,
_models_root: &Path,
execution: &ExecutionMode,
stream_mode: Option<StreamMode>,
) -> Result<()> {
let whisper = load_whisper_model(model_choice, model_dir, execution)?;
let options = WhisperOptions {
beam_size: 5,
..Default::default()
};
println!(
"backend: CTranslate2 / {} ({})",
execution.device_label(),
execution.compute_type_label()
);
if let Some(gpu_device) = execution.gpu_device {
println!("gpu id : {gpu_device}");
}
println!("source : {}", audio.display_name);
if let Some(duration) = audio.metadata.duration {
println!("length : {}", format_duration(duration.as_secs_f64()));
}
if let Some(source_rate) = audio.metadata.source_sample_rate {
println!("input rate : {source_rate} Hz");
}
println!("model rate : {} Hz", audio.metadata.target_sample_rate);
if let Some(channels) = audio.metadata.channels {
println!("channels : {channels} -> mono");
}
println!("codec : {}", audio.metadata.codec);
println!("model path : {}", model_dir.display());
println!();
if whisper.sampling_rate() != audio.metadata.target_sample_rate as usize {
anyhow::bail!(
"audio was resampled to {} Hz but the model expects {} Hz",
audio.metadata.target_sample_rate,
whisper.sampling_rate()
);
}
if let Some(stream_mode) = stream_mode {
stream_transcription(&whisper, audio, &options, stream_mode)
} else {
let progress = ProgressBar::new_spinner();
progress.set_style(
ProgressStyle::with_template(" transcribing {spinner:.green} {msg}")
.context("failed to configure transcription spinner")?,
);
progress.enable_steady_tick(Duration::from_millis(80));
progress.set_message(audio.display_name.clone());
let lines = whisper
.generate(&audio.samples, None, false, &options)
.context("CTranslate2 transcription failed")?;
progress.finish_with_message("transcription complete");
for line in lines {
let line = line.trim();
if !line.is_empty() {
println!("{line}");
}
}
Ok(())
}
}
fn load_whisper_model(
model_choice: ModelChoice,
model_dir: &Path,
execution: &ExecutionMode,
) -> Result<Whisper> {
let progress = ProgressBar::new_spinner();
progress.set_style(
ProgressStyle::with_template(" loading model {spinner:.green} {msg}")
.context("failed to configure model loading spinner")?,
);
progress.enable_steady_tick(Duration::from_millis(80));
progress.set_message(model_dir.display().to_string());
let whisper = Whisper::new(model_dir, ctranslate2_config(execution)).with_context(|| {
format!(
"failed to initialize CTranslate2 model `{}` from `{}`",
model_choice.cli_name(),
model_dir.display()
)
})?;
progress.finish_with_message("model loaded");
Ok(whisper)
}
fn stream_transcription(
whisper: &Whisper,
audio: &PreparedAudio,
options: &WhisperOptions,
stream_mode: StreamMode,
) -> Result<()> {
let chunk_size = whisper.n_samples();
let sample_rate = audio.metadata.target_sample_rate as usize;
let (window_samples, overlap_samples) = match stream_mode {
StreamMode::Chunk => (
chunk_size,
(sample_rate * STREAM_OVERLAP_SECONDS).min(chunk_size / 2),
),
StreamMode::SubChunk => {
let subchunk_samples = (sample_rate * STREAM_SUBCHUNK_SECONDS).min(chunk_size);
let overlap = (sample_rate * STREAM_SUBCHUNK_OVERLAP_SECONDS).min(subchunk_samples / 2);
(subchunk_samples, overlap)
}
};
let step_size = window_samples.saturating_sub(overlap_samples).max(1);
let window_count = audio.samples.len().div_ceil(step_size);
let multi = MultiProgress::new();
multi.set_move_cursor(true);
let current_progress = multi.add(ProgressBar::new_spinner());
current_progress.set_style(
ProgressStyle::with_template(" current chunk {spinner:.cyan} {msg}")
.context("failed to configure current chunk progress spinner")?,
);
current_progress.enable_steady_tick(Duration::from_millis(STREAM_ACTIVITY_TICK_MS));
current_progress.set_message(format!("1/{window_count} decoding"));
let next_progress = multi.add(ProgressBar::new_spinner());
next_progress.set_style(
ProgressStyle::with_template(" next chunk {spinner:.yellow} {msg}")
.context("failed to configure next chunk progress spinner")?,
);
next_progress.enable_steady_tick(Duration::from_millis(STREAM_ACTIVITY_TICK_MS));
if window_count > 1 {
next_progress.set_message(format!("2/{window_count} queued"));
} else {
next_progress.set_message("none");
}
let mut seen_tail = String::new();
let mut held_words: Vec<String> = Vec::new();
let mut has_output = false;
for (index, start) in (0..audio.samples.len()).step_by(step_size).enumerate() {
let current_window = index + 1;
let next_window = if current_window < window_count {
Some(current_window + 1)
} else {
None
};
current_progress.set_message(format!("{current_window}/{window_count} decoding"));
if let Some(next_window) = next_window {
next_progress.set_message(format!("{next_window}/{window_count} queued"));
} else {
next_progress.set_message("none");
}
let end = (start + window_samples).min(audio.samples.len());
let chunk = &audio.samples[start..end];
let raw_text = whisper
.generate(chunk, None, false, options)
.context("streaming chunk transcription failed")?
.into_iter()
.map(|line| line.trim().to_string())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ");
let text = trim_stream_overlap(&seen_tail, &raw_text);
if text.is_empty() {
continue;
}
held_words.extend(text.split_whitespace().map(str::to_string));
seen_tail = merge_stream_tail(&seen_tail, &text);
let emit_now = held_words.len().saturating_sub(STREAM_HOLD_WORDS);
if emit_now > 0 {
let stable_text = held_words[..emit_now].join(" ");
emit_stream_text(&multi, &stable_text, &mut has_output)?;
held_words.drain(..emit_now);
}
}
if !held_words.is_empty() {
let final_text = held_words.join(" ");
emit_stream_text(&multi, &final_text, &mut has_output)?;
}
if has_output {
multi.suspend(|| {
println!();
});
}
current_progress.finish_and_clear();
next_progress.finish_and_clear();
Ok(())
}
fn format_duration(seconds: f64) -> String {
let total_millis = (seconds * 1000.0).round() as u64;
let minutes = total_millis / 60_000;
let seconds = (total_millis % 60_000) / 1000;
let centis = (total_millis % 1000) / 10;
format!("{minutes:02}:{seconds:02}.{centis:02}")
}
fn emit_stream_text(progress: &MultiProgress, text: &str, has_output: &mut bool) -> Result<()> {
progress.suspend(|| -> Result<()> {
let mut stdout = io::stdout().lock();
if *has_output {
write!(stdout, " ").context("failed to write stream separator")?;
}
write!(stdout, "{text}").context("failed to write stream text")?;
stdout.flush().context("failed to flush stdout")?;
*has_output = true;
Ok(())
})
}
fn trim_stream_overlap(previous_tail: &str, current: &str) -> String {
let current_words = current.split_whitespace().collect::<Vec<_>>();
if current_words.is_empty() {
return String::new();
}
let previous_words = previous_tail.split_whitespace().collect::<Vec<_>>();
let max_overlap = previous_words
.len()
.min(current_words.len())
.min(MAX_DEDUP_WORDS);
for overlap in (1..=max_overlap).rev() {
let previous_slice = &previous_words[previous_words.len() - overlap..];
let current_slice = ¤t_words[..overlap];
if words_match(previous_slice, current_slice) {
return current_words[overlap..].join(" ");
}
}
current.trim().to_string()
}
fn merge_stream_tail(previous_tail: &str, current: &str) -> String {
let mut words = previous_tail
.split_whitespace()
.chain(current.split_whitespace())
.collect::<Vec<_>>();
if words.len() > MAX_DEDUP_WORDS {
words = words.split_off(words.len() - MAX_DEDUP_WORDS);
}
words.join(" ")
}
fn words_match(previous: &[&str], current: &[&str]) -> bool {
previous.len() == current.len()
&& previous
.iter()
.zip(current.iter())
.all(|(left, right)| normalize_word(left) == normalize_word(right))
}
fn normalize_word(word: &str) -> String {
word.chars()
.filter(|character| character.is_alphanumeric())
.flat_map(|character| character.to_lowercase())
.collect()
}
fn ctranslate2_config(execution: &ExecutionMode) -> Config {
let threads = std::thread::available_parallelism()
.map(|parallelism| parallelism.get())
.unwrap_or(4);
Config {
device: execution.device,
compute_type: execution.compute_type,
device_indices: execution
.gpu_device
.map(|index| vec![index])
.unwrap_or_else(|| vec![0]),
num_threads_per_replica: threads,
max_queued_batches: 2,
..Default::default()
}
}