forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use std::sync::Arc;

use serde_json::{json, Value};

use crate::clients::base::ToolCall;
use crate::core::message::{Message, MessageRole, MessageType};
use crate::core::tool_spec::ToolSpec;
use crate::guardrails::{
    final_response_top_k_from_logits, recent_errors_from_messages, tool_call_top_k_from_logits,
    FinalResponseContext, FinalResponseScorer, FinalResponseToolResult, ScoringContext,
    ScoringPipeline, StepEnforcer, ToolCallScorer, WorkflowStateForScoring,
};
use crate::tools::respond::RESPOND_TOOL_NAME;

use super::classifier_log::{emit_proxy_classifier_jsonl, proxy_tool_call_for_json, unix_ms};
use super::telemetry::{
    capture_final_response_classifier_non_allow, capture_tool_call_classifier_non_allow,
};

pub(super) async fn score_proxy_tool_calls(
    scorer: Option<Arc<dyn ToolCallScorer>>,
    messages: &[Message],
    tool_calls: &[ToolCall],
    step_enforcer: Option<&StepEnforcer>,
    tool_specs: &[ToolSpec],
) -> Option<String> {
    let scorer = scorer?;
    let pipeline = ScoringPipeline::new(Some(scorer), None);
    let user_request = latest_proxy_user_request(messages).unwrap_or_default();
    let recent_errors = recent_errors_from_messages(messages, 8);
    let ctx = Arc::new(match step_enforcer {
        Some(enforcer) => ScoringContext::from_step_enforcer(
            user_request,
            enforcer,
            &enforcer.terminal_tools,
            recent_errors,
            tool_specs,
        ),
        None => ScoringContext::new(
            user_request,
            Vec::new(),
            Vec::new(),
            Vec::new(),
            proxy_terminal_tools_for_scoring(tool_specs),
            recent_errors,
            tool_specs,
        ),
    });

    let ctx_for_log = ctx.clone();
    pipeline
        .score_tool_calls(
            ctx,
            tool_calls,
            |call, score| {
                capture_tool_call_classifier_non_allow(call, score);
                tracing::info!(
                    target: "forge.classifier",
                    label = ?score.label,
                    confidence = score.confidence,
                    action = ?score.action,
                    latency_ms = score.latency_ms,
                    tool = %call.tool,
                    "tool-call classifier score"
                );
                emit_proxy_classifier_jsonl(json!({
                    "kind": "tool_call",
                    "unix_ms": unix_ms(),
                    "user_request": ctx_for_log.user_request.as_str(),
                    "initial_user_request": initial_proxy_user_request(messages).unwrap_or_default(),
                    "workflow_state": &ctx_for_log.workflow_state,
                    "candidate_call": proxy_tool_call_for_json(call),
                    "tool": call.tool.as_str(),
                    "label": score.label.as_label().as_ref(),
                    "confidence": score.confidence,
                    "top_k": tool_call_top_k_from_logits(&score.logits),
                    "action": score.action.as_str(),
                    "latency_ms": score.latency_ms,
                    "model_version": score.model_version.as_str(),
                }));
            },
            |call, err| {
                tracing::warn!(
                    target: "forge.classifier",
                    error = %err,
                    tool = %call.tool,
                    "classifier scoring failed; allowing deterministic path"
                );
            },
        )
        .await
}

pub(super) async fn score_proxy_final_tool_calls(
    scorer: Option<Arc<dyn FinalResponseScorer>>,
    messages: &[Message],
    tool_calls: &[ToolCall],
    step_enforcer: Option<&StepEnforcer>,
    tool_specs: &[ToolSpec],
) -> Option<String> {
    let terminal_tools = proxy_terminal_tool_set(step_enforcer, tool_specs);
    let mut nudge = None;
    for call in tool_calls
        .iter()
        .filter(|call| terminal_tools.contains(call.tool.as_str()))
    {
        let candidate = proxy_candidate_final_response_from_call(call);
        let mut trace = proxy_tool_trace_from_messages(messages);
        trace.push(call.tool.clone());
        if let Some(content) = score_proxy_final_candidate(
            scorer.clone(),
            messages,
            &candidate,
            trace,
            step_enforcer,
            tool_specs,
            Some(call.tool.as_str()),
        )
        .await
        {
            nudge = Some(content);
            break;
        }
    }
    nudge
}

pub(super) async fn score_proxy_final_text(
    scorer: Option<Arc<dyn FinalResponseScorer>>,
    messages: &[Message],
    candidate: &str,
    step_enforcer: Option<&StepEnforcer>,
    tool_specs: &[ToolSpec],
) -> Option<String> {
    score_proxy_final_candidate(
        scorer,
        messages,
        candidate,
        proxy_tool_trace_from_messages(messages),
        step_enforcer,
        tool_specs,
        None,
    )
    .await
}

