use std::sync::Arc;
#[cfg(feature = "classifier")]
use std::str::FromStr;
#[cfg(feature = "classifier")]
use forge_guardrails::{
default_tool_call_classifier_artifact_dir, ensure_classifier_artifact_dir,
ClassifierArtifactKind, ClassifierModelKind, FinalResponseContext, FinalResponseScore,
OnnxScorerOptions, ScoringContext, ToolCall, ToolCallScore, DEFAULT_CLASSIFIER_REPO,
DEFAULT_CLASSIFIER_REVISION,
};
use forge_guardrails::{FinalResponseScorer, ScorerMode, ToolCallScorer};
use crate::cli::Cli;
#[cfg(feature = "classifier")]
use crate::config::validate_nonempty;
use crate::config::ProxyConfig;
pub(super) fn prepare_classifier_artifact(config: &ProxyConfig) -> Result<(), String> {
if !config.classifier_auto_download {
return Ok(());
}
let Some(dir) = config.classifier_dir.as_deref() else {
return Err("--classify did not resolve a classifier artifact directory".to_string());
};
#[cfg(feature = "classifier")]
{
ensure_classifier_artifact_dir(
ClassifierArtifactKind::ToolCall,
dir,
DEFAULT_CLASSIFIER_REPO,
DEFAULT_CLASSIFIER_REVISION,
config.classifier_model,
|line| eprintln!("{line}"),
)
.map_err(|err| format!("failed to prepare classifier artifact: {err}"))?;
Ok(())
}
#[cfg(not(feature = "classifier"))]
{
let _ = dir;
Err("--classify requires building with --features classifier".to_string())
}
}
pub(super) fn download_classifier_shortcut(cli: &Cli) -> Result<(), String> {
#[cfg(feature = "classifier")]
{
let artifact_dir = match cli.classifier_dir.as_deref() {
Some(raw) => validate_nonempty(raw, "--classifier-dir")?.to_string(),
None => default_tool_call_classifier_artifact_dir()
.map_err(|err| err.to_string())?
.to_string_lossy()
.into_owned(),
};
let model = match cli.classifier_model.as_deref() {
Some(raw) => ClassifierModelKind::from_str(raw)?,
None => ClassifierModelKind::Quantized,
};
ensure_classifier_artifact_dir(
ClassifierArtifactKind::ToolCall,
artifact_dir,
DEFAULT_CLASSIFIER_REPO,
DEFAULT_CLASSIFIER_REVISION,
model,
|line| println!("{line}"),
)
.map_err(|err| format!("failed to download classifier artifact: {err}"))?;
Ok(())
}
#[cfg(not(feature = "classifier"))]
{
let _ = cli;
Err("--classify-download requires building with --features classifier".to_string())
}
}
#[cfg(feature = "classifier")]
const FORGE_CLASSIFIER_SESSION_POOL_SIZE: &str = "FORGE_CLASSIFIER_SESSION_POOL_SIZE";
#[cfg(feature = "classifier")]
const FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE: &str =
"FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE";
#[cfg(feature = "classifier")]
const FORGE_CLASSIFIER_INTRA_THREADS: &str = "FORGE_CLASSIFIER_INTRA_THREADS";
#[cfg(feature = "classifier")]
const FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS: &str =
"FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS";
pub(super) fn build_classifier_scorer(
config: &ProxyConfig,
) -> Result<Option<Arc<dyn ToolCallScorer>>, String> {
if config.classifier_mode == ScorerMode::Disabled {
return Ok(None);
}
let Some(dir) = config.classifier_dir.as_deref() else {
return Ok(None);
};
#[cfg(feature = "classifier")]
{
let options = onnx_options_from_env(
FORGE_CLASSIFIER_SESSION_POOL_SIZE,
FORGE_CLASSIFIER_INTRA_THREADS,
"classifier",
)?;
let scorer = forge_guardrails::OnnxToolCallScorer::from_dir_with_model_and_options(
dir,
Some(config.classifier_mode),
config.classifier_model,
options,
)
.map_err(|err| format!("failed to load classifier artifact: {err}"))?;
let scorer: Arc<dyn ToolCallScorer> = Arc::new(scorer);
Ok(Some(wrap_tool_latency_warning(
scorer,
config.classifier_max_latency_ms,
)))
}
#[cfg(not(feature = "classifier"))]
{
let _ = dir;
Err("classifier support requires building with --features classifier".to_string())
}
}
pub(super) fn build_final_response_classifier_scorer(
config: &ProxyConfig,
) -> Result<Option<Arc<dyn FinalResponseScorer>>, String> {
if config.final_response_classifier_mode == ScorerMode::Disabled {
return Ok(None);
}
let Some(dir) = config.final_response_classifier_dir.as_deref() else {
return Ok(None);
};
#[cfg(feature = "classifier")]
{
let options = onnx_options_from_env(
FORGE_FINAL_RESPONSE_CLASSIFIER_SESSION_POOL_SIZE,
FORGE_FINAL_RESPONSE_CLASSIFIER_INTRA_THREADS,
"final-response classifier",
)?;
let scorer = forge_guardrails::OnnxFinalResponseScorer::from_dir_with_model_and_options(
dir,
Some(config.final_response_classifier_mode),
config.final_response_classifier_model,
options,
)
.map_err(|err| format!("failed to load final-response classifier artifact: {err}"))?;
let scorer: Arc<dyn FinalResponseScorer> = Arc::new(scorer);
Ok(Some(wrap_final_response_latency_warning(
scorer,
config.final_response_classifier_max_latency_ms,
)))
}
#[cfg(not(feature = "classifier"))]
{
let _ = dir;
Err(
"final-response classifier support requires building with --features classifier"
.to_string(),
)
}
}
#[cfg(feature = "classifier")]
fn onnx_options_from_env(
pool_var: &str,
intra_var: &str,
label: &str,
) -> Result<OnnxScorerOptions, String> {
let mut options = OnnxScorerOptions::default();
if let Some(value) = optional_usize_env(pool_var)? {
options.session_pool_size = value;
}
if let Some(value) = optional_usize_env(intra_var)? {
options.intra_threads = value;
}
options
.validate()
.map_err(|err| format!("invalid {label} ONNX runtime options: {err}"))
}
#[cfg(feature = "classifier")]
fn optional_usize_env(name: &str) -> Result<Option<usize>, String> {
match std::env::var(name) {
Ok(raw) if raw.trim().is_empty() => Ok(None),
Ok(raw) => raw
.parse::<usize>()
.map(Some)
.map_err(|err| format!("{name} must be a positive integer: {err}")),
Err(std::env::VarError::NotPresent) => Ok(None),
Err(err) => Err(format!("failed to read {name}: {err}")),
}
}
#[cfg(feature = "classifier")]
struct ToolCallLatencyWarningScorer {
inner: Arc<dyn ToolCallScorer>,
max_latency_ms: u64,
}
#[cfg(feature = "classifier")]
impl ToolCallScorer for ToolCallLatencyWarningScorer {
fn score(&self, ctx: &ScoringContext, candidate: &ToolCall) -> anyhow::Result<ToolCallScore> {
let score = self.inner.score(ctx, candidate)?;
if score.latency_ms > self.max_latency_ms as f64 {
tracing::warn!(
target: "forge.classifier",
tool = %candidate.tool,
latency_ms = score.latency_ms,
max_latency_ms = self.max_latency_ms,
"tool-call classifier latency exceeded configured warning limit"
);
}
Ok(score)
}
}
#[cfg(feature = "classifier")]
fn wrap_tool_latency_warning(
scorer: Arc<dyn ToolCallScorer>,
max_latency_ms: Option<u64>,
) -> Arc<dyn ToolCallScorer> {
match max_latency_ms {
Some(max_latency_ms) => Arc::new(ToolCallLatencyWarningScorer {
inner: scorer,
max_latency_ms,
}),
None => scorer,
}
}
#[cfg(feature = "classifier")]
struct FinalResponseLatencyWarningScorer {
inner: Arc<dyn FinalResponseScorer>,
max_latency_ms: u64,
}
#[cfg(feature = "classifier")]
impl FinalResponseScorer for FinalResponseLatencyWarningScorer {
fn score(&self, ctx: &FinalResponseContext) -> anyhow::Result<FinalResponseScore> {
let score = self.inner.score(ctx)?;
if score.latency_ms > self.max_latency_ms as f64 {
tracing::warn!(
target: "forge.classifier",
latency_ms = score.latency_ms,
max_latency_ms = self.max_latency_ms,
"final-response classifier latency exceeded configured warning limit"
);
}
Ok(score)
}
}
#[cfg(feature = "classifier")]
fn wrap_final_response_latency_warning(
scorer: Arc<dyn FinalResponseScorer>,
max_latency_ms: Option<u64>,
) -> Arc<dyn FinalResponseScorer> {
match max_latency_ms {
Some(max_latency_ms) => Arc::new(FinalResponseLatencyWarningScorer {
inner: scorer,
max_latency_ms,
}),
None => scorer,
}
}