use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use crate::event::AgentEvent;
use crate::plugin::{EventObserver, Plugin, PluginCapabilities, SteeringSource};
use crate::types::AgentMessage;
type MessageProvider = Arc<dyn Fn() -> String + Send + Sync>;
type GraceProvider = Arc<dyn Fn() -> usize + Send + Sync>;
const DEFAULT_WRAP_UP_MESSAGE: &str = "\
You have used your turn budget. Stop calling work tools and deliver your \
final result now. Summarize:\n\
- What you accomplished.\n\
- What remains unfinished, if anything.\n\
- Any partial findings the caller should know about.\n\
\n\
Then stop. Do not call any more work tools. Ask for input only if a user \
answer is genuinely required before you can continue.";
pub struct GracefulTurnLimit {
max_iterations: usize,
default_grace: usize,
grace_provider: Option<GraceProvider>,
turns_completed: AtomicUsize,
fired: Arc<AtomicBool>,
message_provider: MessageProvider,
}
impl GracefulTurnLimit {
pub fn from_hard_cap(max_iterations: usize, grace_iterations: usize) -> Option<Self> {
Self::from_hard_cap_with_message_provider(
max_iterations,
grace_iterations,
Arc::new(|| DEFAULT_WRAP_UP_MESSAGE.to_string()),
)
}
pub fn from_hard_cap_with_message_provider(
max_iterations: usize,
grace_iterations: usize,
message_provider: MessageProvider,
) -> Option<Self> {
Self::from_hard_cap_with_providers(max_iterations, grace_iterations, message_provider, None)
}
pub fn from_hard_cap_with_providers(
max_iterations: usize,
default_grace: usize,
message_provider: MessageProvider,
grace_provider: Option<GraceProvider>,
) -> Option<Self> {
if default_grace == 0 || default_grace >= max_iterations {
return None;
}
Some(Self {
max_iterations,
default_grace,
grace_provider,
turns_completed: AtomicUsize::new(0),
fired: Arc::new(AtomicBool::new(false)),
message_provider,
})
}
pub fn default_wrap_up_message() -> &'static str {
DEFAULT_WRAP_UP_MESSAGE
}
pub fn signal(&self) -> Arc<AtomicBool> {
self.fired.clone()
}
fn effective_grace(&self) -> usize {
let raw = self
.grace_provider
.as_ref()
.map(|p| p())
.unwrap_or(self.default_grace);
raw.clamp(1, self.max_iterations.saturating_sub(1).max(1))
}
pub fn soft_limit(&self) -> usize {
self.max_iterations.saturating_sub(self.effective_grace())
}
}
impl Plugin for GracefulTurnLimit {
fn name(&self) -> &'static str {
"graceful_turn_limit"
}
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities {
event_observer: true,
steering: true,
..PluginCapabilities::default()
}
}
}
#[async_trait]
impl EventObserver for GracefulTurnLimit {
async fn on_event(&self, event: &AgentEvent) {
if matches!(event, AgentEvent::TurnEnd { .. }) {
self.turns_completed.fetch_add(1, Ordering::Relaxed);
}
}
}
#[async_trait]
impl SteeringSource for GracefulTurnLimit {
async fn next_steering_messages(&self) -> Vec<AgentMessage> {
if self.turns_completed.load(Ordering::Relaxed) < self.soft_limit() {
return Vec::new();
}
if self.fired.swap(true, Ordering::Relaxed) {
return Vec::new();
}
let content = (self.message_provider)();
vec![AgentMessage::System {
content,
timestamp: None,
}]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_hard_cap_rejects_zero_grace() {
assert!(GracefulTurnLimit::from_hard_cap(50, 0).is_none());
}
#[test]
fn from_hard_cap_rejects_grace_at_or_above_cap() {
assert!(GracefulTurnLimit::from_hard_cap(10, 10).is_none());
assert!(GracefulTurnLimit::from_hard_cap(10, 99).is_none());
}
#[test]
fn from_hard_cap_computes_soft_limit() {
let plugin = GracefulTurnLimit::from_hard_cap(50, 5).unwrap();
assert_eq!(plugin.soft_limit(), 45);
}
#[tokio::test]
async fn does_not_fire_before_soft_limit() {
let plugin = GracefulTurnLimit::from_hard_cap(10, 3).unwrap();
for _ in 0..6 {
plugin.on_event(&turn_end()).await;
}
let msgs = plugin.next_steering_messages().await;
assert!(msgs.is_empty());
assert!(!plugin.fired.load(Ordering::Relaxed));
}
#[tokio::test]
async fn fires_once_at_soft_limit() {
let plugin = GracefulTurnLimit::from_hard_cap(10, 3).unwrap();
for _ in 0..7 {
plugin.on_event(&turn_end()).await;
}
let first = plugin.next_steering_messages().await;
assert_eq!(first.len(), 1, "should emit one wrap-up message");
match &first[0] {
AgentMessage::System { content, .. } => {
assert!(content.starts_with("You have used your turn budget."))
}
other => panic!("expected system wrap-up message, got {other:?}"),
}
assert!(plugin.fired.load(Ordering::Relaxed));
let second = plugin.next_steering_messages().await;
assert!(second.is_empty(), "wrap-up must be one-shot");
}
#[tokio::test]
async fn ignores_non_turn_start_events() {
let plugin = GracefulTurnLimit::from_hard_cap(10, 3).unwrap();
for _ in 0..20 {
plugin.on_event(&AgentEvent::AgentStart).await;
}
let msgs = plugin.next_steering_messages().await;
assert!(msgs.is_empty());
}
#[tokio::test]
async fn default_wrap_up_is_generic_and_directs_delivery() {
let plugin = GracefulTurnLimit::from_hard_cap(10, 3).unwrap();
for _ in 0..7 {
plugin.on_event(&turn_end()).await;
}
let msgs = plugin.next_steering_messages().await;
let AgentMessage::System { content: text, .. } = &msgs[0] else {
panic!("expected system wrap-up message");
};
assert!(text.contains("deliver your final result"), "{text}");
assert!(text.contains("Stop calling work tools"), "{text}");
assert!(!text.contains("message_result"), "{text}");
assert!(!text.contains("message_ask"), "{text}");
}
#[tokio::test]
async fn wrap_up_uses_custom_message_provider_when_supplied() {
let plugin = GracefulTurnLimit::from_hard_cap_with_message_provider(
10,
3,
Arc::new(|| "custom wrap-up".to_string()),
)
.unwrap();
for _ in 0..7 {
plugin.on_event(&turn_end()).await;
}
let msgs = plugin.next_steering_messages().await;
let AgentMessage::System { content, .. } = &msgs[0] else {
panic!("expected system wrap-up message");
};
assert_eq!(content, "custom wrap-up");
}
#[tokio::test]
async fn does_not_fire_before_first_completed_turn() {
let plugin = GracefulTurnLimit::from_hard_cap(6, 5).unwrap();
plugin.on_event(&AgentEvent::TurnStart).await;
let msgs = plugin.next_steering_messages().await;
assert!(msgs.is_empty());
assert!(!plugin.fired.load(Ordering::Relaxed));
}
#[tokio::test]
async fn dynamic_grace_provider_widens_wrap_up_window_for_bigger_jobs() {
let grace = Arc::new(std::sync::Mutex::new(3usize));
let grace_for_provider = grace.clone();
let plugin = GracefulTurnLimit::from_hard_cap_with_providers(
20,
3,
Arc::new(|| "wrap".to_string()),
Some(Arc::new(move || *grace_for_provider.lock().unwrap())),
)
.unwrap();
for _ in 0..16 {
plugin.on_event(&turn_end()).await;
}
let early = plugin.next_steering_messages().await;
assert!(early.is_empty(), "should not fire at 16 turns with grace=3");
assert_eq!(plugin.soft_limit(), 17);
*grace.lock().unwrap() = 8;
assert_eq!(plugin.soft_limit(), 12);
let fired = plugin.next_steering_messages().await;
assert_eq!(
fired.len(),
1,
"widened grace must let the plugin fire on the next poll"
);
assert!(plugin.fired.load(Ordering::Relaxed));
}
#[tokio::test]
async fn dynamic_grace_provider_clamps_out_of_range_returns() {
let plugin = GracefulTurnLimit::from_hard_cap_with_providers(
10,
3,
Arc::new(|| "wrap".to_string()),
Some(Arc::new(|| 0)),
)
.unwrap();
assert_eq!(plugin.soft_limit(), 9);
let plugin = GracefulTurnLimit::from_hard_cap_with_providers(
10,
3,
Arc::new(|| "wrap".to_string()),
Some(Arc::new(|| 999)),
)
.unwrap();
assert_eq!(plugin.soft_limit(), 1);
}
fn turn_end() -> AgentEvent {
AgentEvent::TurnEnd {
message: AgentMessage::Assistant {
content: crate::types::AssistantContent::text(""),
stop_reason: crate::types::StopReason::ToolUse,
error_message: None,
timestamp: None,
usage: None,
},
tool_results: Vec::new(),
}
}
}