use super::nudge::Nudge;
use super::policy::GuardrailState;
use crate::clients::base::ToolCall;
use crate::core::steps::{Prerequisite, StepTracker};
use crate::prompts::nudges;
use indexmap::{IndexMap, IndexSet};
#[derive(Debug, Clone, PartialEq)]
pub enum StepPrerequisite {
NameOnly(String),
ArgMatched {
tool: String,
match_arg: String,
},
}
impl From<&StepPrerequisite> for Prerequisite {
fn from(sp: &StepPrerequisite) -> Self {
match sp {
StepPrerequisite::NameOnly(name) => Prerequisite::NameOnly(name.clone()),
StepPrerequisite::ArgMatched { tool, match_arg } => Prerequisite::ArgMatched {
tool: tool.clone(),
match_arg: match_arg.clone(),
},
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct StepCheck {
pub nudge: Option<Nudge>,
pub needs_nudge: bool,
}
impl StepCheck {
pub fn ok() -> Self {
Self {
nudge: None,
needs_nudge: false,
}
}
pub fn blocked(nudge: Nudge) -> Self {
Self {
nudge: Some(nudge),
needs_nudge: true,
}
}
}
pub struct StepEnforcer {
pub tracker: StepTracker,
pub terminal_tools: IndexSet<String>,
pub tool_prerequisites: IndexMap<String, Vec<StepPrerequisite>>,
pub max_premature_attempts: i32,
pub max_prereq_violations: i32,
pub premature_attempts: i32,
pub prereq_violations: i32,
}
impl StepEnforcer {
pub fn new(
required_steps: Vec<String>,
terminal_tools: IndexSet<String>,
tool_prerequisites: Option<IndexMap<String, Vec<StepPrerequisite>>>,
max_premature_attempts: i32,
max_prereq_violations: i32,
) -> Self {
Self {
tracker: StepTracker::new(required_steps),
terminal_tools,
tool_prerequisites: tool_prerequisites.unwrap_or_default(),
max_premature_attempts,
max_prereq_violations,
premature_attempts: 0,
prereq_violations: 0,
}
}
pub fn check(&mut self, tool_calls: &[ToolCall]) -> StepCheck {
if self.tracker.is_satisfied() {
return StepCheck::ok();
}
let has_terminal = tool_calls
.iter()
.any(|c| self.terminal_tools.contains(&c.tool));
if !has_terminal {
return StepCheck::ok();
}
self.premature_attempts += 1;
let tier = std::cmp::min(self.premature_attempts, 3);
let pending = self.tracker.pending();
let pending_refs: Vec<&str> = pending.iter().map(|s| s.as_str()).collect();
let mixed_terminal_batch = tool_calls
.iter()
.any(|c| !self.terminal_tools.contains(&c.tool));
if mixed_terminal_batch {
let blocked: Vec<&str> = self.terminal_tools.iter().map(|s| s.as_str()).collect();
let content = nudges::unsafe_batch_nudge(&pending_refs, &blocked);
let nudge = Nudge::new("user", content, "unsafe_batch").with_tier(tier);
return StepCheck::blocked(nudge);
}
let terminal_name = tool_calls
.iter()
.find(|c| self.terminal_tools.contains(&c.tool))
.map(|c| c.tool.as_str())
.unwrap_or("terminal");
let content = nudges::step_nudge(terminal_name, &pending_refs, tier);
let nudge = Nudge::new("user", content, "step").with_tier(tier);
StepCheck::blocked(nudge)
}
pub fn check_prerequisites(&mut self, tool_calls: &[ToolCall]) -> StepCheck {
for tc in tool_calls {
if let Some(prereqs) = self.tool_prerequisites.get(&tc.tool) {
let rust_prereqs: Vec<Prerequisite> = prereqs.iter().map(|p| p.into()).collect();
let result = self
.tracker
.check_prerequisites(&tc.tool, &tc.args, &rust_prereqs);
if !result.satisfied {
self.prereq_violations += 1;
let missing_refs: Vec<&str> =
result.missing.iter().map(|s| s.as_str()).collect();
let content = nudges::prerequisite_nudge(&tc.tool, &missing_refs);
let nudge = Nudge::new("user", content, "prerequisite");
return StepCheck::blocked(nudge);
}
}
}
StepCheck::ok()
}
pub fn record(&mut self, tool_name: &str, args: Option<&IndexMap<String, serde_json::Value>>) {
self.tracker.record(tool_name, args);
}
pub fn is_satisfied(&self) -> bool {
self.tracker.is_satisfied()
}
pub fn pending(&self) -> Vec<String> {
self.tracker.pending()
}
pub fn terminal_reached(&self, tool_calls: &[ToolCall]) -> bool {
let has_terminal = tool_calls
.iter()
.any(|c| self.terminal_tools.contains(&c.tool));
has_terminal && self.tracker.is_satisfied()
}
pub fn reset_premature(&mut self) {
self.premature_attempts = 0;
}
pub fn reset_prereq_violations(&mut self) {
self.prereq_violations = 0;
}
pub fn summary_hint(&self) -> String {
self.tracker.summary_hint()
}
pub fn guardrail_state(&self, tool_names: &[String]) -> GuardrailState {
let completed_steps = self.completed_steps().keys().cloned().collect();
GuardrailState::from_parts(
completed_steps,
self.pending(),
tool_names,
&self.terminal_tools,
)
}
pub fn premature_attempts(&self) -> i32 {
self.premature_attempts
}
pub fn premature_exhausted(&self) -> bool {
self.premature_attempts > self.max_premature_attempts
}
pub fn prereq_violations(&self) -> i32 {
self.prereq_violations
}
pub fn prereq_exhausted(&self) -> bool {
self.prereq_violations > self.max_prereq_violations
}
pub fn completed_steps(&self) -> IndexMap<String, ()> {
let all_required = self.tracker.required_steps();
let pending: IndexSet<String> = self.tracker.pending().into_iter().collect();
all_required
.iter()
.filter(|s| !pending.contains(s.as_str()))
.map(|s| (s.clone(), ()))
.collect()
}
}