use std::sync::atomic::{AtomicU32, Ordering};
pub struct ToolCallLimitTracker {
limit: Option<u32>,
count: AtomicU32,
}
impl ToolCallLimitTracker {
pub fn new(limit: Option<u32>) -> Self {
Self {
limit,
count: AtomicU32::new(0),
}
}
pub fn increment(&self, n: u32) {
self.count.fetch_add(n, Ordering::Relaxed);
}
pub fn is_exceeded(&self) -> bool {
match self.limit {
None => false,
Some(max) => self.count.load(Ordering::Relaxed) >= max,
}
}
pub fn count(&self) -> u32 {
self.count.load(Ordering::Relaxed)
}
pub fn limit(&self) -> Option<u32> {
self.limit
}
pub fn reset(&self) {
self.count.store(0, Ordering::Relaxed);
}
pub fn remaining(&self) -> Option<u32> {
self.limit
.map(|max| max.saturating_sub(self.count.load(Ordering::Relaxed)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_call_limit_tracker() {
let tracker = ToolCallLimitTracker::new(Some(5));
assert!(!tracker.is_exceeded());
for _ in 0..5 {
tracker.increment(1);
}
assert!(tracker.is_exceeded());
}
#[test]
fn test_tool_call_limit_tracker_unlimited() {
let tracker = ToolCallLimitTracker::new(None);
for _ in 0..1000 {
tracker.increment(1);
}
assert!(!tracker.is_exceeded());
}
#[test]
fn test_tool_call_limit_tracker_batch_increment() {
let tracker = ToolCallLimitTracker::new(Some(10));
tracker.increment(3);
tracker.increment(4);
assert!(!tracker.is_exceeded());
tracker.increment(3);
assert!(tracker.is_exceeded());
}
#[test]
fn test_tool_call_limit_zero_means_immediate() {
let tracker = ToolCallLimitTracker::new(Some(0));
assert!(tracker.is_exceeded());
}
#[test]
fn test_tool_call_limit_count_and_limit() {
let tracker = ToolCallLimitTracker::new(Some(10));
assert_eq!(tracker.count(), 0);
assert_eq!(tracker.limit(), Some(10));
tracker.increment(5);
assert_eq!(tracker.count(), 5);
}
#[test]
fn test_reset() {
let tracker = ToolCallLimitTracker::new(Some(5));
tracker.increment(5);
assert!(tracker.is_exceeded());
tracker.reset();
assert_eq!(tracker.count(), 0);
assert!(!tracker.is_exceeded());
assert_eq!(tracker.remaining(), Some(5));
assert_eq!(tracker.limit(), Some(5));
}
#[test]
fn test_remaining_unlimited() {
let tracker = ToolCallLimitTracker::new(None);
assert_eq!(tracker.remaining(), None);
tracker.increment(100);
assert_eq!(tracker.remaining(), None);
}
#[test]
fn test_remaining_with_limit() {
let tracker = ToolCallLimitTracker::new(Some(10));
assert_eq!(tracker.remaining(), Some(10));
tracker.increment(3);
assert_eq!(tracker.remaining(), Some(7));
tracker.increment(7);
assert_eq!(tracker.remaining(), Some(0));
tracker.increment(5);
assert_eq!(tracker.remaining(), Some(0));
}
#[test]
fn test_remaining_zero_limit() {
let tracker = ToolCallLimitTracker::new(Some(0));
assert_eq!(tracker.remaining(), Some(0));
}
}