async fn score_proxy_final_candidate(
    scorer: Option<Arc<dyn FinalResponseScorer>>,
    messages: &[Message],
    candidate: &str,
    tool_trace: Vec<String>,
    step_enforcer: Option<&StepEnforcer>,
    tool_specs: &[ToolSpec],
    terminal_tool: Option<&str>,
) -> Option<String> {
    let scorer = scorer?;
    let pipeline = ScoringPipeline::new(None, Some(scorer));
    let user_request = latest_proxy_user_request(messages).unwrap_or_default();
    let workflow_state = match step_enforcer {
        Some(enforcer) => WorkflowStateForScoring {
            required_steps: enforcer.tracker.required_steps().to_vec(),
            completed_steps: enforcer.completed_steps().keys().cloned().collect(),
            pending_steps: enforcer.pending(),
            terminal_tools: enforcer.terminal_tools.iter().cloned().collect(),
            recent_errors: recent_errors_from_messages(messages, 8),
        },
        None => WorkflowStateForScoring {
            required_steps: Vec::new(),
            completed_steps: Vec::new(),
            pending_steps: Vec::new(),
            terminal_tools: proxy_terminal_tools_for_scoring(tool_specs),
            recent_errors: recent_errors_from_messages(messages, 8),
        },
    };
    let ctx = Arc::new(FinalResponseContext {
        schema_version: "final-response-verifier-input/v1".to_string(),
        user_request: user_request.to_string(),
        workflow_state,
        required_facts: Vec::new(),
        tool_trace,
        tool_results: proxy_tool_results_from_messages(messages),
        candidate_final_response: candidate.to_string(),
        metadata: None,
    });
    let ctx_for_log = ctx.clone();
    let terminal_tool_name = terminal_tool.unwrap_or("text");
    pipeline
        .score_final_response(
            ctx,
            |score| {
                capture_final_response_classifier_non_allow(terminal_tool_name, score);
                tracing::info!(
                    target: "forge.classifier",
                    label = %score.label.as_label(),
                    confidence = score.confidence,
                    action = %score.action.as_str(),
                    latency_ms = score.latency_ms,
                    terminal_tool = terminal_tool_name,
                    "final-response classifier score"
                );
                emit_proxy_classifier_jsonl(json!({
                    "kind": "final_response",
                    "unix_ms": unix_ms(),
                    "user_request": ctx_for_log.user_request.as_str(),
                    "initial_user_request": initial_proxy_user_request(messages).unwrap_or_default(),
                    "workflow_state": &ctx_for_log.workflow_state,
                    "required_facts": &ctx_for_log.required_facts,
                    "tool_trace": &ctx_for_log.tool_trace,
                    "tool_results": ctx_for_log.tool_results.iter().map(|result| {
                        json!({"tool_name": result.tool_name.as_str(), "content": result.content.as_str()})
                    }).collect::<Vec<_>>(),
                    "candidate_final_response": ctx_for_log.candidate_final_response.as_str(),
                    "terminal_tool": terminal_tool_name,
                    "label": score.label.as_label().as_ref(),
                    "confidence": score.confidence,
                    "top_k": final_response_top_k_from_logits(&score.logits),
                    "action": score.action.as_str(),
                    "latency_ms": score.latency_ms,
                    "model_version": score.model_version.as_str(),
                }));
            },
            |err| {
                tracing::warn!(
                    target: "forge.classifier",
                    error = %err,
                    terminal_tool = terminal_tool_name,
                    "final-response classifier scoring failed; allowing deterministic path"
                );
            },
        )
        .await
}

fn proxy_terminal_tool_set<'a>(
    step_enforcer: Option<&'a StepEnforcer>,
    tool_specs: &'a [ToolSpec],
) -> std::collections::HashSet<&'a str> {
    match step_enforcer {
        Some(enforcer) => enforcer.terminal_tools.iter().map(String::as_str).collect(),
        None => tool_specs
            .iter()
            .filter(|spec| spec.name == RESPOND_TOOL_NAME)
            .map(|spec| spec.name.as_str())
            .collect(),
    }
}

fn proxy_candidate_final_response_from_call(call: &ToolCall) -> String {
    for key in ["message", "answer", "content", "report", "summary"] {
        if let Some(value) = call.args.get(key) {
            return value
                .as_str()
                .map(str::to_string)
                .unwrap_or_else(|| value.to_string());
        }
    }
    Value::Object(
        call.args
            .iter()
            .map(|(key, value)| (key.clone(), value.clone()))
            .collect(),
    )
    .to_string()
}

fn proxy_tool_trace_from_messages(messages: &[Message]) -> Vec<String> {
    messages
        .iter()
        .filter_map(|message| message.tool_calls.as_ref())
        .flat_map(|calls| calls.iter().map(|call| call.name.clone()))
        .collect()
}

fn proxy_tool_results_from_messages(messages: &[Message]) -> Vec<FinalResponseToolResult> {
    messages
        .iter()
        .filter(|message| message.metadata.msg_type == MessageType::ToolResult)
        .filter_map(|message| {
            Some(FinalResponseToolResult {
                tool_name: message.tool_name.clone()?,
                content: message.content.clone(),
            })
        })
        .collect()
}

fn latest_proxy_user_request(messages: &[Message]) -> Option<&str> {
    messages
        .iter()
        .rev()
        .find(|message| message.role == MessageRole::User)
        .map(|message| message.content.as_str())
}

fn initial_proxy_user_request(messages: &[Message]) -> Option<&str> {
    messages
        .iter()
        .find(|message| {
            message.role == MessageRole::User
                && message.metadata.msg_type == MessageType::UserInput
                && !message.content.trim().is_empty()
        })
        .map(|message| message.content.as_str())
}

fn proxy_terminal_tools_for_scoring(tool_specs: &[ToolSpec]) -> Vec<String> {
    tool_specs
        .iter()
        .filter(|spec| spec.name == RESPOND_TOOL_NAME)
        .map(|spec| spec.name.clone())
        .collect()
}