forge-guardrails 0.1.0

Foundation types for an LLM-agent workflow framework
Documentation
use std::str::FromStr;

use crate::cli::Cli;
#[cfg(feature = "classifier")]
use forge_guardrails::default_tool_call_classifier_artifact_dir;
use forge_guardrails::{
    ClassifierModelKind, ScorerMode, ToolCallPolicyConfig, ToolCallPolicyMode,
    ToolOutputCompressionConfig, ToolOutputCompressionMethod, ToolOutputCompressionMode,
};

pub(crate) const DEFAULT_PROXY_PORT: u16 = 8081;
pub(crate) const DEFAULT_BACKEND_PORT: u16 = 8080;
pub(crate) const DEFAULT_ENV_CONTEXT_TOKENS: i64 = 128_000;
pub(crate) const DEFAULT_EXTERNAL_CONTEXT_TOKENS: i64 = 8192;
pub(crate) const DEFAULT_ENV_HOST: &str = "0.0.0.0";
pub(crate) const DEFAULT_CLI_HOST: &str = "127.0.0.1";
pub(crate) const DEFAULT_INTERNAL_MODEL: &str = "forge-guardrails-unset";
pub(crate) const DEFAULT_MAX_RETRIES: i32 = 3;

mod env_helpers;
mod validation;

#[cfg(test)]
mod tests;

use env_helpers::{
    env_bool, env_classifier_model, env_final_response_classifier_model, env_first_string, env_i32,
    env_i64, env_optional_string, env_optional_u64, env_scoring_mode, env_string,
    env_tool_call_policy, env_tool_output_compression,
};

#[cfg(test)]
pub(super) use env_helpers::tool_output_compression_from_env_values;

pub(crate) use validation::{
    cli_host, cli_max_retries, cli_model, cli_port, normalized_extra_flags, require_cli_gguf,
    require_cli_llamafile_runtime, require_cli_model, resolve_serialize, validate_nonempty,
    validate_nonnegative_i32, validate_nonzero_u16, validate_optional_positive_i64,
    validate_positive_i64,
};

#[derive(Clone)]
pub(crate) struct ProxyConfig {
    pub(crate) host: String,
    pub(crate) port: u16,
    pub(crate) default_model: String,
    pub(crate) default_model_explicit: bool,
    pub(crate) context_tokens: i64,
    pub(crate) max_retries: i32,
    pub(crate) rescue_enabled: bool,
    pub(crate) serialize_requests: bool,
    pub(crate) verbose: bool,
    pub(crate) classifier_dir: Option<String>,
    pub(crate) classifier_mode: ScorerMode,
    pub(crate) classifier_model: ClassifierModelKind,
    pub(crate) classifier_auto_download: bool,
    pub(crate) classifier_max_latency_ms: Option<u64>,
    pub(crate) final_response_classifier_dir: Option<String>,
    pub(crate) final_response_classifier_mode: ScorerMode,
    pub(crate) final_response_classifier_model: ClassifierModelKind,
    pub(crate) final_response_classifier_max_latency_ms: Option<u64>,
    pub(crate) tool_output_compression: ToolOutputCompressionConfig,
    pub(crate) tool_call_policy: ToolCallPolicyConfig,
}

