mod audio;
mod model;
mod model_download;
mod transcribe;
mod video;
mod whisper;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use indicatif::{ProgressBar, ProgressStyle};
use std::path::PathBuf;
use crate::audio::prepare_audio_source;
use crate::model::{
ModelChoice, binary_directory, read_model_config, remove_all_models,
remove_model_with_artifacts,
};
use crate::model_download::ensure_model_downloaded;
use crate::transcribe::{ExecutionMode, run_transcription};
#[derive(Debug, Parser)]
#[command(
name = "transcribe-cli",
version,
about = "CPU-optimized Whisper pipeline on CTranslate2 in Rust"
)]
struct Cli {
#[arg(long, value_enum, default_value_t = CliModel::Small)]
model: CliModel,
#[arg(
long,
value_name = "DIR",
help = "Model storage directory; defaults to <binary_dir>/models"
)]
models_dir: Option<PathBuf>,
#[arg(long, help = "Print transcript segments as soon as they are decoded")]
stream: bool,
#[arg(long, help = "Use Nvidia GPU through CTranslate2 CUDA backend")]
gpu: bool,
#[arg(
long,
default_value_t = 0,
help = "CUDA device index to use with --gpu"
)]
gpu_device: i32,
#[arg(
long,
conflicts_with = "remove_all",
help = "Remove the selected model and related leftovers from the models directory"
)]
remove_model: bool,
#[arg(long, help = "Remove the entire models directory")]
remove_all: bool,
#[arg(value_name = "MEDIA", help = "Path or URL to a media file")]
audio: Option<String>,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum CliModel {
Tiny,
Small,
Medium,
Large,
}
impl From<CliModel> for ModelChoice {
fn from(value: CliModel) -> Self {
match value {
CliModel::Tiny => Self::Tiny,
CliModel::Small => Self::Small,
CliModel::Medium => Self::Medium,
CliModel::Large => Self::Large,
}
}
}
#[tokio::main]
async fn main() {
if let Err(error) = run().await {
eprintln!("error: {error:#}");
std::process::exit(1);
}
}
async fn run() -> Result<()> {
let cli = Cli::parse();
let model_choice = ModelChoice::from(cli.model);
let execution = ExecutionMode::from_cli(cli.gpu, cli.gpu_device)?;
if let Some(warning) = execution.warning() {
eprintln!("warning: {warning}");
}
let default_models_root = binary_directory()?.join("models");
let models_root = cli
.models_dir
.as_deref()
.unwrap_or(default_models_root.as_path());
if cli.remove_all {
let removed = remove_all_models(models_root)?;
if removed {
println!("removed models directory `{}`", models_root.display());
} else {
println!(
"models directory `{}` does not exist",
models_root.display()
);
}
return Ok(());
}
if cli.remove_model {
let removed_entries = remove_model_with_artifacts(model_choice, models_root)?;
if removed_entries > 0 {
println!(
"removed model `{}` and {} related artifact(s) from `{}`",
model_choice.cli_name(),
removed_entries,
models_root.display()
);
} else {
println!(
"no artifacts found for model `{}` in `{}`",
model_choice.cli_name(),
models_root.display()
);
}
return Ok(());
}
let audio_input = cli.audio.as_deref().ok_or_else(|| {
anyhow::anyhow!(
"a media path or URL is required unless --remove-model or --remove-all is used"
)
})?;
let model_dir = ensure_model_downloaded(model_choice, Some(models_root)).await?;
let model_config = read_model_config(&model_dir)?;
println!();
println!("Model parameters");
println!(" requested model : {}", model_choice.cli_name());
println!(" runtime model : {}", model_choice.runtime_name());
println!(" repo : {}", model_choice.repo_id());
println!(" models root : {}", models_root.display());
println!(" local path : {}", model_dir.display());
println!(" device : {}", execution.device_label());
println!(" compute type : {}", execution.compute_type_label());
if cli.gpu {
if execution.gpu_requested_but_unavailable() {
println!(" gpu request : fallback to cpu");
} else {
println!(" gpu device : {}", cli.gpu_device);
}
}
if let Some(model_type) = model_config.model_type.as_deref() {
println!(" model type : {model_type}");
}
if let Some(d_model) = model_config.d_model {
println!(" d_model : {d_model}");
}
if let Some(mel_bins) = model_config.num_mel_bins {
println!(" mel bins : {mel_bins}");
}
if let Some(vocab_size) = model_config.vocab_size {
println!(" vocab size : {vocab_size}");
}
if let Some(encoder_layers) = model_config.encoder_layers {
println!(" encoder layers : {encoder_layers}");
}
if let Some(decoder_layers) = model_config.decoder_layers {
println!(" decoder layers : {decoder_layers}");
}
if let Some(encoder_heads) = model_config.encoder_attention_heads {
println!(" encoder heads : {encoder_heads}");
}
if let Some(decoder_heads) = model_config.decoder_attention_heads {
println!(" decoder heads : {decoder_heads}");
}
if let Some(max_source_positions) = model_config.max_source_positions {
println!(" max source pos : {max_source_positions}");
}
if let Some(max_target_positions) = model_config.max_target_positions {
println!(" max target pos : {max_target_positions}");
}
let prepare_bar = ProgressBar::new_spinner();
prepare_bar.set_style(ProgressStyle::with_template(
" preparing media {spinner:.green} {msg}",
)?);
prepare_bar.enable_steady_tick(std::time::Duration::from_millis(80));
prepare_bar.set_message(audio_input.to_string());
let audio = prepare_audio_source(audio_input).await?;
prepare_bar.finish_with_message("media ready");
run_transcription(
&audio,
model_choice,
&model_dir,
models_root,
&execution,
cli.stream,
)
.await
}