scriptrs 0.2.0

Rust transcription with native CoreML Parakeet v2 inference
Documentation
use std::env;
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use std::time::Instant;

use eyre::{Result, bail, eyre};
use hound::WavReader;

#[cfg(feature = "long-form-vad")]
use scriptrs::LongFormMode;
use scriptrs::TranscriptionPipeline;
#[cfg(feature = "long-form")]
use scriptrs::{LongFormConfig, LongFormTranscriptionPipeline};

fn main() -> ExitCode {
    match run() {
        Ok(()) => ExitCode::SUCCESS,
        Err(error) => {
            eprintln!("{error:?}");
            ExitCode::FAILURE
        }
    }
}

fn run() -> Result<()> {
    let args = Args::parse(env::args().skip(1))?;
    let audio = read_mono_16khz_wav(&args.audio_path)?;
    let pipeline_started_at = Instant::now();

    #[cfg(feature = "long-form")]
    let result = if args.long_form {
        let pipeline = build_long_form_pipeline(&args)?;
        let pipeline_elapsed = pipeline_started_at.elapsed().as_secs_f64();
        run_long_form_pipeline(&pipeline, &audio, &args, pipeline_elapsed)?
    } else {
        let pipeline = build_pipeline(&args)?;
        let pipeline_elapsed = pipeline_started_at.elapsed().as_secs_f64();
        run_pipeline(&pipeline, &audio, &args, pipeline_elapsed)?
    };

    #[cfg(not(feature = "long-form"))]
    let result = {
        if args.long_form {
            bail!("rebuild with --features long-form to use --long-form")
        }
        let pipeline = build_pipeline(&args)?;
        let pipeline_elapsed = pipeline_started_at.elapsed().as_secs_f64();
        run_pipeline(&pipeline, &audio, &args, pipeline_elapsed)?
    };

    println!("file: {}", args.audio_path.display());
    println!("audio_seconds: {:.1}", audio.len() as f64 / 16_000.0);
    println!("chunks: {}", result.chunks.len());
    println!("tokens: {}", result.tokens.len());
    println!("{}", preview(&result.text, args.preview_chars));
    Ok(())
}

#[derive(Debug, Clone)]
struct Args {
    audio_path: PathBuf,
    models_dir: Option<PathBuf>,
    pretrained: bool,
    long_form: bool,
    vad_long_form: bool,
    long_form_workers: Option<usize>,
    warmup_runs: usize,
    benchmark_runs: usize,
    preview_chars: usize,
}

impl Args {
    fn parse(args: impl IntoIterator<Item = String>) -> Result<Self> {
        let mut args = args.into_iter();
        let mut audio_path = None;
        let mut models_dir = None;
        let mut pretrained = false;
        let mut long_form = false;
        let mut vad_long_form = false;
        let mut long_form_workers = None;
        let mut warmup_runs = 1usize;
        let mut benchmark_runs = 0usize;
        let mut preview_chars = 800usize;

        while let Some(arg) = args.next() {
            match arg.as_str() {
                "--audio" => audio_path = Some(next_path(&mut args, "--audio")?),
                "--models-dir" => models_dir = Some(next_path(&mut args, "--models-dir")?),
                "--pretrained" => pretrained = true,
                "--long-form" => long_form = true,
                "--vad-long-form" => vad_long_form = true,
                "--long-form-workers" => {
                    let worker_count = next_value(&mut args, "--long-form-workers")?
                        .parse()
                        .map_err(|error| eyre!("invalid --long-form-workers value: {error}"))?;
                    long_form_workers = Some(worker_count);
                }
                "--warmup-runs" => {
                    warmup_runs = next_value(&mut args, "--warmup-runs")?
                        .parse()
                        .map_err(|error| eyre!("invalid --warmup-runs value: {error}"))?;
                }
                "--benchmark-runs" => {
                    benchmark_runs = next_value(&mut args, "--benchmark-runs")?
                        .parse()
                        .map_err(|error| eyre!("invalid --benchmark-runs value: {error}"))?;
                }
                "--preview-chars" => {
                    preview_chars = next_value(&mut args, "--preview-chars")?
                        .parse()
                        .map_err(|error| eyre!("invalid --preview-chars value: {error}"))?;
                }
                "--help" | "-h" => {
                    print_usage();
                    std::process::exit(0);
                }
                flag if flag.starts_with('-') => bail!("unknown flag: {flag}"),
                path => {
                    if audio_path.is_some() {
                        bail!("unexpected positional argument: {path}")
                    }
                    audio_path = Some(PathBuf::from(path));
                }
            }
        }

        let Some(audio_path) = audio_path else {
            bail!("missing --audio <path.wav>")
        };
        if pretrained && models_dir.is_some() {
            bail!("use either --pretrained or --models-dir, not both")
        }
        if benchmark_runs > 0 && warmup_runs == 0 {
            bail!("--warmup-runs must be at least 1 when benchmarking")
        }
        if vad_long_form && !long_form {
            bail!("--vad-long-form requires --long-form")
        }
        if let Some(worker_count) = long_form_workers
            && worker_count == 0
        {
            bail!("--long-form-workers must be at least 1")
        }

        Ok(Self {
            audio_path,
            models_dir,
            pretrained,
            long_form,
            vad_long_form,
            long_form_workers,
            warmup_runs,
            benchmark_runs,
            preview_chars,
        })
    }
}

