direct_play_nice 0.1.0-alpha.7

CLI program that converts video files to direct-play-compatible formats.
Documentation
fn resolve_model_dir() -> Result<PathBuf> {
    if let Some(dir) = env::var_os("DPN_OCR_MODEL_DIR") {
        let path = PathBuf::from(dir);
        fs::create_dir_all(&path)
            .with_context(|| format!("creating OCR model directory '{}'", path.display()))?;
        return Ok(path);
    }

    if let Ok(exe) = env::current_exe() {
        if let Some(parent) = exe.parent() {
            let candidate = parent.join("models");
            if candidate.is_dir() {
                return Ok(candidate);
            }
        }
    }

    let fallback = if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") {
        PathBuf::from(xdg).join("direct-play-nice").join("models")
    } else if let Some(home) = env::var_os("HOME") {
        PathBuf::from(home)
            .join(".config")
            .join("direct-play-nice")
            .join("models")
    } else {
        env::temp_dir().join("direct-play-nice-models")
    };

    fs::create_dir_all(&fallback)
        .with_context(|| format!("creating OCR model directory '{}'", fallback.display()))?;
    Ok(fallback)
}

fn ensure_ppocr_models(
    model_dir: &Path,
    variant: PpOcrVariant,
    skip_cls: bool,
) -> Result<PpOcrModels> {
    let (det_spec, cls_spec, rec_spec) = variant.model_specs();
    if skip_cls {
        debug!(
            "Skipping classifier is requested, but this build uses mandatory classifier initialization; loading cls model."
        );
    }
    let det = ensure_model_file(model_dir, det_spec)?;
    let cls = ensure_model_file(model_dir, cls_spec)?;
    let rec = ensure_model_file(model_dir, rec_spec)?;
    Ok(PpOcrModels { det, cls, rec })
}

fn resolve_optional_latin_rec_model(model_dir: &Path, variant: PpOcrVariant) -> Result<Option<PathBuf>> {
    if let Some(path) = env::var_os("DPN_OCR_REC_LATIN_MODEL") {
        let rec_path = PathBuf::from(path);
        if !rec_path.is_file() {
            bail!(
                "DPN_OCR_REC_LATIN_MODEL is set but file does not exist: '{}'",
                rec_path.display()
            );
        }
        return Ok(Some(rec_path));
    }

    let candidates: &[&str] = match variant {
        PpOcrVariant::V3 => &[
            "latin_PP-OCRv3_rec_infer.onnx",
            "latin_ppocr_mobile_v2.0_rec_infer.onnx",
            "multilingual_PP-OCRv3_rec_infer.onnx",
        ],
        PpOcrVariant::V4 => &[
            "latin_PP-OCRv4_rec_infer.onnx",
            "latin_ppocr_mobile_v2.0_rec_infer.onnx",
            "multilingual_PP-OCRv4_rec_infer.onnx",
        ],
    };

    for candidate in candidates {
        let path = model_dir.join(candidate);
        if path.is_file() {
            return Ok(Some(path));
        }
    }
    Ok(None)
}

fn resolve_optional_japanese_rec_model(
    model_dir: &Path,
    variant: PpOcrVariant,
) -> Result<Option<PathBuf>> {
    resolve_optional_rec_model_with_candidates(
        "DPN_OCR_REC_JAPANESE_MODEL",
        model_dir,
        match variant {
            PpOcrVariant::V3 => &[
                "japan_PP-OCRv3_rec_infer.onnx",
                "japanese_PP-OCRv3_rec_infer.onnx",
                "ja_PP-OCRv3_rec_infer.onnx",
            ],
            PpOcrVariant::V4 => &[
                "japan_PP-OCRv4_rec_infer.onnx",
                "japanese_PP-OCRv4_rec_infer.onnx",
                "ja_PP-OCRv4_rec_infer.onnx",
            ],
        },
    )
}

fn resolve_optional_korean_rec_model(
    model_dir: &Path,
    variant: PpOcrVariant,
) -> Result<Option<PathBuf>> {
    resolve_optional_rec_model_with_candidates(
        "DPN_OCR_REC_KOREAN_MODEL",
        model_dir,
        match variant {
            PpOcrVariant::V3 => &[
                "korean_PP-OCRv3_rec_infer.onnx",
                "ko_PP-OCRv3_rec_infer.onnx",
            ],
            PpOcrVariant::V4 => &[
                "korean_PP-OCRv4_rec_infer.onnx",
                "ko_PP-OCRv4_rec_infer.onnx",
            ],
        },
    )
}

