transcribe-cli 0.0.6

Native Rust CLI transcription pipeline with GigaAM v3 ONNX
mod audio;
mod dynamic_chunk;
mod gigaam;
mod model;
mod model_download;
mod onnx_ctc;
mod onnx_transducer;
mod parakeet;
mod rest;
mod server;
mod transcribe;
mod video;

use anyhow::Result;
use clap::{Parser, ValueEnum};
use indicatif::{ProgressBar, ProgressStyle};
use std::path::PathBuf;

use crate::audio::prepare_audio_source;
use crate::model::{
    ModelChoice, ModelComputeType, default_model_root_directory, read_model_config,
    remove_all_models, remove_model_with_artifacts,
};
use crate::model_download::ensure_model_downloaded;
#[cfg(target_os = "linux")]
use crate::model_download::ensure_ort_runtime_downloaded;
use crate::onnx_ctc::ExecutionMode;
use crate::server::run_server;
use crate::transcribe::{
    run_gigaam_transcription, run_live_gigaam_transcription, run_live_parakeet_transcription,
    run_parakeet_transcription,
};

#[derive(Debug, Parser)]
#[command(
    name = "transcribe-cli",
    version,
    about = "Native Rust transcription CLI with GigaAM v3 ONNX"
)]
struct Cli {
    #[arg(
        long,
        value_name = "PORT",
        help = "Start the REST server on the given port instead of running a single CLI transcription"
    )]
    server: Option<u16>,

    #[arg(
        long,
        value_name = "DIR",
        help = "Model storage directory; defaults to <binary_dir>/transcribe_sandbox/models"
    )]
    models_dir: Option<PathBuf>,

    #[arg(long, help = "Use Nvidia GPU execution")]
    gpu: bool,

    #[arg(
        long,
        default_value_t = 0,
        help = "CUDA device index to use with --gpu"
    )]
    gpu_device: i32,

    #[arg(
        long,
        value_enum,
        help = "Override compute type; accepted values are auto, int8, float32"
    )]
    compute_type: Option<CliComputeType>,

    #[arg(
        long,
        alias = "laungage",
        value_name = "LANG",
        help = "Language selects the backend: default ru uses GigaAM, en uses Parakeet-TDT v2"
    )]
    language: Option<String>,

    #[arg(
        long,
        help = "Print transcript chunk-by-chunk while the model is decoding"
    )]
    stream: bool,

    #[arg(
        long,
        help = "Treat MEDIA as a live HTTP/HTTPS media stream and transcribe it incrementally"
    )]
    live: bool,

    #[arg(
        long,
        help = "Enable lightweight VAD segmentation for --live so only speech-like chunks are sent to the model"
    )]
    vad: bool,

    #[arg(
        long,
        conflicts_with = "remove_all",
        help = "Remove the selected model and related leftovers from the models directory"
    )]
    remove_model: bool,

    #[arg(long, help = "Remove the entire models directory")]
    remove_all: bool,

    #[arg(value_name = "MEDIA", help = "Path or URL to an audio or video file")]
    audio: Option<String>,
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum CliComputeType {
    Auto,
    #[value(name = "int8")]
    Int8,
    #[value(name = "float32")]
    Float32,
}

impl CliComputeType {
    fn gigaam_type(self) -> Option<ModelComputeType> {
        match self {
            Self::Auto => None,
            Self::Int8 => Some(ModelComputeType::Int8),
            Self::Float32 => Some(ModelComputeType::Float32),
        }
    }
}

#[tokio::main]
async fn main() {
    if let Err(error) = run().await {
        eprintln!("error: {error:#}");
        std::process::exit(1);
    }
}

