kbolt-core 0.1.7

Core engine for kbolt local-first retrieval
Documentation
use std::collections::HashSet;
use std::fs;
use std::path::Path;

use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Repo, RepoType};
use kbolt_types::KboltError;

use super::super::provider::{ModelArtifactProvider, ModelDownloadRequest, ModelFileRequirement};
use crate::Result;

pub(crate) struct HfHubDownloader;

impl ModelArtifactProvider for HfHubDownloader {
    fn download_model(&self, request: &ModelDownloadRequest, target_dir: &Path) -> Result<u64> {
        fs::create_dir_all(target_dir)?;
        let api = ApiBuilder::new()
            .with_cache_dir(target_dir.to_path_buf())
            .build()
            .map_err(|err| KboltError::ModelDownload(format!("{}: {err}", request.model_id)))?;
        let repo = api.repo(Repo::new(request.model_id.clone(), RepoType::Model));
        let info = repo
            .info()
            .map_err(|err| KboltError::ModelDownload(format!("{}: {err}", request.model_id)))?;
        let siblings = info
            .siblings
            .into_iter()
            .map(|sibling| sibling.rfilename)
            .collect::<Vec<_>>();
        let required_files =
            resolve_required_files(&request.model_id, &siblings, &request.requirements)?;
        if required_files.is_empty() {
            return Err(KboltError::ModelDownload(format!(
                "{}: no files selected for download",
                request.model_id
            ))
            .into());
        }

        let mut total_bytes = 0u64;
        for sibling in required_files {
            let file_path = repo
                .get(&sibling)
                .map_err(|err| KboltError::ModelDownload(format!("{}: {err}", request.model_id)))?;
            total_bytes = total_bytes.saturating_add(file_size_bytes(&file_path)?);
        }

        if total_bytes == 0 {
            return Err(KboltError::ModelDownload(format!(
                "{}: no files were downloaded",
                request.model_id
            ))
            .into());
        }

        Ok(total_bytes)
    }
}

fn resolve_required_files(
    model_id: &str,
    siblings: &[String],
    requirements: &[ModelFileRequirement],
) -> Result<Vec<String>> {
    let mut selected = Vec::new();
    let mut seen = HashSet::new();
    for requirement in requirements {
        let resolved = resolve_requirement(model_id, siblings, requirement)?;
        if seen.insert(resolved.clone()) {
            selected.push(resolved);
        }
    }
    Ok(selected)
}

fn resolve_requirement(
    model_id: &str,
    siblings: &[String],
    requirement: &ModelFileRequirement,
) -> Result<String> {
    match requirement {
        ModelFileRequirement::ExactPath { path, config_field } => {
            let normalized = normalize_repo_path(path);
            if siblings.iter().any(|sibling| sibling == &normalized) {
                Ok(normalized)
            } else {
                Err(KboltError::ModelDownload(format!(
                    "{model_id}: configured file for {config_field} not found in repo: {path}"
                ))
                .into())
            }
        }
        ModelFileRequirement::SingleExtension {
            extension,
            config_field,
        } => {
            let matches = siblings
                .iter()
                .filter(|path| has_extension(path, extension))
                .cloned()
                .collect::<Vec<_>>();
            match matches.len() {
                0 => Err(KboltError::ModelDownload(format!(
                    "{model_id}: missing .{extension} artifact. set {config_field} to the desired file"
                ))
                .into()),
                1 => Ok(matches[0].clone()),
                _ => Err(KboltError::ModelDownload(format!(
                    "{model_id}: multiple .{extension} artifacts found. set {config_field} to choose one"
                ))
                .into()),
            }
        }
        ModelFileRequirement::SingleTokenizerJson { config_field } => {
            let matches = siblings
                .iter()
                .filter(|path| is_tokenizer_json(path))
                .cloned()
                .collect::<Vec<_>>();
            match matches.len() {
                0 => Err(KboltError::ModelDownload(format!(
                    "{model_id}: missing tokenizer.json. set {config_field} to the tokenizer path"
                ))
                .into()),
                1 => Ok(matches[0].clone()),
                _ => Err(KboltError::ModelDownload(format!(
                    "{model_id}: multiple tokenizer.json files found. set {config_field} to choose one"
                ))
                .into()),
            }
        }
    }
}

fn normalize_repo_path(path: &str) -> String {
    path.trim_start_matches("./").to_string()
}

fn has_extension(path: &str, extension: &str) -> bool {
    Path::new(path)
        .extension()
        .and_then(|ext| ext.to_str())
        .map(|ext| ext.eq_ignore_ascii_case(extension))
        .unwrap_or(false)
}

fn is_tokenizer_json(path: &str) -> bool {
    Path::new(path)
        .file_name()
        .and_then(|name| name.to_str())
        .map(|name| name.eq_ignore_ascii_case("tokenizer.json"))
        .unwrap_or(false)
}

fn file_size_bytes(path: &Path) -> Result<u64> {
    Ok(fs::metadata(path)?.len())
}