use std::borrow::Cow;
use std::fmt;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use serde::Serialize;
use crate::clients::base::ToolCall;
use crate::guardrails::classifier_artifact::{
EXPECTED_LABELS, FINAL_RESPONSE_EXPECTED_LABELS, LEGACY_EXPECTED_LABELS,
};
use crate::guardrails::scoring_context::{
ScoringContext, ScoringMetadata, WorkflowStateForScoring,
};
static DEFAULT_SCORING_EXECUTOR: LazyLock<ScoringExecutor> =
LazyLock::new(ScoringExecutor::default);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ScorerMode {
Disabled,
#[default]
Shadow,
Advisory,
Enforce,
}
impl ScorerMode {
pub fn as_str(self) -> &'static str {
match self {
Self::Disabled => "disabled",
Self::Shadow => "shadow",
Self::Advisory => "advisory",
Self::Enforce => "enforce",
}
}
}
impl fmt::Display for ScorerMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for ScorerMode {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.trim().to_ascii_lowercase().as_str() {
"disabled" => Ok(Self::Disabled),
"shadow" => Ok(Self::Shadow),
"advisory" => Ok(Self::Advisory),
"enforce" => Ok(Self::Enforce),
other => Err(format!(
"classifier mode must be disabled, shadow, advisory, or enforce, got '{other}'"
)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolCallClass {
Valid,
WrongToolSemantic,
WrongArgumentsSemantic,
ToolNotNeeded,
NeedsClarification,
DeterministicInvalid,
Unknown(String),
}
impl ToolCallClass {
pub fn as_label(&self) -> Cow<'_, str> {
match self {
Self::Valid => Cow::Borrowed("valid"),
Self::WrongToolSemantic => Cow::Borrowed("wrong_tool_semantic"),
Self::WrongArgumentsSemantic => Cow::Borrowed("wrong_arguments_semantic"),
Self::ToolNotNeeded => Cow::Borrowed("tool_not_needed"),
Self::NeedsClarification => Cow::Borrowed("needs_clarification"),
Self::DeterministicInvalid => Cow::Borrowed("deterministic_invalid"),
Self::Unknown(label) => Cow::Borrowed(label.as_str()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct ClassifierTopKEntry {
pub label: String,
pub confidence: f32,
pub logit: f32,
}
pub fn tool_call_top_k_from_logits(logits: &[f32]) -> Vec<ClassifierTopKEntry> {
if logits.len() == EXPECTED_LABELS.len() {
top_k_from_logits(&EXPECTED_LABELS, logits)
} else if logits.len() == LEGACY_EXPECTED_LABELS.len() {
top_k_from_logits(&LEGACY_EXPECTED_LABELS, logits)
} else {
Vec::new()
}
}
pub fn final_response_top_k_from_logits(logits: &[f32]) -> Vec<ClassifierTopKEntry> {
if logits.len() == FINAL_RESPONSE_EXPECTED_LABELS.len() {
top_k_from_logits(&FINAL_RESPONSE_EXPECTED_LABELS, logits)
} else {
Vec::new()
}
}
fn top_k_from_logits(labels: &[&str], logits: &[f32]) -> Vec<ClassifierTopKEntry> {
let probs = softmax_for_telemetry(logits);
let mut entries = labels
.iter()
.zip(logits.iter())
.zip(probs.iter())
.map(|((label, logit), confidence)| ClassifierTopKEntry {
label: (*label).to_string(),
confidence: *confidence,
logit: *logit,
})
.collect::<Vec<_>>();
entries.sort_by(|left, right| {
right
.confidence
.total_cmp(&left.confidence)
.then_with(|| left.label.cmp(&right.label))
});
entries.truncate(entries.len().min(8));
entries
}
fn softmax_for_telemetry(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps = logits
.iter()
.map(|logit| (*logit - max).exp())
.collect::<Vec<_>>();
let sum: f32 = exps.iter().sum();
if sum == 0.0 || !sum.is_finite() {
return vec![0.0; logits.len()];
}
exps.into_iter().map(|value| value / sum).collect()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClassifierAction {
Allow,
ShadowOnly,
AdvisoryNudge,
Block,
}
impl ClassifierAction {
pub fn as_str(self) -> &'static str {
match self {
Self::Allow => "allow",
Self::ShadowOnly => "shadow_only",
Self::AdvisoryNudge => "advisory_nudge",
Self::Block => "block",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ToolCallScore {
pub label: ToolCallClass,
pub confidence: f32,
pub logits: Vec<f32>,
pub action: ClassifierAction,
pub model_version: String,
pub latency_ms: f64,
}
pub trait ToolCallScorer: Send + Sync {
fn score(&self, ctx: &ScoringContext, candidate: &ToolCall) -> anyhow::Result<ToolCallScore>;
}
#[derive(Clone)]
pub struct ScoringExecutor {
semaphore: Arc<tokio::sync::Semaphore>,
}
impl ScoringExecutor {
pub fn new(max_concurrency: usize) -> Self {
Self {
semaphore: Arc::new(tokio::sync::Semaphore::new(max_concurrency.max(1))),
}
}
pub fn default_concurrency() -> usize {
std::thread::available_parallelism()
.map(|parallelism| parallelism.get())
.unwrap_or(1)
.clamp(1, 4)
}
pub async fn score_tool_call_async(
&self,
scorer: Arc<dyn ToolCallScorer>,
ctx: Arc<ScoringContext>,
candidate: ToolCall,
) -> anyhow::Result<ToolCallScore> {
self.run_blocking("classifier scoring task failed", move || {
scorer.score(&ctx, &candidate)
})
.await
}
async fn run_blocking<T, F>(&self, task_error: &'static str, task: F) -> anyhow::Result<T>
where
T: Send + 'static,
F: FnOnce() -> anyhow::Result<T> + Send + 'static,
{
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.map_err(|err| anyhow::anyhow!("classifier scoring semaphore closed: {err}"))?;
tokio::task::spawn_blocking(move || {
let _permit = permit;
task()
})
.await
.map_err(|err| anyhow::anyhow!("{task_error}: {err}"))?
}
}
impl Default for ScoringExecutor {
fn default() -> Self {
Self::new(Self::default_concurrency())
}
}
#[derive(Clone)]
pub struct ScoringPipeline {
tool_call_scorer: Option<Arc<dyn ToolCallScorer>>,
final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
executor: ScoringExecutor,
}
impl ScoringPipeline {
pub fn new(
tool_call_scorer: Option<Arc<dyn ToolCallScorer>>,
final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
) -> Self {
Self {
tool_call_scorer,
final_response_scorer,
executor: (*DEFAULT_SCORING_EXECUTOR).clone(),
}
}
pub fn with_executor(
tool_call_scorer: Option<Arc<dyn ToolCallScorer>>,
final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
executor: ScoringExecutor,
) -> Self {
Self {
tool_call_scorer,
final_response_scorer,
executor,
}
}
pub async fn score_tool_calls<F, E>(
&self,
ctx: Arc<ScoringContext>,
candidates: &[ToolCall],
mut on_score: F,
mut on_error: E,
) -> Option<String>
where
F: FnMut(&ToolCall, &ToolCallScore),
E: FnMut(&ToolCall, &anyhow::Error),
{
let scorer = self.tool_call_scorer.clone()?;
let mut nudge = None;
for candidate in candidates {
match self
.executor
.score_tool_call_async(scorer.clone(), ctx.clone(), candidate.clone())
.await
{
Ok(score) => {
on_score(candidate, &score);
if matches!(
score.action,
ClassifierAction::AdvisoryNudge | ClassifierAction::Block
) {
let content =
crate::prompts::classifier_nudge(score.label.as_label().as_ref());
if score.action == ClassifierAction::Block || nudge.is_none() {
nudge = Some(content);
}
}
}
Err(err) => on_error(candidate, &err),
}
}
nudge
}
pub async fn score_final_response<F, E>(
&self,
ctx: Arc<FinalResponseContext>,
mut on_score: F,
mut on_error: E,
) -> Option<String>
where
F: FnMut(&FinalResponseScore),
E: FnMut(&anyhow::Error),
{
let scorer = self.final_response_scorer.clone()?;
match self.executor.score_final_response_async(scorer, ctx).await {
Ok(score) => {
on_score(&score);
if matches!(
score.action,
ClassifierAction::AdvisoryNudge | ClassifierAction::Block
) {
Some(crate::prompts::classifier_nudge(
score.label.as_label().as_ref(),
))
} else {
None
}
}
Err(err) => {
on_error(&err);
None
}
}
}
}
impl Default for ScoringPipeline {
fn default() -> Self {
Self::new(None, None)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FinalResponseToolResult {
pub tool_name: String,
pub content: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FinalResponseContext {
pub schema_version: String,
pub user_request: String,
pub workflow_state: WorkflowStateForScoring,
pub required_facts: Vec<String>,
pub tool_trace: Vec<String>,
pub tool_results: Vec<FinalResponseToolResult>,
pub candidate_final_response: String,
pub metadata: Option<ScoringMetadata>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FinalResponseClass {
ValidFinalResponse,
MissingToolFact,
ContradictsToolResult,
UnsupportedClaim,
FailedToAcknowledgeDataGap,
Unknown(String),
}
impl FinalResponseClass {
pub fn as_label(&self) -> Cow<'_, str> {
match self {
Self::ValidFinalResponse => Cow::Borrowed("valid_final_response"),
Self::MissingToolFact => Cow::Borrowed("missing_tool_fact"),
Self::ContradictsToolResult => Cow::Borrowed("contradicts_tool_result"),
Self::UnsupportedClaim => Cow::Borrowed("unsupported_claim"),
Self::FailedToAcknowledgeDataGap => Cow::Borrowed("failed_to_acknowledge_data_gap"),
Self::Unknown(label) => Cow::Borrowed(label.as_str()),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FinalResponseScore {
pub label: FinalResponseClass,
pub confidence: f32,
pub logits: Vec<f32>,
pub action: ClassifierAction,
pub model_version: String,
pub latency_ms: f64,
}
pub trait FinalResponseScorer: Send + Sync {
fn score(&self, ctx: &FinalResponseContext) -> anyhow::Result<FinalResponseScore>;
}
impl ScoringExecutor {
pub async fn score_final_response_async(
&self,
scorer: Arc<dyn FinalResponseScorer>,
ctx: Arc<FinalResponseContext>,
) -> anyhow::Result<FinalResponseScore> {
self.run_blocking("final-response scoring task failed", move || {
scorer.score(&ctx)
})
.await
}
}
pub async fn score_tool_call_async(
scorer: Arc<dyn ToolCallScorer>,
ctx: Arc<ScoringContext>,
candidate: ToolCall,
) -> anyhow::Result<ToolCallScore> {
DEFAULT_SCORING_EXECUTOR
.score_tool_call_async(scorer, ctx, candidate)
.await
}
pub async fn score_final_response_async(
scorer: Arc<dyn FinalResponseScorer>,
ctx: Arc<FinalResponseContext>,
) -> anyhow::Result<FinalResponseScore> {
DEFAULT_SCORING_EXECUTOR
.score_final_response_async(scorer, ctx)
.await
}
pub fn serialize_final_response_state_v1(ctx: &FinalResponseContext) -> String {
let ws = &ctx.workflow_state;
let results = ctx
.tool_results
.iter()
.map(|result| format!("{}: {}", result.tool_name, json_string(&result.content)))
.collect::<Vec<_>>()
.join("\n");
let metadata = ctx.metadata.as_ref();
format!(
"SCHEMA_VERSION:\n{}\n\nUSER_REQUEST:\n{}\n\nWORKFLOW_STATE:\nrequired_steps={}\ncompleted_steps={}\npending_steps={}\nterminal_tools={}\nrecent_errors={}\n\nREQUIRED_FACTS:\n{}\n\nTOOL_TRACE:\n{}\n\nTOOL_RESULTS:\n{}\n\nCANDIDATE_FINAL_RESPONSE:\n{}\n\nSCORING_METADATA:\nscenario_family={}\nrequires_transform={}\nrequires_synthesis={}\nrequires_all_tool_facts={}\nmust_acknowledge_missing_data={}",
ctx.schema_version,
ctx.user_request,
py_list(&ws.required_steps),
py_list(&ws.completed_steps),
py_list(&ws.pending_steps),
py_list(&ws.terminal_tools),
py_list(&ws.recent_errors),
py_list(&ctx.required_facts),
py_list(&ctx.tool_trace),
results,
ctx.candidate_final_response,
optional_json_string(metadata.and_then(|value| value.scenario_family.as_deref())),
optional_json_bool(metadata.and_then(|value| value.requires_transform)),
optional_json_bool(metadata.and_then(|value| value.requires_synthesis)),
optional_json_bool(metadata.and_then(|value| value.requires_all_tool_facts)),
optional_json_bool(metadata.and_then(|value| value.must_acknowledge_missing_data)),
)
}
fn py_list(values: &[String]) -> String {
if values.is_empty() {
return "[]".to_string();
}
let body = values
.iter()
.map(|value| format!("'{}'", value.replace('\\', "\\\\").replace('\'', "\\'")))
.collect::<Vec<_>>()
.join(", ");
format!("[{body}]")
}
fn json_string(value: &str) -> String {
serde_json::to_string(value).unwrap_or_else(|_| "\"\"".to_string())
}
fn optional_json_string(value: Option<&str>) -> String {
value.map(json_string).unwrap_or_else(|| "null".to_string())
}
fn optional_json_bool(value: Option<bool>) -> &'static str {
match value {
Some(true) => "true",
Some(false) => "false",
None => "null",
}
}
#[derive(Debug, Default)]
pub struct NoopToolCallScorer;
impl ToolCallScorer for NoopToolCallScorer {
fn score(&self, _ctx: &ScoringContext, _candidate: &ToolCall) -> anyhow::Result<ToolCallScore> {
Ok(ToolCallScore {
label: ToolCallClass::Valid,
confidence: 1.0,
logits: Vec::new(),
action: ClassifierAction::Allow,
model_version: "noop".to_string(),
latency_ms: 0.0,
})
}
}
#[derive(Debug, Default)]
pub struct NoopFinalResponseScorer;
impl FinalResponseScorer for NoopFinalResponseScorer {
fn score(&self, _ctx: &FinalResponseContext) -> anyhow::Result<FinalResponseScore> {
Ok(FinalResponseScore {
label: FinalResponseClass::ValidFinalResponse,
confidence: 1.0,
logits: Vec::new(),
action: ClassifierAction::Allow,
model_version: "noop".to_string(),
latency_ms: 0.0,
})
}
}