use super::error_tracker::ErrorTracker;
use super::nudge::Nudge;
use super::response_validator::{ResponseValidator, RetryNudgeFn};
use super::scoring::{ClassifierAction, ToolCallScore, ToolCallScorer};
use super::scoring_context::ScoringContext;
use super::step_enforcer::{StepEnforcer, StepPrerequisite};
use crate::clients::base::{LLMResponse, ToolCall};
use crate::prompts::nudges;
use indexmap::IndexSet;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum GuardAction {
Execute,
Retry,
StepBlocked,
Fatal,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CheckResult {
pub action: GuardAction,
pub tool_calls: Option<Vec<ToolCall>>,
pub nudge: Option<Nudge>,
pub reason: Option<String>,
}
impl CheckResult {
pub fn execute(tool_calls: Vec<ToolCall>) -> Self {
Self {
action: GuardAction::Execute,
tool_calls: Some(tool_calls),
nudge: None,
reason: None,
}
}
pub fn retry(nudge: Nudge) -> Self {
Self {
action: GuardAction::Retry,
tool_calls: None,
nudge: Some(nudge),
reason: None,
}
}
pub fn step_blocked(nudge: Nudge) -> Self {
Self {
action: GuardAction::StepBlocked,
tool_calls: None,
nudge: Some(nudge),
reason: None,
}
}
pub fn fatal(reason: impl Into<String>) -> Self {
Self {
action: GuardAction::Fatal,
tool_calls: None,
nudge: None,
reason: Some(reason.into()),
}
}
}
pub enum TerminalTool {
Single(String),
Multiple(IndexSet<String>),
}
pub struct Guardrails {
pub error_tracker: ErrorTracker,
pub validator: ResponseValidator,
pub step_enforcer: StepEnforcer,
pub terminal_tools: IndexSet<String>,
scorer: Option<Arc<dyn ToolCallScorer>>,
last_scores: Vec<ToolCallScore>,
}
impl Guardrails {
#[allow(clippy::too_many_arguments)]
pub fn new(
tool_names: Vec<String>,
terminal_tool: TerminalTool,
required_steps: Option<Vec<String>>,
tool_prerequisites: Option<indexmap::IndexMap<String, Vec<StepPrerequisite>>>,
max_retries: i32,
max_tool_errors: i32,
rescue_enabled: bool,
max_premature_attempts: i32,
retry_nudge: Option<RetryNudgeFn>,
) -> Self {
let terminal_set: IndexSet<String> = match terminal_tool {
TerminalTool::Single(name) => {
let mut s = IndexSet::new();
s.insert(name);
s
}
TerminalTool::Multiple(names) => names,
};
let steps = required_steps.unwrap_or_default();
Self {
error_tracker: ErrorTracker::new(max_retries, max_tool_errors),
validator: ResponseValidator::new(tool_names, rescue_enabled, retry_nudge),
step_enforcer: StepEnforcer::new(
steps,
terminal_set.clone(),
tool_prerequisites,
max_premature_attempts,
2,
),
terminal_tools: terminal_set,
scorer: None,
last_scores: Vec::new(),
}
}
pub fn with_scorer(mut self, scorer: Arc<dyn ToolCallScorer>) -> Self {
self.scorer = Some(scorer);
self
}
pub fn last_scores(&self) -> &[ToolCallScore] {
&self.last_scores
}
pub fn check(&mut self, response: &LLMResponse) -> CheckResult {
self.check_inner(response, None)
}
pub fn check_with_scoring_context(
&mut self,
response: &LLMResponse,
scoring_context: &ScoringContext,
) -> CheckResult {
self.check_inner(response, Some(scoring_context))
}
fn check_inner(
&mut self,
response: &LLMResponse,
scoring_context: Option<&ScoringContext>,
) -> CheckResult {
self.last_scores.clear();
let validation = self.validator.validate(response);
if validation.needs_retry {
self.error_tracker.record_retry();
if self.error_tracker.retries_exhausted() {
return CheckResult::fatal("Too many bad responses");
}
return CheckResult::retry(validation.nudge.expect("needs_retry requires a nudge"));
}
self.error_tracker.reset_retries();
let tool_calls = validation
.tool_calls
.expect("valid response requires tool_calls");
let step_check = self.step_enforcer.check(&tool_calls);
if step_check.needs_nudge {
if self.step_enforcer.premature_exhausted() {
return CheckResult::fatal("Too many skipped required steps");
}
return CheckResult::step_blocked(
step_check.nudge.expect("needs_nudge requires a nudge"),
);
}
let prereq_check = self.step_enforcer.check_prerequisites(&tool_calls);
if prereq_check.needs_nudge {
if self.step_enforcer.prereq_exhausted() {
return CheckResult::fatal("Too many prerequisite violations");
}
return CheckResult::step_blocked(
prereq_check.nudge.expect("needs_nudge requires a nudge"),
);
}
if let (Some(scorer), Some(ctx)) = (self.scorer.as_ref(), scoring_context) {
let mut classifier_nudge: Option<Nudge> = None;
for call in &tool_calls {
match scorer.score(ctx, call) {
Ok(score) => {
tracing::info!(
target: "forge.classifier",
tool = %call.tool,
label = %score.label.as_label(),
confidence = score.confidence,
action = %score.action.as_str(),
latency_ms = score.latency_ms,
model_version = %score.model_version,
"tool-call classifier score"
);
if matches!(
score.action,
ClassifierAction::AdvisoryNudge | ClassifierAction::Block
) {
let label = score.label.as_label();
let nudge = Nudge::new(
"user",
nudges::classifier_nudge(label.as_ref()),
"classifier",
);
if score.action == ClassifierAction::Block || classifier_nudge.is_none()
{
classifier_nudge = Some(nudge);
}
}
self.last_scores.push(score);
}
Err(err) => {
tracing::warn!(
target: "forge.classifier",
tool = %call.tool,
error = %err,
"classifier scoring failed; allowing deterministic path"
);
}
}
}
if let Some(nudge) = classifier_nudge {
self.error_tracker.record_retry();
if self.error_tracker.retries_exhausted() {
return CheckResult::fatal("Too many classifier objections");
}
return CheckResult::retry(nudge);
}
}
CheckResult::execute(tool_calls)
}
pub fn record(&mut self, executed: &[&str]) -> bool {
for name in executed {
self.step_enforcer.record(name, None);
}
self.error_tracker.reset_retries();
self.error_tracker.reset_errors();
self.step_enforcer.reset_premature();
self.step_enforcer.reset_prereq_violations();
let has_terminal = executed
.iter()
.any(|name| self.terminal_tools.contains(*name));
has_terminal && self.step_enforcer.is_satisfied()
}
pub fn completed_steps(&self) -> indexmap::IndexMap<String, ()> {
self.step_enforcer.completed_steps()
}
pub fn pending_steps(&self) -> Vec<String> {
self.step_enforcer.pending()
}
pub fn premature_attempts(&self) -> i32 {
self.step_enforcer.premature_attempts()
}
}