forge-guardrails 0.1.0

Foundation types for an LLM-agent workflow framework
Documentation
use super::files::required_files_for_model;
use super::mod_impl::classifier_artifact_needs_download;
use super::paths::{cache_root_from_values, sanitize_repo};
use super::types::ClassifierArtifactKind;
use crate::guardrails::ClassifierModelKind;
use serde_json::json;
use std::path::{Path, PathBuf};

#[test]
fn cache_root_prefers_explicit_then_xdg_then_home() {
    assert_eq!(
        cache_root_from_values(Some("/tmp/forge-cache"), Some("/tmp/xdg"), Some("/home/me"))
            .expect("explicit"),
        PathBuf::from("/tmp/forge-cache")
    );
    assert_eq!(
        cache_root_from_values(None, Some("/tmp/xdg"), Some("/home/me")).expect("xdg"),
        PathBuf::from("/tmp/xdg/forge-guardrails/classifiers")
    );
    assert_eq!(
        cache_root_from_values(None, None, Some("/home/me")).expect("home"),
        PathBuf::from("/home/me/.cache/forge-guardrails/classifiers")
    );
    assert!(cache_root_from_values(None, None, None).is_err());
}

#[test]
fn cached_artifact_dir_uses_stable_layout() {
    let root = cache_root_from_values(Some("/tmp/forge-cache"), None, None).expect("root");
    let dir = root
        .join(ClassifierArtifactKind::ToolCall.as_str())
        .join("owner__name")
        .join("revision")
        .join("onnx");
    let actual = root
        .join(ClassifierArtifactKind::ToolCall.as_str())
        .join(sanitize_repo("owner/name"))
        .join("revision")
        .join("onnx");
    assert_eq!(actual, dir);
}

#[test]
fn required_quantized_tool_call_files_include_runtime_assets() {
    let files = required_files_for_model(
        ClassifierArtifactKind::ToolCall,
        ClassifierModelKind::Quantized,
    );
    assert!(files.contains(&"onnx/tokenizer.json"));
    assert!(files.contains(&"onnx/model_quantized.onnx"));
    assert!(!files.contains(&"onnx/model.onnx"));
}

#[test]
fn missing_required_file_needs_download() {
    let temp = make_temp_dir("missing-required");
    assert!(classifier_artifact_needs_download(
        ClassifierArtifactKind::ToolCall,
        &temp,
        ClassifierModelKind::Quantized
    ));
}

#[test]
fn valid_minimal_tool_call_artifact_skips_download() {
    let temp = make_temp_dir("valid-minimal");
    write_minimal_tool_call_artifact(&temp);
    assert!(!classifier_artifact_needs_download(
        ClassifierArtifactKind::ToolCall,
        &temp,
        ClassifierModelKind::Quantized
    ));
}

fn make_temp_dir(name: &str) -> PathBuf {
    let dir = std::env::temp_dir().join(format!(
        "forge-classifier-download-{name}-{}",
        std::process::id()
    ));
    let _ = std::fs::remove_dir_all(&dir);
    std::fs::create_dir_all(&dir).expect("tempdir");
    dir
}

fn write_minimal_tool_call_artifact(dir: &Path) {
    std::fs::create_dir_all(dir).expect("dir");
    let labels = vec![
        "valid",
        "wrong_tool_semantic",
        "wrong_arguments_semantic",
        "tool_not_needed",
        "needs_clarification",
        "deterministic_invalid",
    ];
    let label2id = labels
        .iter()
        .enumerate()
        .map(|(index, label)| (label.to_string(), json!(index)))
        .collect::<serde_json::Map<_, _>>();
    let id2label = labels
        .iter()
        .enumerate()
        .map(|(index, label)| (index.to_string(), json!(label)))
        .collect::<serde_json::Map<_, _>>();
    let thresholds = labels
        .iter()
        .map(|label| {
            (
                label.to_string(),
                json!({
                    "action": "shadow_only",
                    "advisory_min_confidence": 1.01,
                    "enforce_min_confidence": 1.01
                }),
            )
        })
        .collect::<serde_json::Map<_, _>>();

    write_json(
        dir.join("artifact_manifest.json"),
        json!({
            "artifact_schema_version": "toolcall-verifier-artifact/v1",
            "input_schema_version": "toolcall-verifier-input/v1",
            "serializer": "serialize_state_v1",
            "max_length": 1280,
            "onnx_file": "model.onnx",
            "quantized_onnx_file": "model_quantized.onnx",
            "labels": labels,
        }),
    );
    write_json(
        dir.join("labels.json"),
        json!({
            "label_mode": "six_label",
            "labels": labels,
            "label2id": label2id,
            "id2label": id2label,
        }),
    );
    write_json(
        dir.join("thresholds.json"),
        json!({
            "schema_version": "toolcall-verifier-thresholds/v1",
            "mode": "shadow",
            "default_action": "shadow_only",
            "labels": thresholds,
        }),
    );
    for file in [
        "input_schema.json",
        "serializer_fixture.json",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "added_tokens.json",
        "spm.model",
        "config.json",
        "test_metrics.json",
        "training_metrics.json",
        "training_run_summary.json",
    ] {
        std::fs::write(dir.join(file), "{}").expect("sidecar");
    }
    std::fs::write(dir.join("model_quantized.onnx"), b"onnx").expect("model");
}

fn write_json(path: PathBuf, value: serde_json::Value) {
    std::fs::write(path, serde_json::to_vec(&value).expect("json")).expect("write");
}