fn print_usage() {
    eprintln!(
        "Usage:
  cargo run --example transcribe_wav -- --audio <path.wav>
  cargo run --example transcribe_wav -- --audio <path.wav> --pretrained
  cargo run --example transcribe_wav --features long-form -- --audio <path.wav> --pretrained --long-form

Options:
  --models-dir <dir>         local scriptrs model bundle directory
  --pretrained               download models via the online feature
  --long-form                use LongFormTranscriptionPipeline
  --vad-long-form            use VAD-backed region planning
  --long-form-workers <n>    parallel long-form workers (default: 4)
  --warmup-runs <n>          warmup runs before timing (default: 1)
  --benchmark-runs <n>       timed runs after warmup on one loaded pipeline
  --preview-chars <n>        text preview limit"
    );
}

fn run_pipeline(
    pipeline: &TranscriptionPipeline,
    audio: &[f32],
    args: &Args,
    pipeline_elapsed: f64,
) -> Result<scriptrs::TranscriptionResult> {
    if args.benchmark_runs == 0 {
        let started_at = Instant::now();
        let result = pipeline.run(audio)?;
        println!("pipeline_load_seconds: {:.2}", pipeline_elapsed);
        println!("elapsed_seconds: {:.2}", started_at.elapsed().as_secs_f64());
        return Ok(result);
    }

    for _ in 0..args.warmup_runs {
        let _ = pipeline.run(audio)?;
    }

    let mut total_seconds = 0.0;
    let mut result = None;
    for _ in 0..args.benchmark_runs {
        let started_at = Instant::now();
        let run_result = pipeline.run(audio)?;
        total_seconds += started_at.elapsed().as_secs_f64();
        result = Some(run_result);
    }

    let result = result.expect("benchmark_runs should be positive");
    println!("pipeline_load_seconds: {:.2}", pipeline_elapsed);
    println!("warmup_runs: {}", args.warmup_runs);
    println!("benchmark_runs: {}", args.benchmark_runs);
    println!("elapsed_seconds: {:.2}", total_seconds);
    println!(
        "mean_elapsed_seconds: {:.2}",
        total_seconds / args.benchmark_runs as f64
    );
    Ok(result)
}