fn resolve_optional_cjk_rec_model(model_dir: &Path, variant: PpOcrVariant) -> Result<Option<PathBuf>> {
    resolve_optional_rec_model_with_candidates(
        "DPN_OCR_REC_CJK_MODEL",
        model_dir,
        match variant {
            PpOcrVariant::V3 => &[
                "cjk_PP-OCRv3_rec_infer.onnx",
                "chinese_PP-OCRv3_rec_infer.onnx",
                "zh_PP-OCRv3_rec_infer.onnx",
            ],
            PpOcrVariant::V4 => &[
                "cjk_PP-OCRv4_rec_infer.onnx",
                "chinese_PP-OCRv4_rec_infer.onnx",
                "zh_PP-OCRv4_rec_infer.onnx",
            ],
        },
    )
}

fn resolve_optional_rec_model_with_candidates(
    env_key: &str,
    model_dir: &Path,
    candidates: &[&str],
) -> Result<Option<PathBuf>> {
    if let Some(path) = env::var_os(env_key) {
        let rec_path = PathBuf::from(path);
        if !rec_path.is_file() {
            bail!(
                "{} is set but file does not exist: '{}'",
                env_key,
                rec_path.display()
            );
        }
        return Ok(Some(rec_path));
    }

    for candidate in candidates {
        let path = model_dir.join(candidate);
        if path.is_file() {
            return Ok(Some(path));
        }
    }
    Ok(None)
}

fn ensure_model_file(model_dir: &Path, spec: &ModelSpec) -> Result<PathBuf> {
    let path = model_dir.join(spec.filename);
    if path.exists() {
        if let Ok(hash) = sha256_file(&path) {
            if hash != spec.sha256.to_ascii_lowercase() {
                warn!(
                    "OCR model '{}' hash mismatch (got {}, expected {}). Using existing file anyway.",
                    path.display(),
                    hash,
                    spec.sha256
                );
            }
        }
        return Ok(path);
    }

    info!(
        "Downloading OCR model '{}' from {}",
        spec.filename, spec.url
    );
    download_model_with_values(&path, spec.url, spec.sha256, spec.filename)?;
    Ok(path)
}

#[cfg(test)]
fn ensure_model_file_with_values(
    model_dir: &Path,
    filename: &str,
    url: &str,
    sha256: &str,
) -> Result<PathBuf> {
    let path = model_dir.join(filename);
    if path.exists() {
        return Ok(path);
    }
    info!("Downloading OCR model '{}' from {}", filename, url);
    download_model_with_values(&path, url, sha256, filename)?;
    Ok(path)
}

fn download_model_with_values(path: &Path, url: &str, sha256: &str, filename: &str) -> Result<()> {
    let parent = path
        .parent()
        .ok_or_else(|| anyhow!("invalid OCR model path '{}'", path.display()))?;
    fs::create_dir_all(parent)
        .with_context(|| format!("creating OCR model directory '{}'", parent.display()))?;

    let tmp_path = path.with_extension("download");
    let response = ureq::get(url)
        .call()
        .with_context(|| format!("downloading OCR model from {}", url))?;
    let mut reader = response.into_reader();
    let mut file = fs::File::create(&tmp_path)
        .with_context(|| format!("creating '{}'", tmp_path.display()))?;
    let mut hasher = Sha256::new();
    let mut buf = [0u8; 16 * 1024];
    let mut download_result = Ok(());
    loop {
        let read = match reader.read(&mut buf) {
            Ok(read) => read,
            Err(err) => {
                download_result = Err(anyhow!(err).context("reading OCR model download"));
                break;
            }
        };
        if read == 0 {
            break;
        }
        hasher.update(&buf[..read]);
        if let Err(err) = file.write_all(&buf[..read]) {
            download_result = Err(anyhow!(err).context("writing OCR model download"));
            break;
        }
    }
    let _ = file.flush();

    if let Err(err) = download_result {
        let _ = fs::remove_file(&tmp_path);
        return Err(err);
    }

    let actual_hash = to_hex_lower(&hasher.finalize());
    let expected = sha256.to_ascii_lowercase();
    if actual_hash != expected {
        let _ = fs::remove_file(&tmp_path);
        bail!(
            "downloaded OCR model '{}' hash mismatch (got {}, expected {})",
            filename,
            actual_hash,
            expected
        );
    }

    fs::rename(&tmp_path, path).with_context(|| {
        format!(
            "renaming OCR model download '{}' -> '{}'",
            tmp_path.display(),
            path.display()
        )
    })?;
    Ok(())
}

fn sha256_file(path: &Path) -> Result<String> {
    let mut file = fs::File::open(path)?;
    let mut hasher = Sha256::new();
    let mut buf = [0u8; 16 * 1024];
    loop {
        let read = file.read(&mut buf)?;
        if read == 0 {
            break;
        }
        hasher.update(&buf[..read]);
    }
    Ok(to_hex_lower(&hasher.finalize()))
}

fn to_hex_lower(bytes: &[u8]) -> String {
    let mut out = String::with_capacity(bytes.len() * 2);
    for byte in bytes {
        use std::fmt::Write as _;
        let _ = write!(&mut out, "{:02x}", byte);
    }
    out
}