forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use std::sync::Arc;

#[cfg(feature = "classifier")]
use std::str::FromStr;

#[cfg(feature = "classifier")]
use forge_guardrails::{
    default_tool_call_classifier_artifact_dir, ensure_classifier_artifact_dir,
    ClassifierArtifactKind, ClassifierModelKind, FinalResponseContext, FinalResponseScore,
    OnnxScorerOptions, ScoringContext, ToolCall, ToolCallScore, DEFAULT_CLASSIFIER_REPO,
    DEFAULT_CLASSIFIER_REVISION,
};
use forge_guardrails::{FinalResponseScorer, ScorerMode, ToolCallScorer};

use crate::cli::Cli;
#[cfg(feature = "classifier")]
use crate::config::validate_nonempty;
use crate::config::ProxyConfig;

pub(super) fn prepare_classifier_artifact(config: &ProxyConfig) -> Result<(), String> {
    if !config.classifier_auto_download {
        return Ok(());
    }
    let Some(dir) = config.classifier_dir.as_deref() else {
        return Err("--classify did not resolve a classifier artifact directory".to_string());
    };

    #[cfg(feature = "classifier")]
    {
        ensure_classifier_artifact_dir(
            ClassifierArtifactKind::ToolCall,
            dir,
            DEFAULT_CLASSIFIER_REPO,
            DEFAULT_CLASSIFIER_REVISION,
            config.classifier_model,
            |line| eprintln!("{line}"),
        )
        .map_err(|err| format!("failed to prepare classifier artifact: {err}"))?;
        Ok(())
    }

    #[cfg(not(feature = "classifier"))]
    {
        let _ = dir;
        Err("--classify requires building with --features classifier".to_string())
    }
}

pub(super) fn download_classifier_shortcut(cli: &Cli) -> Result<(), String> {
    #[cfg(feature = "classifier")]
    {
        let artifact_dir = match cli.classifier_dir.as_deref() {
            Some(raw) => validate_nonempty(raw, "--classifier-dir")?.to_string(),
            None => default_tool_call_classifier_artifact_dir()
                .map_err(|err| err.to_string())?
                .to_string_lossy()
                .into_owned(),
        };
        let model = match cli.classifier_model.as_deref() {
            Some(raw) => ClassifierModelKind::from_str(raw)?,
            None => ClassifierModelKind::Quantized,
        };
        ensure_classifier_artifact_dir(
            ClassifierArtifactKind::ToolCall,
            artifact_dir,
            DEFAULT_CLASSIFIER_REPO,
            DEFAULT_CLASSIFIER_REVISION,
            model,
            |line| println!("{line}"),
        )
        .map_err(|err| format!("failed to download classifier artifact: {err}"))?;
        Ok(())
    }

    #[cfg(not(feature = "classifier"))]
    {
        let _ = cli;
        Err("--classify-download requires building with --features classifier".to_string())
    }
}

#[cfg(feature = "classifier")]
const FORGE_CLASSIFIER_SESSION_POOL_SIZE: &str = "FORGE_CLASSIFIER_SESSION_POOL_SIZE";
#[cfg(feature = "classifier")]
const FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE: &str =
    "FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE";
#[cfg(feature = "classifier")]
const FORGE_CLASSIFIER_INTRA_THREADS: &str = "FORGE_CLASSIFIER_INTRA_THREADS";
#[cfg(feature = "classifier")]
const FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS: &str =
    "FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS";

pub(super) fn build_classifier_scorer(
    config: &ProxyConfig,
) -> Result<Option<Arc<dyn ToolCallScorer>>, String> {
    if config.classifier_mode == ScorerMode::Disabled {
        return Ok(None);
    }
    let Some(dir) = config.classifier_dir.as_deref() else {
        return Ok(None);
    };

    #[cfg(feature = "classifier")]
    {
        let options = onnx_options_from_env(
            FORGE_CLASSIFIER_SESSION_POOL_SIZE,
            FORGE_CLASSIFIER_INTRA_THREADS,
            "classifier",
        )?;
        let scorer = forge_guardrails::OnnxToolCallScorer::from_dir_with_model_and_options(
            dir,
            Some(config.classifier_mode),
            config.classifier_model,
            options,
        )
        .map_err(|err| format!("failed to load classifier artifact: {err}"))?;
        let scorer: Arc<dyn ToolCallScorer> = Arc::new(scorer);
        Ok(Some(wrap_tool_latency_warning(
            scorer,
            config.classifier_max_latency_ms,
        )))
    }

    #[cfg(not(feature = "classifier"))]
    {
        let _ = dir;
        Err("classifier support requires building with --features classifier".to_string())
    }
}