#[cfg(feature = "long-form")]
fn run_long_form_pipeline(
    pipeline: &LongFormTranscriptionPipeline,
    audio: &[f32],
    args: &Args,
    pipeline_elapsed: f64,
) -> Result<scriptrs::TranscriptionResult> {
    let mut config = LongFormConfig::default();
    if let Some(worker_count) = args.long_form_workers {
        config.worker_count = worker_count;
    }
    #[cfg(feature = "long-form-vad")]
    if args.vad_long_form {
        config.mode = LongFormMode::Vad;
    }
    #[cfg(not(feature = "long-form-vad"))]
    if args.vad_long_form {
        bail!("rebuild with --features long-form-vad to use --vad-long-form")
    }
    #[cfg(feature = "long-form-vad")]
    if !args.vad_long_form {
        config.mode = LongFormMode::Fast;
    }

    if args.benchmark_runs == 0 {
        let started_at = Instant::now();
        let result = pipeline.run_with_config(audio, &config)?;
        println!("pipeline_load_seconds: {:.2}", pipeline_elapsed);
        println!("elapsed_seconds: {:.2}", started_at.elapsed().as_secs_f64());
        return Ok(result);
    }

    for _ in 0..args.warmup_runs {
        let _ = pipeline.run_with_config(audio, &config)?;
    }

    let mut total_seconds = 0.0;
    let mut result = None;
    for _ in 0..args.benchmark_runs {
        let started_at = Instant::now();
        let run_result = pipeline.run_with_config(audio, &config)?;
        total_seconds += started_at.elapsed().as_secs_f64();
        result = Some(run_result);
    }

    let result = result.expect("benchmark_runs should be positive");
    println!("pipeline_load_seconds: {:.2}", pipeline_elapsed);
    println!("warmup_runs: {}", args.warmup_runs);
    println!("benchmark_runs: {}", args.benchmark_runs);
    println!("elapsed_seconds: {:.2}", total_seconds);
    println!(
        "mean_elapsed_seconds: {:.2}",
        total_seconds / args.benchmark_runs as f64
    );
    Ok(result)
}

fn build_pipeline(args: &Args) -> Result<TranscriptionPipeline> {
    if let Some(models_dir) = &args.models_dir {
        return Ok(TranscriptionPipeline::from_dir(models_dir)?);
    }

    #[cfg(feature = "online")]
    {
        let _ = args.pretrained;
        Ok(TranscriptionPipeline::from_pretrained()?)
    }

    #[cfg(not(feature = "online"))]
    {
        let _ = args.pretrained;
        bail!("rebuild with the default online feature or pass --models-dir")
    }
}

#[cfg(feature = "long-form")]
fn build_long_form_pipeline(args: &Args) -> Result<LongFormTranscriptionPipeline> {
    if let Some(models_dir) = &args.models_dir {
        return Ok(LongFormTranscriptionPipeline::from_dir(models_dir)?);
    }

    #[cfg(feature = "online")]
    {
        let _ = args.pretrained;
        Ok(LongFormTranscriptionPipeline::from_pretrained()?)
    }

    #[cfg(not(feature = "online"))]
    {
        let _ = args.pretrained;
        bail!("rebuild with the default online feature or pass --models-dir")
    }
}

fn next_path(args: &mut impl Iterator<Item = String>, flag: &str) -> Result<PathBuf> {
    Ok(PathBuf::from(next_value(args, flag)?))
}

fn next_value(args: &mut impl Iterator<Item = String>, flag: &str) -> Result<String> {
    args.next().ok_or_else(|| eyre!("missing value for {flag}"))
}

fn read_mono_16khz_wav(path: &Path) -> Result<Vec<f32>> {
    let mut reader = WavReader::open(path)?;
    let spec = reader.spec();
    if spec.sample_rate != 16_000 {
        bail!(
            "expected 16kHz audio, got {} Hz in {}",
            spec.sample_rate,
            path.display()
        )
    }
    if spec.channels != 1 {
        bail!(
            "expected mono audio, got {} channels in {}",
            spec.channels,
            path.display()
        )
    }

    match (spec.sample_format, spec.bits_per_sample) {
        (hound::SampleFormat::Float, 32) => reader
            .samples::<f32>()
            .collect::<Result<Vec<_>, _>>()
            .map_err(Into::into),
        (hound::SampleFormat::Int, bits) if (1..=32).contains(&bits) => {
            let scale = ((1_i64 << (bits - 1)) - 1) as f32;
            let samples = reader
                .samples::<i32>()
                .map(|sample| sample.map(|value| value as f32 / scale))
                .collect::<Result<Vec<_>, _>>()?;
            Ok(samples)
        }
        _ => bail!(
            "unsupported WAV format: {:?} {}-bit",
            spec.sample_format,
            spec.bits_per_sample
        ),
    }
}

fn preview(text: &str, limit: usize) -> String {
    if text.len() <= limit {
        return text.to_owned();
    }
    format!("{}...", &text[..limit])
}