kbolt-core 0.1.7

Core engine for kbolt local-first retrieval
Documentation
use std::path::{Path, PathBuf};

use kbolt_types::KboltError;
use walkdir::WalkDir;

use crate::Result;

pub(super) fn resolve_tokenizer_file(
    artifact_dir: &Path,
    configured: Option<&str>,
    config_field: &str,
) -> Result<PathBuf> {
    if let Some(configured) = configured {
        return resolve_configured_file(artifact_dir, configured, config_field);
    }

    let matches = discover_files(artifact_dir, |path| {
        path.file_name()
            .and_then(|name| name.to_str())
            .map(|name| name.eq_ignore_ascii_case("tokenizer.json"))
            .unwrap_or(false)
    })?;
    match matches.len() {
        0 => Err(KboltError::Inference(format!(
            "missing tokenizer.json under {}",
            artifact_dir.display()
        ))
        .into()),
        1 => Ok(matches[0].clone()),
        _ => Err(KboltError::Inference(format!(
            "multiple tokenizer.json files found under {}. set {config_field} to choose one",
            artifact_dir.display()
        ))
        .into()),
    }
}

pub(super) fn resolve_file_with_extension(
    artifact_dir: &Path,
    configured: Option<&str>,
    extension: &str,
    config_field: &str,
) -> Result<PathBuf> {
    if let Some(configured) = configured {
        return resolve_configured_file(artifact_dir, configured, config_field);
    }

    let matches = discover_files(artifact_dir, |path| {
        path.extension()
            .and_then(|ext| ext.to_str())
            .map(|ext| ext.eq_ignore_ascii_case(extension))
            .unwrap_or(false)
    })?;
    match matches.len() {
        0 => Err(KboltError::Inference(format!(
            "missing .{extension} artifact under {}",
            artifact_dir.display()
        ))
        .into()),
        1 => Ok(matches[0].clone()),
        _ => Err(KboltError::Inference(format!(
            "multiple .{extension} artifacts found under {}. set {config_field} to choose one",
            artifact_dir.display()
        ))
        .into()),
    }
}

fn resolve_configured_file(
    artifact_dir: &Path,
    configured: &str,
    config_field: &str,
) -> Result<PathBuf> {
    let configured_path = Path::new(configured);
    let path = if configured_path.is_absolute() {
        configured_path.to_path_buf()
    } else {
        artifact_dir.join(configured_path)
    };
    if !path.is_file() {
        return Err(KboltError::Inference(format!(
            "{config_field} does not point to an existing file: {}",
            path.display()
        ))
        .into());
    }
    Ok(path)
}

fn discover_files<F>(root: &Path, mut predicate: F) -> Result<Vec<PathBuf>>
where
    F: FnMut(&Path) -> bool,
{
    let mut results = Vec::new();
    for entry in WalkDir::new(root).follow_links(true) {
        let entry = entry.map_err(|err| {
            KboltError::Inference(format!(
                "failed to walk model artifact directory {}: {err}",
                root.display()
            ))
        })?;
        if !entry.file_type().is_file() {
            continue;
        }
        let path = entry.path();
        if predicate(path) {
            results.push(path.to_path_buf());
        }
    }
    results.sort();
    Ok(results)
}

#[cfg(test)]
mod tests {
    use std::fs;

    use tempfile::tempdir;

    use super::{resolve_file_with_extension, resolve_tokenizer_file};

    fn write_file(path: &Path) {
        if let Some(parent) = path.parent() {
            fs::create_dir_all(parent).expect("create parent directories");
        }
        fs::write(path, b"fixture").expect("write fixture file");
    }

    use std::path::Path;

    #[test]
    fn resolve_file_with_extension_returns_single_match() {
        let root = tempdir().expect("create tempdir");
        let model = root.path().join("model.onnx");
        write_file(&model);

        let resolved =
            resolve_file_with_extension(root.path(), None, "onnx", "embeddings.onnx_file")
                .expect("resolve .onnx file");
        assert_eq!(resolved, model);
    }

    #[test]
    fn resolve_file_with_extension_errors_on_ambiguous_match() {
        let root = tempdir().expect("create tempdir");
        write_file(&root.path().join("a/model.onnx"));
        write_file(&root.path().join("b/model.onnx"));

        let err = resolve_file_with_extension(root.path(), None, "onnx", "embeddings.onnx_file")
            .expect_err("ambiguous .onnx files should error");
        assert!(err.to_string().contains("multiple .onnx artifacts found"));
    }

    #[test]
    fn resolve_file_with_extension_uses_configured_override() {
        let root = tempdir().expect("create tempdir");
        write_file(&root.path().join("a/model.onnx"));
        let preferred = root.path().join("b/preferred.onnx");
        write_file(&preferred);

        let resolved = resolve_file_with_extension(
            root.path(),
            Some("b/preferred.onnx"),
            "onnx",
            "embeddings.onnx_file",
        )
        .expect("resolve configured onnx file");
        assert_eq!(resolved, preferred);
    }

    #[test]
    fn resolve_tokenizer_file_errors_when_missing() {
        let root = tempdir().expect("create tempdir");
        let err = resolve_tokenizer_file(root.path(), None, "embeddings.tokenizer_file")
            .expect_err("missing tokenizer should error");
        assert!(err.to_string().contains("missing tokenizer.json"));
    }
}