transcribe-cli 0.0.4

Whisper CLI transcription pipeline on CTranslate2 with CPU and optional CUDA support
mod audio;
mod model;
mod model_download;
mod transcribe;
mod video;
mod whisper;

use anyhow::Result;
use clap::{Parser, ValueEnum};
use indicatif::{ProgressBar, ProgressStyle};
use std::path::PathBuf;

use crate::audio::prepare_audio_source;
use crate::model::{
    ModelChoice, binary_directory, read_model_config, remove_all_models,
    remove_model_with_artifacts,
};
use crate::model_download::ensure_model_downloaded;
use crate::transcribe::{ExecutionMode, run_transcription};

#[derive(Debug, Parser)]
#[command(
    name = "transcribe-cli",
    version,
    about = "CPU-optimized Whisper pipeline on CTranslate2 in Rust"
)]
struct Cli {
    #[arg(long, value_enum, default_value_t = CliModel::Small)]
    model: CliModel,

    #[arg(
        long,
        value_name = "DIR",
        help = "Model storage directory; defaults to <binary_dir>/models"
    )]
    models_dir: Option<PathBuf>,

    #[arg(long, help = "Print transcript segments as soon as they are decoded")]
    stream: bool,

    #[arg(long, help = "Use Nvidia GPU through CTranslate2 CUDA backend")]
    gpu: bool,

    #[arg(
        long,
        default_value_t = 0,
        help = "CUDA device index to use with --gpu"
    )]
    gpu_device: i32,

    #[arg(
        long,
        conflicts_with = "remove_all",
        help = "Remove the selected model and related leftovers from the models directory"
    )]
    remove_model: bool,

    #[arg(long, help = "Remove the entire models directory")]
    remove_all: bool,

    #[arg(value_name = "MEDIA", help = "Path or URL to a media file")]
    audio: Option<String>,
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum CliModel {
    Tiny,
    Small,
    Medium,
    Large,
}

impl From<CliModel> for ModelChoice {
    fn from(value: CliModel) -> Self {
        match value {
            CliModel::Tiny => Self::Tiny,
            CliModel::Small => Self::Small,
            CliModel::Medium => Self::Medium,
            CliModel::Large => Self::Large,
        }
    }
}

#[tokio::main]
async fn main() {
    if let Err(error) = run().await {
        eprintln!("error: {error:#}");
        std::process::exit(1);
    }
}

async fn run() -> Result<()> {
    let cli = Cli::parse();
    let model_choice = ModelChoice::from(cli.model);
    let execution = ExecutionMode::from_cli(cli.gpu, cli.gpu_device)?;

    if let Some(warning) = execution.warning() {
        eprintln!("warning: {warning}");
    }

    let default_models_root = binary_directory()?.join("models");
    let models_root = cli
        .models_dir
        .as_deref()
        .unwrap_or(default_models_root.as_path());

    if cli.remove_all {
        let removed = remove_all_models(models_root)?;
        if removed {
            println!("removed models directory `{}`", models_root.display());
        } else {
            println!(
                "models directory `{}` does not exist",
                models_root.display()
            );
        }
        return Ok(());
    }

    if cli.remove_model {
        let removed_entries = remove_model_with_artifacts(model_choice, models_root)?;
        if removed_entries > 0 {
            println!(
                "removed model `{}` and {} related artifact(s) from `{}`",
                model_choice.cli_name(),
                removed_entries,
                models_root.display()
            );
        } else {
            println!(
                "no artifacts found for model `{}` in `{}`",
                model_choice.cli_name(),
                models_root.display()
            );
        }
        return Ok(());
    }

    let audio_input = cli.audio.as_deref().ok_or_else(|| {
        anyhow::anyhow!(
            "a media path or URL is required unless --remove-model or --remove-all is used"
        )
    })?;
    let model_dir = ensure_model_downloaded(model_choice, Some(models_root)).await?;
    let model_config = read_model_config(&model_dir)?;
    println!();
    println!("Model parameters");
    println!("  requested model : {}", model_choice.cli_name());
    println!("  runtime model   : {}", model_choice.runtime_name());
    println!("  repo            : {}", model_choice.repo_id());
    println!("  models root     : {}", models_root.display());
    println!("  local path      : {}", model_dir.display());
    println!("  device          : {}", execution.device_label());
    println!("  compute type    : {}", execution.compute_type_label());
    if cli.gpu {
        if execution.gpu_requested_but_unavailable() {
            println!("  gpu request     : fallback to cpu");
        } else {
            println!("  gpu device      : {}", cli.gpu_device);
        }
    }
    if let Some(model_type) = model_config.model_type.as_deref() {
        println!("  model type      : {model_type}");
    }
    if let Some(d_model) = model_config.d_model {
        println!("  d_model         : {d_model}");
    }
    if let Some(mel_bins) = model_config.num_mel_bins {
        println!("  mel bins        : {mel_bins}");
    }
    if let Some(vocab_size) = model_config.vocab_size {
        println!("  vocab size      : {vocab_size}");
    }
    if let Some(encoder_layers) = model_config.encoder_layers {
        println!("  encoder layers  : {encoder_layers}");
    }
    if let Some(decoder_layers) = model_config.decoder_layers {
        println!("  decoder layers  : {decoder_layers}");
    }
    if let Some(encoder_heads) = model_config.encoder_attention_heads {
        println!("  encoder heads   : {encoder_heads}");
    }
    if let Some(decoder_heads) = model_config.decoder_attention_heads {
        println!("  decoder heads   : {decoder_heads}");
    }
    if let Some(max_source_positions) = model_config.max_source_positions {
        println!("  max source pos  : {max_source_positions}");
    }
    if let Some(max_target_positions) = model_config.max_target_positions {
        println!("  max target pos  : {max_target_positions}");
    }

    let prepare_bar = ProgressBar::new_spinner();
    prepare_bar.set_style(ProgressStyle::with_template(
        "  preparing media  {spinner:.green} {msg}",
    )?);
    prepare_bar.enable_steady_tick(std::time::Duration::from_millis(80));
    prepare_bar.set_message(audio_input.to_string());
    let audio = prepare_audio_source(audio_input).await?;
    prepare_bar.finish_with_message("media ready");
    run_transcription(
        &audio,
        model_choice,
        &model_dir,
        models_root,
        &execution,
        cli.stream,
    )
    .await
}