use aws_smithy_async::time::TimeSource;
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use aws_smithy_types::retry::ErrorKind;
use std::fmt;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
const DEFAULT_CAPACITY: usize = 500;
pub const MAXIMUM_CAPACITY: usize = 500_000_000;
const DEFAULT_RETRY_COST: u32 = 5;
const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
const PERMIT_REGENERATION_AMOUNT: usize = 1;
const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
#[derive(Clone, Debug)]
pub struct TokenBucket {
semaphore: Arc<Semaphore>,
max_permits: usize,
timeout_retry_cost: u32,
retry_cost: u32,
success_reward: f32,
fractional_tokens: Arc<AtomicF32>,
refill_rate: f32,
last_refill_time_secs: Arc<AtomicU32>,
}
impl std::panic::UnwindSafe for AtomicF32 {}
impl std::panic::RefUnwindSafe for AtomicF32 {}
struct AtomicF32 {
storage: AtomicU32,
}
impl AtomicF32 {
fn new(value: f32) -> Self {
let as_u32 = value.to_bits();
Self {
storage: AtomicU32::new(as_u32),
}
}
fn store(&self, value: f32) {
let as_u32 = value.to_bits();
self.storage.store(as_u32, Ordering::Relaxed)
}
fn load(&self) -> f32 {
let as_u32 = self.storage.load(Ordering::Relaxed);
f32::from_bits(as_u32)
}
}
impl fmt::Debug for AtomicF32 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AtomicF32")
.field("value", &self.load())
.finish()
}
}
impl Clone for AtomicF32 {
fn clone(&self) -> Self {
AtomicF32 {
storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
}
}
}
impl Storable for TokenBucket {
type Storer = StoreReplace<Self>;
}
impl Default for TokenBucket {
fn default() -> Self {
Self {
semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
max_permits: DEFAULT_CAPACITY,
timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
retry_cost: DEFAULT_RETRY_COST,
success_reward: DEFAULT_SUCCESS_REWARD,
fractional_tokens: Arc::new(AtomicF32::new(0.0)),
refill_rate: 0.0,
last_refill_time_secs: Arc::new(AtomicU32::new(0)),
}
}
}
impl TokenBucket {
pub fn new(initial_quota: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(initial_quota)),
max_permits: initial_quota,
..Default::default()
}
}
pub fn unlimited() -> Self {
Self {
semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
max_permits: MAXIMUM_CAPACITY,
timeout_retry_cost: 0,
retry_cost: 0,
success_reward: 0.0,
fractional_tokens: Arc::new(AtomicF32::new(0.0)),
refill_rate: 0.0,
last_refill_time_secs: Arc::new(AtomicU32::new(0)),
}
}
pub fn builder() -> TokenBucketBuilder {
TokenBucketBuilder::default()
}
pub(crate) fn acquire(
&self,
err: &ErrorKind,
time_source: &impl TimeSource,
) -> Option<OwnedSemaphorePermit> {
self.refill_tokens_based_on_time(time_source);
self.convert_fractional_tokens();
let retry_cost = if err == &ErrorKind::TransientError {
self.timeout_retry_cost
} else {
self.retry_cost
};
self.semaphore
.clone()
.try_acquire_many_owned(retry_cost)
.ok()
}
pub(crate) fn success_reward(&self) -> f32 {
self.success_reward
}
pub(crate) fn regenerate_a_token(&self) {
self.add_permits(PERMIT_REGENERATION_AMOUNT);
}
#[inline]
fn convert_fractional_tokens(&self) {
let mut calc_fractional_tokens = self.fractional_tokens.load();
if !calc_fractional_tokens.is_finite() {
tracing::error!(
"Fractional tokens corrupted to: {}, resetting to 0.0",
calc_fractional_tokens
);
self.fractional_tokens.store(0.0);
return;
}
let full_tokens_accumulated = calc_fractional_tokens.floor();
if full_tokens_accumulated >= 1.0 {
self.add_permits(full_tokens_accumulated as usize);
calc_fractional_tokens -= full_tokens_accumulated;
}
self.fractional_tokens.store(calc_fractional_tokens);
}
#[inline]
fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
if self.refill_rate > 0.0 {
let current_time_secs = time_source
.now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_secs() as u32;
let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
if current_time_secs == last_refill_secs {
return;
}
if self
.last_refill_time_secs
.compare_exchange(
last_refill_secs,
current_time_secs,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_err()
{
return;
}
let current_fractional = self.fractional_tokens.load();
let max_fractional = self.max_permits as f32;
if current_fractional >= max_fractional {
return;
}
let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
self.fractional_tokens.store(new_fractional);
}
}
#[inline]
pub(crate) fn reward_success(&self) {
if self.success_reward > 0.0 {
let current = self.fractional_tokens.load();
let max_fractional = self.max_permits as f32;
if current >= max_fractional {
return;
}
let new_fractional = (current + self.success_reward).min(max_fractional);
self.fractional_tokens.store(new_fractional);
}
}
pub(crate) fn add_permits(&self, amount: usize) {
let available = self.semaphore.available_permits();
if available >= self.max_permits {
return;
}
self.semaphore
.add_permits(amount.min(self.max_permits - available));
}
pub fn is_full(&self) -> bool {
self.convert_fractional_tokens();
self.semaphore.available_permits() >= self.max_permits
}
pub fn is_empty(&self) -> bool {
self.convert_fractional_tokens();
self.semaphore.available_permits() == 0
}
#[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
pub(crate) fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
#[allow(dead_code)]
#[doc(hidden)]
#[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
self.last_refill_time_secs.clone()
}
}
#[derive(Clone, Debug, Default)]
pub struct TokenBucketBuilder {
capacity: Option<usize>,
retry_cost: Option<u32>,
timeout_retry_cost: Option<u32>,
success_reward: Option<f32>,
refill_rate: Option<f32>,
}
impl TokenBucketBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn capacity(mut self, mut capacity: usize) -> Self {
if capacity > MAXIMUM_CAPACITY {
capacity = MAXIMUM_CAPACITY;
}
self.capacity = Some(capacity);
self
}
pub fn retry_cost(mut self, retry_cost: u32) -> Self {
self.retry_cost = Some(retry_cost);
self
}
pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
self.timeout_retry_cost = Some(timeout_retry_cost);
self
}
pub fn success_reward(mut self, reward: f32) -> Self {
self.success_reward = Some(reward);
self
}
pub fn refill_rate(mut self, rate: f32) -> Self {
let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
self.refill_rate = Some(validated_rate);
self
}
pub fn build(self) -> TokenBucket {
TokenBucket {
semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
timeout_retry_cost: self
.timeout_retry_cost
.unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
fractional_tokens: Arc::new(AtomicF32::new(0.0)),
refill_rate: self.refill_rate.unwrap_or(0.0),
last_refill_time_secs: Arc::new(AtomicU32::new(0)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_async::test_util::ManualTimeSource;
use std::{sync::LazyLock, time::UNIX_EPOCH};
static TIME_SOURCE: LazyLock<ManualTimeSource> =
LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
#[test]
fn test_unlimited_token_bucket() {
let bucket = TokenBucket::unlimited();
assert!(bucket
.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
.is_some());
assert!(bucket
.acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
.is_some());
assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
assert_eq!(bucket.retry_cost, 0);
assert_eq!(bucket.timeout_retry_cost, 0);
let mut permits = Vec::new();
for _ in 0..100 {
let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
assert!(permit.is_some());
permits.push(permit);
assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
}
}
#[test]
fn test_bounded_permits_exhaustion() {
let bucket = TokenBucket::new(10);
let mut permits = Vec::new();
for _ in 0..100 {
let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
if let Some(p) = permit {
permits.push(p);
} else {
break;
}
}
assert_eq!(permits.len(), 2);
assert!(bucket
.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
.is_none());
}
#[test]
fn test_fractional_tokens_accumulate_and_convert() {
let bucket = TokenBucket::builder()
.capacity(10)
.success_reward(0.4)
.build();
let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
assert_eq!(bucket.semaphore.available_permits(), 0);
bucket.reward_success();
bucket.convert_fractional_tokens();
assert_eq!(bucket.semaphore.available_permits(), 0);
bucket.reward_success();
bucket.convert_fractional_tokens();
assert_eq!(bucket.semaphore.available_permits(), 0);
bucket.reward_success();
bucket.convert_fractional_tokens();
assert_eq!(bucket.semaphore.available_permits(), 1);
}
#[test]
fn test_fractional_tokens_respect_max_capacity() {
let bucket = TokenBucket::builder()
.capacity(10)
.success_reward(2.0)
.build();
for _ in 0..20 {
bucket.reward_success();
}
assert!(bucket.semaphore.available_permits() == 10);
}
#[test]
fn test_convert_fractional_tokens() {
let test_cases = [
(0.7, 0, 0.7),
(1.0, 1, 0.0),
(2.3, 2, 0.3),
(5.8, 5, 0.8),
(10.0, 10, 0.0),
(f32::NAN, 0, 0.0),
(f32::INFINITY, 0, 0.0),
];
for (input, expected_permits, expected_remaining) in test_cases {
let bucket = TokenBucket::builder().capacity(10).build();
let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
let initial = bucket.semaphore.available_permits();
bucket.fractional_tokens.store(input);
bucket.convert_fractional_tokens();
assert_eq!(
bucket.semaphore.available_permits() - initial,
expected_permits
);
assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
}
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_builder_with_custom_values() {
let bucket = TokenBucket::builder()
.capacity(100)
.retry_cost(10)
.timeout_retry_cost(20)
.success_reward(0.5)
.refill_rate(2.5)
.build();
assert_eq!(bucket.max_permits, 100);
assert_eq!(bucket.retry_cost, 10);
assert_eq!(bucket.timeout_retry_cost, 20);
assert_eq!(bucket.success_reward, 0.5);
assert_eq!(bucket.refill_rate, 2.5);
}
#[test]
fn test_builder_refill_rate_validation() {
let bucket = TokenBucket::builder().refill_rate(-5.0).build();
assert_eq!(bucket.refill_rate, 0.0);
let bucket = TokenBucket::builder().refill_rate(1.5).build();
assert_eq!(bucket.refill_rate, 1.5);
let bucket = TokenBucket::builder().refill_rate(0.0).build();
assert_eq!(bucket.refill_rate, 0.0);
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_builder_custom_time_source() {
use aws_smithy_async::test_util::ManualTimeSource;
use std::time::UNIX_EPOCH;
let manual_time = ManualTimeSource::new(UNIX_EPOCH);
let bucket = TokenBucket::builder()
.capacity(100)
.refill_rate(1.0)
.build();
let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
assert_eq!(bucket.available_permits(), 0);
manual_time.advance(Duration::from_secs(5));
bucket.refill_tokens_based_on_time(&manual_time);
bucket.convert_fractional_tokens();
assert_eq!(bucket.available_permits(), 5);
}
#[test]
fn test_atomicf32_f32_to_bits_conversion_correctness() {
let test_values = vec![
0.0,
-0.0,
1.0,
-1.0,
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
f32::MIN,
f32::MAX,
f32::MIN_POSITIVE,
f32::EPSILON,
std::f32::consts::PI,
std::f32::consts::E,
1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
for &expected in &test_values {
let atomic = AtomicF32::new(expected);
let actual = atomic.load();
if expected.is_nan() {
assert!(actual.is_nan(), "Expected NaN, got {}", actual);
assert_eq!(expected.to_bits(), actual.to_bits());
} else {
assert_eq!(expected.to_bits(), actual.to_bits());
}
}
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_atomicf32_store_load_preserves_exact_bits() {
let atomic = AtomicF32::new(0.0);
let critical_bit_patterns = vec![
0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
for &expected_bits in &critical_bit_patterns {
let expected_f32 = f32::from_bits(expected_bits);
atomic.store(expected_f32);
let loaded_f32 = atomic.load();
let actual_bits = loaded_f32.to_bits();
assert_eq!(expected_bits, actual_bits);
}
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_atomicf32_concurrent_store_load_safety() {
use std::sync::Arc;
use std::thread;
let atomic = Arc::new(AtomicF32::new(0.0));
let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut handles = Vec::new();
for &value in &test_values {
let atomic_clone = Arc::clone(&atomic);
let handle = thread::spawn(move || {
for _ in 0..1000 {
atomic_clone.store(value);
}
});
handles.push(handle);
}
let atomic_reader = Arc::clone(&atomic);
let reader_handle = thread::spawn(move || {
let mut readings = Vec::new();
for _ in 0..5000 {
let value = atomic_reader.load();
readings.push(value);
}
readings
});
for handle in handles {
handle.join().expect("Writer thread panicked");
}
let readings = reader_handle.join().expect("Reader thread panicked");
for &reading in &readings {
assert!(test_values.contains(&reading) || reading == 0.0);
assert!(
reading.is_finite() || reading == 0.0,
"Corrupted reading detected"
);
}
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_atomicf32_stress_concurrent_access() {
use std::sync::{Arc, Barrier};
use std::thread;
let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let atomic = Arc::new(AtomicF32::new(0.0));
let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
for i in 0..10 {
let atomic_clone = Arc::clone(&atomic);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
for _ in 0..10000 {
let value = i as f32;
atomic_clone.store(value);
let loaded = atomic_clone.load();
assert!(loaded >= 0.0 && loaded <= 9.0);
assert!(
expected_values.contains(&loaded),
"Got unexpected value: {}, expected one of {:?}",
loaded,
expected_values
);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_atomicf32_integration_with_token_bucket_usage() {
let atomic = AtomicF32::new(0.0);
let success_reward = 0.3;
let iterations = 5;
for _ in 1..=iterations {
let current = atomic.load();
atomic.store(current + success_reward);
}
let accumulated = atomic.load();
let expected_total = iterations as f32 * success_reward;
let full_tokens = accumulated.floor();
atomic.store(accumulated - full_tokens);
let remaining = atomic.load();
assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
assert_eq!(remaining, expected_total - expected_total.floor());
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_atomicf32_clone_creates_independent_copy() {
let original = AtomicF32::new(123.456);
let cloned = original.clone();
assert_eq!(original.load(), cloned.load());
original.store(999.0);
assert_eq!(
cloned.load(),
123.456,
"Clone should be unaffected by original changes"
);
assert_eq!(original.load(), 999.0, "Original should have new value");
}
#[test]
fn test_combined_time_and_success_rewards() {
use aws_smithy_async::test_util::ManualTimeSource;
use std::time::UNIX_EPOCH;
let time_source = ManualTimeSource::new(UNIX_EPOCH);
let current_time_secs = UNIX_EPOCH
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let bucket = TokenBucket {
refill_rate: 1.0,
success_reward: 0.5,
last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
semaphore: Arc::new(Semaphore::new(0)),
max_permits: 100,
..Default::default()
};
bucket.reward_success();
bucket.reward_success();
time_source.advance(Duration::from_secs(2));
bucket.refill_tokens_based_on_time(&time_source);
bucket.convert_fractional_tokens();
assert_eq!(bucket.available_permits(), 3);
assert!(bucket.fractional_tokens.load().abs() < 0.0001);
}
#[test]
fn test_refill_rates() {
use aws_smithy_async::test_util::ManualTimeSource;
use std::time::UNIX_EPOCH;
let test_cases = [
(10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
let time_source = ManualTimeSource::new(UNIX_EPOCH);
let current_time_secs = UNIX_EPOCH
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let bucket = TokenBucket {
refill_rate,
last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
semaphore: Arc::new(Semaphore::new(0)),
max_permits: 100,
..Default::default()
};
time_source.advance(Duration::from_secs(elapsed_secs));
bucket.refill_tokens_based_on_time(&time_source);
bucket.convert_fractional_tokens();
assert_eq!(
bucket.available_permits(),
expected_permits,
"Rate {}: After {}s expected {} permits",
refill_rate,
elapsed_secs,
expected_permits
);
assert!(
(bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
"Rate {}: After {}s expected {} fractional, got {}",
refill_rate,
elapsed_secs,
expected_fractional,
bucket.fractional_tokens.load()
);
}
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_rewards_capped_at_max_capacity() {
use aws_smithy_async::test_util::ManualTimeSource;
use std::time::UNIX_EPOCH;
let time_source = ManualTimeSource::new(UNIX_EPOCH);
let current_time_secs = UNIX_EPOCH
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let bucket = TokenBucket {
refill_rate: 50.0,
success_reward: 2.0,
last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
semaphore: Arc::new(Semaphore::new(5)),
max_permits: 10,
..Default::default()
};
for _ in 0..50 {
bucket.reward_success();
}
assert_eq!(bucket.fractional_tokens.load(), 10.0);
time_source.advance(Duration::from_secs(100));
bucket.refill_tokens_based_on_time(&time_source);
assert_eq!(
bucket.fractional_tokens.load(),
10.0,
"Fractional tokens should be capped at max_permits"
);
bucket.convert_fractional_tokens();
assert_eq!(bucket.available_permits(), 10);
}
#[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
#[test]
fn test_concurrent_time_based_refill_no_over_generation() {
use aws_smithy_async::test_util::ManualTimeSource;
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::UNIX_EPOCH;
let time_source = ManualTimeSource::new(UNIX_EPOCH);
let current_time_secs = UNIX_EPOCH
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let bucket = Arc::new(TokenBucket {
refill_rate: 1.0,
last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
semaphore: Arc::new(Semaphore::new(0)),
max_permits: 100,
..Default::default()
});
time_source.advance(Duration::from_secs(10));
let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
let barrier = Arc::new(Barrier::new(100));
let mut handles = Vec::new();
for _ in 0..100 {
let bucket_clone1 = Arc::clone(&bucket);
let barrier_clone1 = Arc::clone(&barrier);
let time_source_clone1 = shared_time_source.clone();
let bucket_clone2 = Arc::clone(&bucket);
let barrier_clone2 = Arc::clone(&barrier);
let time_source_clone2 = shared_time_source.clone();
let handle1 = thread::spawn(move || {
barrier_clone1.wait();
bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
});
let handle2 = thread::spawn(move || {
barrier_clone2.wait();
bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
});
handles.push(handle1);
handles.push(handle2);
}
for handle in handles {
handle.join().unwrap();
}
bucket.convert_fractional_tokens();
assert_eq!(
bucket.available_permits(),
10,
"Only one thread should have added tokens, not all 100"
);
assert!(bucket.fractional_tokens.load().abs() < 0.0001);
}
#[test]
fn test_is_full_accounts_for_fractional_tokens() {
let bucket = TokenBucket::builder()
.capacity(2)
.retry_cost(1)
.success_reward(0.9)
.build();
assert!(bucket.is_full());
let _p1 = bucket
.acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
.unwrap();
let _p2 = bucket
.acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
.unwrap();
assert!(bucket.is_empty());
bucket.reward_success();
bucket.reward_success();
bucket.reward_success();
assert!(bucket.is_full());
assert!(!bucket.is_empty());
}
#[test]
fn test_is_empty_accounts_for_fractional_tokens() {
let bucket = TokenBucket::builder()
.capacity(10)
.retry_cost(10)
.success_reward(0.5)
.build();
let _p = bucket
.acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
.unwrap();
assert_eq!(bucket.semaphore.available_permits(), 0);
bucket.reward_success();
assert!(bucket.is_empty());
bucket.reward_success();
assert!(!bucket.is_empty());
}
}