voice-typing 0.1.0

Fast local voice typing for any focused window or browser input
use anyhow::{Context, Result, anyhow};
use bzip2::read::BzDecoder;
use reqwest::blocking::Client;
use std::fs::{self, File};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tar::Archive;
use voice_typing_asr::CurrentWiredModel;

const MODEL_REPO: &str = "https://huggingface.co/csukuangfj2/sherpa-onnx-moonshine-base-en-quantized-2026-02-27/resolve/main";
const MODEL_ARCHIVE_URL: &str = "https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-base-en-quantized-2026-02-27.tar.bz2";
const VAD_URL: &str =
    "https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx";

#[derive(Debug, Default)]
pub struct DownloadProgress {
    total_files: AtomicU64,
    completed_files: AtomicU64,
    current_total_bytes: AtomicU64,
    current_downloaded_bytes: AtomicU64,
    running: AtomicBool,
}

impl DownloadProgress {
    pub fn start(&self, total_files: u64) {
        self.total_files.store(total_files, Ordering::Relaxed);
        self.completed_files.store(0, Ordering::Relaxed);
        self.current_total_bytes.store(0, Ordering::Relaxed);
        self.current_downloaded_bytes.store(0, Ordering::Relaxed);
        self.running.store(true, Ordering::Relaxed);
    }

    pub fn begin_file(&self, total_bytes: Option<u64>) {
        self.current_downloaded_bytes.store(0, Ordering::Relaxed);
        self.current_total_bytes
            .store(total_bytes.unwrap_or(0), Ordering::Relaxed);
    }

    pub fn add_bytes(&self, count: u64) {
        self.current_downloaded_bytes
            .fetch_add(count, Ordering::Relaxed);
    }

    pub fn finish_file(&self) {
        self.completed_files.fetch_add(1, Ordering::Relaxed);
        self.current_downloaded_bytes.store(0, Ordering::Relaxed);
        self.current_total_bytes.store(0, Ordering::Relaxed);
    }

    pub fn finish(&self) {
        self.running.store(false, Ordering::Relaxed);
        self.current_downloaded_bytes.store(0, Ordering::Relaxed);
        self.current_total_bytes.store(0, Ordering::Relaxed);
    }

    pub fn fraction(&self) -> f32 {
        let total = self.total_files.load(Ordering::Relaxed).max(1);
        let completed = self.completed_files.load(Ordering::Relaxed).min(total);
        let current_total = self.current_total_bytes.load(Ordering::Relaxed);
        let current_done = self.current_downloaded_bytes.load(Ordering::Relaxed);
        let current_fraction = if current_total > 0 {
            (current_done as f32 / current_total as f32).clamp(0.0, 1.0)
        } else {
            0.0
        };

        ((completed as f32) + current_fraction) / total as f32
    }

    pub fn is_running(&self) -> bool {
        self.running.load(Ordering::Relaxed)
    }
}

pub async fn ensure_auto_model(progress: Arc<DownloadProgress>) -> Result<()> {
    tokio::task::spawn_blocking(move || ensure_auto_model_blocking(progress))
        .await
        .map_err(|err| anyhow!("model download worker failed: {err}"))?
}

fn ensure_auto_model_blocking(progress: Arc<DownloadProgress>) -> Result<()> {
    if CurrentWiredModel::auto_assets_ready() {
        progress.finish();
        return Ok(());
    }

    fs::create_dir_all(CurrentWiredModel::auto_models_root())
        .context("failed to create model cache directory")?;

    progress.start(4);

    let client = Client::builder()
        .user_agent("voice-typing/0.1")
        .build()
        .context("failed to build HTTP client")?;

    let model_dir = CurrentWiredModel::auto_model_dir();
    fs::create_dir_all(&model_dir).with_context(|| {
        format!(
            "failed to create auto model directory {}",
            model_dir.display()
        )
    })?;

    match download_model_files(&client, &progress, &model_dir) {
        Ok(()) => {}
        Err(primary_err) => {
            download_model_archive(&client, &progress, &model_dir).with_context(|| {
                format!("failed primary file download and archive fallback: {primary_err}")
            })?;
        }
    }

    ensure_vad_file(&client, &progress)?;

    if !CurrentWiredModel::auto_assets_ready() {
        copy_bundled_assets(&progress)
            .context("downloaded assets were incomplete and bundled fallback failed")?;
    }

    if !CurrentWiredModel::auto_assets_ready() {
        anyhow::bail!(
            "auto model cache is still incomplete under {}",
            CurrentWiredModel::auto_models_root().display()
        );
    }

    progress.finish();
    Ok(())
}

fn download_model_files(
    client: &Client,
    progress: &DownloadProgress,
    model_dir: &Path,
) -> Result<()> {
    for file_name in [
        "encoder_model.ort",
        "decoder_model_merged.ort",
        "tokens.txt",
    ] {
        let target = model_dir.join(file_name);
        if target.exists() {
            progress.finish_file();
            continue;
        }

        let url = format!("{MODEL_REPO}/{file_name}");
        download_to_path(client, progress, &url, &target)?;
        progress.finish_file();
    }

    Ok(())
}