async fn run() -> Result<()> {
    let cli = Cli::parse();
    let language = resolve_requested_language(cli.language.as_deref())?;
    let model_choice = select_model_for_language(&language);

    let default_models_root = default_model_root_directory()?;
    let models_root = cli
        .models_dir
        .as_deref()
        .unwrap_or(default_models_root.as_path());

    if cli.remove_all {
        let removed = remove_all_models(models_root)?;
        if removed {
            println!("removed models directory `{}`", models_root.display());
        } else {
            println!(
                "models directory `{}` does not exist",
                models_root.display()
            );
        }
        return Ok(());
    }

    if cli.remove_model {
        let removed_entries = remove_model_with_artifacts(model_choice, models_root)?;
        if removed_entries > 0 {
            println!(
                "removed model `{}` and {} related artifact(s) from `{}`",
                model_choice.cli_name(),
                removed_entries,
                models_root.display()
            );
        } else {
            println!(
                "no artifacts found for model `{}` in `{}`",
                model_choice.cli_name(),
                models_root.display()
            );
        }
        return Ok(());
    }

    if let Some(port) = cli.server {
        run_server(port, models_root.to_path_buf()).await?;
        return Ok(());
    }

    let audio_input = cli.audio.as_deref().ok_or_else(|| {
        anyhow::anyhow!(
            "a media path or URL is required unless --server, --remove-model or --remove-all is used"
        )
    })?;

    if cli.vad && !cli.live {
        anyhow::bail!("`--vad` is currently supported only together with `--live`");
    }

    let execution = ExecutionMode::from_cli(
        cli.gpu,
        cli.gpu_device,
        cli.compute_type.and_then(CliComputeType::gigaam_type),
    )?;
    #[cfg(target_os = "linux")]
    ensure_ort_runtime_downloaded().await?;
    let model_dir = ensure_model_downloaded(
        model_choice,
        Some(execution.compute_type()),
        Some(models_root),
    )
    .await?;
    let model_config = read_model_config(&model_dir, model_choice)?;
    print_model_parameters(
        model_choice,
        &model_dir,
        models_root,
        execution.device_label(),
        execution.compute_type_label(),
        &language,
        execution.gpu_device(),
        &model_config,
    );

    match model_choice {
        ModelChoice::GigaamV3 if cli.live => {
            run_live_gigaam_transcription(
                audio_input,
                model_choice,
                &model_dir,
                &execution,
                Some(&language),
                cli.vad,
            )
            .await
        }
        ModelChoice::ParakeetTdt06bV2 if cli.live => {
            run_live_parakeet_transcription(
                audio_input,
                model_choice,
                &model_dir,
                &execution,
                Some(&language),
                cli.vad,
            )
            .await
        }
        ModelChoice::GigaamV3 => {
            let prepare_bar = ProgressBar::new_spinner();
            prepare_bar.set_style(ProgressStyle::with_template(
                "  preparing media  {spinner:.green} {msg}",
            )?);
            prepare_bar.enable_steady_tick(std::time::Duration::from_millis(80));
            prepare_bar.set_message(audio_input.to_string());
            let audio = prepare_audio_source(audio_input).await?;
            prepare_bar.finish_with_message("media ready");

            run_gigaam_transcription(
                &audio,
                model_choice,
                &model_dir,
                &execution,
                Some(&language),
                cli.stream,
            )
            .await
        }
        ModelChoice::ParakeetTdt06bV2 => {
            let prepare_bar = ProgressBar::new_spinner();
            prepare_bar.set_style(ProgressStyle::with_template(
                "  preparing media  {spinner:.green} {msg}",
            )?);
            prepare_bar.enable_steady_tick(std::time::Duration::from_millis(80));
            prepare_bar.set_message(audio_input.to_string());
            let audio = prepare_audio_source(audio_input).await?;
            prepare_bar.finish_with_message("media ready");

            run_parakeet_transcription(
                &audio,
                model_choice,
                &model_dir,
                &execution,
                Some(&language),
                cli.stream,
            )
            .await
        }
    }
}

fn resolve_requested_language(language: Option<&str>) -> Result<String> {
    let normalized = language
        .map(str::trim)
        .filter(|language| !language.is_empty())
        .unwrap_or("ru")
        .to_ascii_lowercase();

    match normalized.as_str() {
        "ru" | "rus" | "auto" => Ok(String::from("ru")),
        "en" | "eng" => Ok(String::from("en")),
        other => anyhow::bail!(
            "unsupported language `{other}`; currently supported values are `ru` and `en`"
        ),
    }
}

fn select_model_for_language(language: &str) -> ModelChoice {
    match language {
        "en" => ModelChoice::ParakeetTdt06bV2,
        _ => ModelChoice::GigaamV3,
    }
}

fn print_model_parameters(
    model_choice: ModelChoice,
    model_dir: &std::path::Path,
    models_root: &std::path::Path,
    device_label: &str,
    compute_type_label: &str,
    language_label: &str,
    gpu_device: Option<i32>,
    model_config: &crate::model::ModelConfig,
) {
    println!();
    println!("Model parameters");
    println!("  requested model : {}", model_choice.cli_name());
    println!("  runtime model   : {}", model_choice.runtime_name());
    println!("  repo            : {}", model_choice.repo_id());
    println!("  models root     : {}", models_root.display());
    println!("  local path      : {}", model_dir.display());
    println!("  device          : {device_label}");
    println!("  compute type    : {compute_type_label}");
    println!("  language        : {language_label}");
    if let Some(gpu_device) = gpu_device {
        println!("  gpu device      : {gpu_device}");
    }
    if let Some(model_name) = model_config.model_name.as_deref() {
        println!("  model name      : {model_name}");
    }
    if let Some(model_class) = model_config.model_class.as_deref() {
        println!("  model class     : {model_class}");
    }
    if let Some(model_type) = model_config.model_type.as_deref() {
        println!("  model type      : {model_type}");
    }
    if let Some(sample_rate) = model_config.sample_rate {
        println!("  sample rate     : {sample_rate} Hz");
    }
    if let Some(features) = model_config.features {
        println!("  mel bins        : {features}");
    }
    if let Some(win_length) = model_config.win_length {
        println!("  win length      : {win_length}");
    }
    if let Some(hop_length) = model_config.hop_length {
        println!("  hop length      : {hop_length}");
    }
    if let Some(n_fft) = model_config.n_fft {
        println!("  n_fft           : {n_fft}");
    }
    if let Some(center) = model_config.center {
        println!("  center          : {center}");
    }
    if let Some(encoder_layers) = model_config.encoder_layers {
        println!("  encoder layers  : {encoder_layers}");
    }
    if let Some(d_model) = model_config.d_model {
        println!("  d_model         : {d_model}");
    }
    if let Some(num_classes) = model_config.num_classes {
        println!("  num classes     : {num_classes}");
    }
    if let Some(subsampling_factor) = model_config.subsampling_factor {
        println!("  subsampling     : {subsampling_factor}");
    }
}