impl ProxyConfig {
    pub(crate) fn from_env() -> Result<Self, String> {
        let env_model = env_first_string(&["FORGE_MODEL", "SMALL_MODEL"]);
        Ok(Self {
            host: env_string(&["FORGE_HOST"], DEFAULT_ENV_HOST),
            port: env_helpers::env_u16(
                &["FORGE_PORT", "PORT", "LISTEN_PORT"],
                DEFAULT_PROXY_PORT,
                "FORGE_PORT",
            )?,
            default_model: env_model
                .clone()
                .unwrap_or_else(|| DEFAULT_INTERNAL_MODEL.to_string()),
            default_model_explicit: env_model.is_some(),
            context_tokens: env_i64(
                &["FORGE_CONTEXT_TOKENS"],
                DEFAULT_ENV_CONTEXT_TOKENS,
                "FORGE_CONTEXT_TOKENS",
            )?,
            max_retries: env_i32(
                &["FORGE_MAX_RETRIES"],
                DEFAULT_MAX_RETRIES,
                "FORGE_MAX_RETRIES",
            )?,
            rescue_enabled: env_bool("FORGE_RESCUE_ENABLED", true)?,
            serialize_requests: env_bool("FORGE_SERIALIZE_REQUESTS", false)?,
            verbose: false,
            classifier_dir: env_optional_string("FORGE_CLASSIFIER_DIR"),
            classifier_mode: env_scoring_mode("FORGE_CLASSIFIER_MODE", ScorerMode::Shadow)?,
            classifier_model: env_classifier_model()?,
            classifier_auto_download: false,
            classifier_max_latency_ms: env_optional_u64("FORGE_CLASSIFIER_MAX_LATENCY_MS")?,
            final_response_classifier_dir: env_optional_string(
                "FORGE_FINAL_RESPONSE_CLASSIFIER_DIR",
            ),
            final_response_classifier_mode: env_scoring_mode(
                "FORGE_FINAL_RESPONSE_CLASSIFIER_MODE",
                ScorerMode::Shadow,
            )?,
            final_response_classifier_model: env_final_response_classifier_model()?,
            final_response_classifier_max_latency_ms: env_optional_u64(
                "FORGE_FINAL_RESPONSE_CLASSIFIER_MAX_LATENCY_MS",
            )?,
            tool_output_compression: env_tool_output_compression()?,
            tool_call_policy: env_tool_call_policy()?,
        })
    }
}

pub(crate) fn apply_env_cli_overrides(config: &mut ProxyConfig, cli: &Cli) -> Result<(), String> {
    if let Some(host) = cli.host.as_deref() {
        config.host = validate_nonempty(host, "--host")?.to_string();
    }
    if let Some(port) = cli.port {
        config.port = validate_nonzero_u16(port, "--port")?;
    }
    if let Some(model) = cli.model.as_deref() {
        config.default_model = validate_nonempty(model, "--model")?.to_string();
        config.default_model_explicit = true;
    }
    if let Some(tokens) = cli.budget_tokens {
        config.context_tokens = validate_positive_i64(tokens, "--budget-tokens")?;
    }
    if let Some(max_retries) = cli.max_retries {
        config.max_retries = validate_nonnegative_i32(max_retries, "--max-retries")?;
    }
    if cli.no_rescue {
        config.rescue_enabled = false;
    }
    if cli.serialize {
        config.serialize_requests = true;
    } else if cli.no_serialize {
        config.serialize_requests = false;
    }
    if cli.verbose {
        config.verbose = true;
    }
    let (classifier_dir, classifier_mode, classifier_model, classifier_auto_download) =
        classifier_settings_from_env_cli(cli)?;
    config.classifier_dir = classifier_dir;
    config.classifier_mode = classifier_mode;
    config.classifier_model = classifier_model;
    config.classifier_auto_download = classifier_auto_download;
    if let Some(value) = cli.classifier_max_latency_ms {
        config.classifier_max_latency_ms = Some(value);
    }
    if let Some(dir) = cli.final_response_classifier_dir.as_deref() {
        config.final_response_classifier_dir =
            Some(validate_nonempty(dir, "--final-response-classifier-dir")?.to_string());
    }
    if let Some(mode) = cli.final_response_classifier_mode.as_deref() {
        config.final_response_classifier_mode = ScorerMode::from_str(mode)?;
    }
    if let Some(model) = cli.final_response_classifier_model.as_deref() {
        config.final_response_classifier_model = ClassifierModelKind::from_str(model)?;
    }
    if let Some(value) = cli.final_response_classifier_max_latency_ms {
        config.final_response_classifier_max_latency_ms = Some(value);
    }
    config.tool_output_compression = tool_output_compression_from_env_cli(cli)?;
    config.tool_call_policy = tool_call_policy_from_env_cli(cli)?;
    Ok(())
}

