forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use crate::clients::base::{LLMClient, ToolCall};
use crate::core::message::{Message, MessageType};
use crate::core::tool_spec::ToolSpec;
use crate::guardrails::{
    recent_errors_from_messages, ClassifierAction, FinalResponseContext, FinalResponseToolResult,
    ScoringContext, ScoringPipeline, StepEnforcer, WorkflowStateForScoring,
};
use serde_json::Value;
use std::sync::Arc;

use super::{latest_user_request, WorkflowRunner};

impl<C: LLMClient> WorkflowRunner<C> {
    pub(super) async fn score_tool_calls(
        &self,
        fallback_user_message: &str,
        messages: &[Message],
        tool_calls: &[ToolCall],
        step_enforcer: &StepEnforcer,
        tool_specs: &[ToolSpec],
    ) -> Option<String> {
        let scorer = self.scorer.clone()?;
        let pipeline = ScoringPipeline::new(Some(scorer), None);
        let user_request = latest_user_request(messages).unwrap_or(fallback_user_message);
        let ctx = Arc::new(ScoringContext::from_step_enforcer(
            user_request,
            step_enforcer,
            &step_enforcer.terminal_tools,
            recent_errors_from_messages(messages, 8),
            tool_specs,
        ));
        pipeline
            .score_tool_calls(
                ctx,
                tool_calls,
                |call, score| {
                    tracing::info!(
                        target: "forge.classifier",
                        label = ?score.label,
                        confidence = score.confidence,
                        action = ?score.action,
                        latency_ms = score.latency_ms,
                        tool = %call.tool,
                        "tool-call classifier score"
                    );
                    if let Some(callback) = &self.on_tool_call_score {
                        callback(call, score);
                    }
                },
                |call, err| {
                    tracing::warn!(
                        target: "forge.classifier",
                        error = %err,
                        tool = %call.tool,
                        "classifier scoring failed; allowing deterministic path"
                    );
                },
            )
            .await
    }

    pub(super) async fn score_candidate_final_responses(
        &self,
        fallback_user_message: &str,
        messages: &[Message],
        tool_calls: &[ToolCall],
        step_enforcer: &StepEnforcer,
    ) -> Option<String> {
        let scorer = self.final_response_scorer.clone()?;
        let pipeline = ScoringPipeline::new(None, Some(scorer));
        let user_request = latest_user_request(messages).unwrap_or(fallback_user_message);
        let base_trace = Self::tool_trace_from_messages(messages);
        let tool_results = Self::tool_results_from_messages(messages);
        let mut nudge: Option<String> = None;
        for call in tool_calls
            .iter()
            .filter(|call| step_enforcer.terminal_tools.contains(&call.tool))
        {
            let mut tool_trace = base_trace.clone();
            tool_trace.push(call.tool.clone());
            let ctx = Arc::new(FinalResponseContext {
                schema_version: "final-response-verifier-input/v1".to_string(),
                user_request: user_request.to_string(),
                workflow_state: WorkflowStateForScoring {
                    required_steps: step_enforcer.tracker.required_steps().to_vec(),
                    completed_steps: step_enforcer.completed_steps().keys().cloned().collect(),
                    pending_steps: step_enforcer.pending(),
                    terminal_tools: step_enforcer.terminal_tools.iter().cloned().collect(),
                    recent_errors: recent_errors_from_messages(messages, 8),
                },
                required_facts: Vec::new(),
                tool_trace,
                tool_results: tool_results.clone(),
                candidate_final_response: Self::candidate_final_response_from_call(call),
                metadata: None,
            });
            let mut action = None;
            if let Some(content) = pipeline
                .score_final_response(
                    ctx,
                    |score| {
                        action = Some(score.action);
                        tracing::info!(
                            target: "forge.classifier",
                            label = %score.label.as_label(),
                            confidence = score.confidence,
                            action = %score.action.as_str(),
                            latency_ms = score.latency_ms,
                            terminal_tool = %call.tool,
                            "final-response classifier score"
                        );
                        if let Some(callback) = &self.on_final_response_score {
                            callback(score);
                        }
                    },
                    |err| {
                        tracing::warn!(
                            target: "forge.classifier",
                            error = %err,
                            terminal_tool = %call.tool,
                            "final-response classifier scoring failed; allowing deterministic path"
                        );
                    },
                )
                .await
            {
                if action == Some(ClassifierAction::Block) || nudge.is_none() {
                    nudge = Some(content);
                }
            }
        }
        nudge
    }

    fn candidate_final_response_from_call(call: &ToolCall) -> String {
        for key in ["message", "answer", "content", "report", "summary"] {
            if let Some(value) = call.args.get(key) {
                return value
                    .as_str()
                    .map(str::to_string)
                    .unwrap_or_else(|| value.to_string());
            }
        }
        Value::Object(
            call.args
                .iter()
                .map(|(key, value)| (key.clone(), value.clone()))
                .collect(),
        )
        .to_string()
    }

    fn tool_trace_from_messages(messages: &[Message]) -> Vec<String> {
        messages
            .iter()
            .filter_map(|message| message.tool_calls.as_ref())
            .flat_map(|calls| calls.iter().map(|call| call.name.clone()))
            .collect()
    }

    fn tool_results_from_messages(messages: &[Message]) -> Vec<FinalResponseToolResult> {
        messages
            .iter()
            .filter(|message| message.metadata.msg_type == MessageType::ToolResult)
            .filter_map(|message| {
                Some(FinalResponseToolResult {
                    tool_name: message.tool_name.clone()?,
                    content: message.content.clone(),
                })
            })
            .collect()
    }
}