pub(super) fn build_final_response_classifier_scorer(
    config: &ProxyConfig,
) -> Result<Option<Arc<dyn FinalResponseScorer>>, String> {
    if config.final_response_classifier_mode == ScorerMode::Disabled {
        return Ok(None);
    }
    let Some(dir) = config.final_response_classifier_dir.as_deref() else {
        return Ok(None);
    };

    #[cfg(feature = "classifier")]
    {
        let options = onnx_options_from_env(
            FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE,
            FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS,
            "final-response classifier",
        )?;
        let scorer = forge_guardrails::OnnxFinalResponseScorer::from_dir_with_model_and_options(
            dir,
            Some(config.final_response_classifier_mode),
            config.final_response_classifier_model,
            options,
        )
        .map_err(|err| format!("failed to load final-response classifier artifact: {err}"))?;
        let scorer: Arc<dyn FinalResponseScorer> = Arc::new(scorer);
        Ok(Some(wrap_final_response_latency_warning(
            scorer,
            config.final_response_classifier_max_latency_ms,
        )))
    }

    #[cfg(not(feature = "classifier"))]
    {
        let _ = dir;
        Err(
            "final-response classifier support requires building with --features classifier"
                .to_string(),
        )
    }
}

#[cfg(feature = "classifier")]
fn onnx_options_from_env(
    pool_var: &str,
    intra_var: &str,
    label: &str,
) -> Result<OnnxScorerOptions, String> {
    let mut options = OnnxScorerOptions::default();
    if let Some(value) = optional_usize_env(pool_var)? {
        options.session_pool_size = value;
    }
    if let Some(value) = optional_usize_env(intra_var)? {
        options.intra_threads = value;
    }
    options
        .validate()
        .map_err(|err| format!("invalid {label} ONNX runtime options: {err}"))
}

#[cfg(feature = "classifier")]
fn optional_usize_env(name: &str) -> Result<Option<usize>, String> {
    match std::env::var(name) {
        Ok(raw) if raw.trim().is_empty() => Ok(None),
        Ok(raw) => raw
            .parse::<usize>()
            .map(Some)
            .map_err(|err| format!("{name} must be a positive integer: {err}")),
        Err(std::env::VarError::NotPresent) => Ok(None),
        Err(err) => Err(format!("failed to read {name}: {err}")),
    }
}

#[cfg(feature = "classifier")]
struct ToolCallLatencyWarningScorer {
    inner: Arc<dyn ToolCallScorer>,
    max_latency_ms: u64,
}

#[cfg(feature = "classifier")]
impl ToolCallScorer for ToolCallLatencyWarningScorer {
    fn score(&self, ctx: &ScoringContext, candidate: &ToolCall) -> anyhow::Result<ToolCallScore> {
        let score = self.inner.score(ctx, candidate)?;
        if score.latency_ms > self.max_latency_ms as f64 {
            tracing::warn!(
                target: "forge.classifier",
                tool = %candidate.tool,
                latency_ms = score.latency_ms,
                max_latency_ms = self.max_latency_ms,
                "tool-call classifier latency exceeded configured warning limit"
            );
        }
        Ok(score)
    }
}

#[cfg(feature = "classifier")]
fn wrap_tool_latency_warning(
    scorer: Arc<dyn ToolCallScorer>,
    max_latency_ms: Option<u64>,
) -> Arc<dyn ToolCallScorer> {
    match max_latency_ms {
        Some(max_latency_ms) => Arc::new(ToolCallLatencyWarningScorer {
            inner: scorer,
            max_latency_ms,
        }),
        None => scorer,
    }
}

#[cfg(feature = "classifier")]
struct FinalResponseLatencyWarningScorer {
    inner: Arc<dyn FinalResponseScorer>,
    max_latency_ms: u64,
}

#[cfg(feature = "classifier")]
impl FinalResponseScorer for FinalResponseLatencyWarningScorer {
    fn score(&self, ctx: &FinalResponseContext) -> anyhow::Result<FinalResponseScore> {
        let score = self.inner.score(ctx)?;
        if score.latency_ms > self.max_latency_ms as f64 {
            tracing::warn!(
                target: "forge.classifier",
                latency_ms = score.latency_ms,
                max_latency_ms = self.max_latency_ms,
                "final-response classifier latency exceeded configured warning limit"
            );
        }
        Ok(score)
    }
}

#[cfg(feature = "classifier")]
fn wrap_final_response_latency_warning(
    scorer: Arc<dyn FinalResponseScorer>,
    max_latency_ms: Option<u64>,
) -> Arc<dyn FinalResponseScorer> {
    match max_latency_ms {
        Some(max_latency_ms) => Arc::new(FinalResponseLatencyWarningScorer {
            inner: scorer,
            max_latency_ms,
        }),
        None => scorer,
    }
}