fn download_model_archive(
    client: &Client,
    progress: &DownloadProgress,
    model_dir: &Path,
) -> Result<()> {
    let archive_path = CurrentWiredModel::auto_models_root()
        .join(format!("{}.tar.bz2", CurrentWiredModel::MODEL_NAME));
    download_to_path(client, progress, MODEL_ARCHIVE_URL, &archive_path)?;

    let file = File::open(&archive_path).with_context(|| {
        format!(
            "failed to open downloaded archive {}",
            archive_path.display()
        )
    })?;
    let decoder = BzDecoder::new(file);
    let mut archive = Archive::new(decoder);
    archive
        .unpack(CurrentWiredModel::auto_models_root())
        .with_context(|| {
            format!(
                "failed to extract model archive into {}",
                CurrentWiredModel::auto_models_root().display()
            )
        })?;

    for required in [
        "encoder_model.ort",
        "decoder_model_merged.ort",
        "tokens.txt",
    ] {
        let path = model_dir.join(required);
        path.metadata()
            .with_context(|| format!("archive did not provide {}", path.display()))?;
    }

    progress.finish_file();
    Ok(())
}

fn ensure_vad_file(client: &Client, progress: &DownloadProgress) -> Result<()> {
    let target = CurrentWiredModel::auto_vad_path();
    if target.exists() {
        progress.finish_file();
        return Ok(());
    }

    download_to_path(client, progress, VAD_URL, &target)?;
    progress.finish_file();
    Ok(())
}

fn download_to_path(
    client: &Client,
    progress: &DownloadProgress,
    url: &str,
    target: &Path,
) -> Result<()> {
    if let Some(parent) = target.parent() {
        fs::create_dir_all(parent)
            .with_context(|| format!("failed to create {}", parent.display()))?;
    }

    let response = client
        .get(url)
        .send()
        .with_context(|| format!("failed GET {url}"))?
        .error_for_status()
        .with_context(|| format!("upstream rejected {url}"))?;

    progress.begin_file(response.content_length());

    let part_path = target.with_extension(format!(
        "{}.part",
        target
            .extension()
            .and_then(|ext| ext.to_str())
            .unwrap_or("download")
    ));

    let mut reader = response;
    let mut writer = File::create(&part_path)
        .with_context(|| format!("failed to create {}", part_path.display()))?;
    let mut buffer = [0_u8; 64 * 1024];

    loop {
        let count = reader
            .read(&mut buffer)
            .with_context(|| format!("failed while reading {url}"))?;
        if count == 0 {
            break;
        }
        writer
            .write_all(&buffer[..count])
            .with_context(|| format!("failed writing {}", part_path.display()))?;
        progress.add_bytes(count as u64);
    }

    writer
        .flush()
        .with_context(|| format!("failed flushing {}", part_path.display()))?;
    fs::rename(&part_path, target).with_context(|| {
        format!(
            "failed to move downloaded file {} into place {}",
            part_path.display(),
            target.display()
        )
    })?;

    Ok(())
}

fn copy_bundled_assets(progress: &DownloadProgress) -> Result<()> {
    let roots = candidate_roots();
    let bundled_model = roots
        .iter()
        .map(|root| root.join(CurrentWiredModel::MODEL_DIR))
        .find(|path| path.exists())
        .context("unable to locate bundled model assets")?;
    let bundled_vad = roots
        .iter()
        .map(|root| root.join(CurrentWiredModel::VAD_PATH))
        .find(|path| path.exists())
        .context("unable to locate bundled VAD asset")?;

    fs::create_dir_all(CurrentWiredModel::auto_model_dir()).with_context(|| {
        format!(
            "failed to create bundled fallback directory {}",
            CurrentWiredModel::auto_model_dir().display()
        )
    })?;

    for file_name in [
        "encoder_model.ort",
        "decoder_model_merged.ort",
        "tokens.txt",
    ] {
        let source = bundled_model.join(file_name);
        let target = CurrentWiredModel::auto_model_dir().join(file_name);
        if !target.exists() {
            fs::copy(&source, &target).with_context(|| {
                format!(
                    "failed to copy bundled asset {} -> {}",
                    source.display(),
                    target.display()
                )
            })?;
        }
    }
    if !CurrentWiredModel::auto_vad_path().exists() {
        fs::copy(&bundled_vad, CurrentWiredModel::auto_vad_path()).with_context(|| {
            format!(
                "failed to copy bundled VAD {} -> {}",
                bundled_vad.display(),
                CurrentWiredModel::auto_vad_path().display()
            )
        })?;
    }

    while progress.completed_files.load(Ordering::Relaxed) < 4 {
        progress.finish_file();
    }

    Ok(())
}

fn candidate_roots() -> Vec<PathBuf> {
    let mut roots = Vec::new();

    if let Ok(current_dir) = std::env::current_dir() {
        roots.push(current_dir);
    }

    if let Ok(exe) = std::env::current_exe() {
        if let Some(parent) = exe.parent() {
            roots.push(parent.to_path_buf());
        }
    }

    roots
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[ignore = "network smoke test"]
    fn auto_model_download_smoke() {
        let _ = fs::remove_dir_all(CurrentWiredModel::auto_model_dir());
        let _ = fs::remove_file(CurrentWiredModel::auto_vad_path());

        let progress = Arc::new(DownloadProgress::default());
        ensure_auto_model_blocking(progress).expect("auto downloader should repopulate cache");
        assert!(CurrentWiredModel::auto_assets_ready());
    }
}