pub(crate) fn tool_output_compression_from_env_cli(
    cli: &Cli,
) -> Result<ToolOutputCompressionConfig, String> {
    let mut config = env_tool_output_compression()?;
    apply_tool_output_compression_cli_overrides(&mut config, cli)?;
    Ok(config)
}

pub(super) fn apply_tool_output_compression_cli_overrides(
    config: &mut ToolOutputCompressionConfig,
    cli: &Cli,
) -> Result<(), String> {
    if let Some(mode) = cli.tool_output_compression.as_deref() {
        config.mode = ToolOutputCompressionMode::from_str(mode)?;
    }
    if let Some(method) = cli.tool_output_compression_method.as_deref() {
        config.method = ToolOutputCompressionMethod::from_str(method)?;
    }
    Ok(())
}

pub(crate) fn tool_call_policy_from_env_cli(cli: &Cli) -> Result<ToolCallPolicyConfig, String> {
    let mut config = env_tool_call_policy()?;
    if let Some(mode) = cli.tool_call_policy.as_deref() {
        config = ToolCallPolicyConfig::from_mode(ToolCallPolicyMode::from_str(mode)?);
    }
    Ok(config)
}

pub(crate) fn classifier_settings_from_env_cli(
    cli: &Cli,
) -> Result<(Option<String>, ScorerMode, ClassifierModelKind, bool), String> {
    if cli.classify {
        #[cfg(not(feature = "classifier"))]
        {
            return Err("--classify requires building with --features classifier".to_string());
        }

        #[cfg(feature = "classifier")]
        {
            let dir = match cli.classifier_dir.as_deref() {
                Some(raw) => validate_nonempty(raw, "--classifier-dir")?.to_string(),
                None => default_tool_call_classifier_artifact_dir()
                    .map_err(|err| err.to_string())?
                    .to_string_lossy()
                    .into_owned(),
            };
            let mode = match cli.classifier_mode.as_deref() {
                Some(raw) => ScorerMode::from_str(raw)?,
                None => ScorerMode::Advisory,
            };
            if mode == ScorerMode::Disabled {
                return Err(
                    "--classify cannot be combined with --classifier-mode disabled".to_string(),
                );
            }
            let model = match cli.classifier_model.as_deref() {
                Some(raw) => ClassifierModelKind::from_str(raw)?,
                None => ClassifierModelKind::Quantized,
            };
            return Ok((Some(dir), mode, model, true));
        }
    }

    let mut dir = env_optional_string("FORGE_CLASSIFIER_DIR");
    let mut mode = env_scoring_mode("FORGE_CLASSIFIER_MODE", ScorerMode::Shadow)?;
    let mut model = env_classifier_model()?;

    if let Some(raw) = cli.classifier_dir.as_deref() {
        dir = Some(validate_nonempty(raw, "--classifier-dir")?.to_string());
    }
    if let Some(raw) = cli.classifier_mode.as_deref() {
        mode = ScorerMode::from_str(raw)?;
    }
    if let Some(raw) = cli.classifier_model.as_deref() {
        model = ClassifierModelKind::from_str(raw)?;
    }

    Ok((dir, mode, model, false))
}

pub(crate) fn final_response_classifier_settings_from_env_cli(
    cli: &Cli,
) -> Result<(Option<String>, ScorerMode, ClassifierModelKind), String> {
    let mut dir = env_optional_string("FORGE_FINAL_RESPONSE_CLASSIFIER_DIR");
    let mut mode = env_scoring_mode("FORGE_FINAL_RESPONSE_CLASSIFIER_MODE", ScorerMode::Shadow)?;
    let mut model = env_final_response_classifier_model()?;

    if let Some(raw) = cli.final_response_classifier_dir.as_deref() {
        dir = Some(validate_nonempty(raw, "--final-response-classifier-dir")?.to_string());
    }
    if let Some(raw) = cli.final_response_classifier_mode.as_deref() {
        mode = ScorerMode::from_str(raw)?;
    }
    if let Some(raw) = cli.final_response_classifier_model.as_deref() {
        model = ClassifierModelKind::from_str(raw)?;
    }

    Ok((dir, mode, model))
}