1use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29#[derive(Clone, Debug)]
31pub struct TokenBucket {
32 semaphore: Arc<Semaphore>,
33 max_permits: usize,
34 timeout_retry_cost: u32,
35 retry_cost: u32,
36 success_reward: f32,
37 fractional_tokens: Arc<AtomicF32>,
38 refill_rate: f32,
39 last_refill_time_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47 storage: AtomicU32,
48}
49impl AtomicF32 {
50 fn new(value: f32) -> Self {
51 let as_u32 = value.to_bits();
52 Self {
53 storage: AtomicU32::new(as_u32),
54 }
55 }
56 fn store(&self, value: f32) {
57 let as_u32 = value.to_bits();
58 self.storage.store(as_u32, Ordering::Relaxed)
59 }
60 fn load(&self) -> f32 {
61 let as_u32 = self.storage.load(Ordering::Relaxed);
62 f32::from_bits(as_u32)
63 }
64}
65
66impl fmt::Debug for AtomicF32 {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AtomicF32")
70 .field("value", &self.load())
71 .finish()
72 }
73}
74
75impl Clone for AtomicF32 {
76 fn clone(&self) -> Self {
77 AtomicF32 {
79 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80 }
81 }
82}
83
84impl Storable for TokenBucket {
85 type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89 fn default() -> Self {
90 Self {
91 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
92 max_permits: DEFAULT_CAPACITY,
93 timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
94 retry_cost: DEFAULT_RETRY_COST,
95 success_reward: DEFAULT_SUCCESS_REWARD,
96 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
97 refill_rate: 0.0,
98 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
99 }
100 }
101}
102
103impl TokenBucket {
104 pub fn new(initial_quota: usize) -> Self {
106 Self {
107 semaphore: Arc::new(Semaphore::new(initial_quota)),
108 max_permits: initial_quota,
109 ..Default::default()
110 }
111 }
112
113 pub fn unlimited() -> Self {
115 Self {
116 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
117 max_permits: MAXIMUM_CAPACITY,
118 timeout_retry_cost: 0,
119 retry_cost: 0,
120 success_reward: 0.0,
121 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
122 refill_rate: 0.0,
123 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
124 }
125 }
126
127 pub fn builder() -> TokenBucketBuilder {
129 TokenBucketBuilder::default()
130 }
131
132 pub(crate) fn acquire(
133 &self,
134 err: &ErrorKind,
135 time_source: &impl TimeSource,
136 ) -> Option<OwnedSemaphorePermit> {
137 self.refill_tokens_based_on_time(time_source);
139 self.convert_fractional_tokens();
141
142 let retry_cost = if err == &ErrorKind::TransientError {
143 self.timeout_retry_cost
144 } else {
145 self.retry_cost
146 };
147
148 self.semaphore
149 .clone()
150 .try_acquire_many_owned(retry_cost)
151 .ok()
152 }
153
154 pub(crate) fn success_reward(&self) -> f32 {
155 self.success_reward
156 }
157
158 pub(crate) fn regenerate_a_token(&self) {
159 self.add_permits(PERMIT_REGENERATION_AMOUNT);
160 }
161
162 #[inline]
166 fn convert_fractional_tokens(&self) {
167 let mut calc_fractional_tokens = self.fractional_tokens.load();
168 if !calc_fractional_tokens.is_finite() {
170 tracing::error!(
171 "Fractional tokens corrupted to: {}, resetting to 0.0",
172 calc_fractional_tokens
173 );
174 self.fractional_tokens.store(0.0);
175 return;
176 }
177
178 let full_tokens_accumulated = calc_fractional_tokens.floor();
179 if full_tokens_accumulated >= 1.0 {
180 self.add_permits(full_tokens_accumulated as usize);
181 calc_fractional_tokens -= full_tokens_accumulated;
182 }
183 self.fractional_tokens.store(calc_fractional_tokens);
185 }
186
187 #[inline]
191 fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
192 if self.refill_rate > 0.0 {
193 let current_time_secs = time_source
195 .now()
196 .duration_since(SystemTime::UNIX_EPOCH)
197 .unwrap_or(Duration::ZERO)
198 .as_secs() as u32;
199
200 let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
201
202 if current_time_secs == last_refill_secs {
204 return;
205 }
206
207 if self
210 .last_refill_time_secs
211 .compare_exchange(
212 last_refill_secs,
213 current_time_secs,
214 Ordering::Relaxed,
215 Ordering::Relaxed,
216 )
217 .is_err()
218 {
219 return;
221 }
222
223 let current_fractional = self.fractional_tokens.load();
225 let max_fractional = self.max_permits as f32;
226
227 if current_fractional >= max_fractional {
229 return;
230 }
231
232 let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
233 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
234
235 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
237 self.fractional_tokens.store(new_fractional);
238 }
239 }
240
241 #[inline]
242 pub(crate) fn reward_success(&self) {
243 if self.success_reward > 0.0 {
244 let current = self.fractional_tokens.load();
245 let max_fractional = self.max_permits as f32;
246 if current >= max_fractional {
248 return;
249 }
250 let new_fractional = (current + self.success_reward).min(max_fractional);
252 self.fractional_tokens.store(new_fractional);
253 }
254 }
255
256 pub(crate) fn add_permits(&self, amount: usize) {
257 let available = self.semaphore.available_permits();
258 if available >= self.max_permits {
259 return;
260 }
261 self.semaphore
262 .add_permits(amount.min(self.max_permits - available));
263 }
264
265 pub fn is_full(&self) -> bool {
267 self.convert_fractional_tokens();
268 self.semaphore.available_permits() >= self.max_permits
269 }
270
271 pub fn is_empty(&self) -> bool {
273 self.convert_fractional_tokens();
274 self.semaphore.available_permits() == 0
275 }
276
277 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
279 pub(crate) fn available_permits(&self) -> usize {
280 self.semaphore.available_permits()
281 }
282
283 #[allow(dead_code)]
285 #[doc(hidden)]
286 #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
287 pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
288 self.last_refill_time_secs.clone()
289 }
290}
291
292#[derive(Clone, Debug, Default)]
294pub struct TokenBucketBuilder {
295 capacity: Option<usize>,
296 retry_cost: Option<u32>,
297 timeout_retry_cost: Option<u32>,
298 success_reward: Option<f32>,
299 refill_rate: Option<f32>,
300}
301
302impl TokenBucketBuilder {
303 pub fn new() -> Self {
305 Self::default()
306 }
307
308 pub fn capacity(mut self, mut capacity: usize) -> Self {
310 if capacity > MAXIMUM_CAPACITY {
311 capacity = MAXIMUM_CAPACITY;
312 }
313 self.capacity = Some(capacity);
314 self
315 }
316
317 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
319 self.retry_cost = Some(retry_cost);
320 self
321 }
322
323 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
325 self.timeout_retry_cost = Some(timeout_retry_cost);
326 self
327 }
328
329 pub fn success_reward(mut self, reward: f32) -> Self {
331 self.success_reward = Some(reward);
332 self
333 }
334
335 pub fn refill_rate(mut self, rate: f32) -> Self {
340 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
341 self.refill_rate = Some(validated_rate);
342 self
343 }
344
345 pub fn build(self) -> TokenBucket {
347 TokenBucket {
348 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
349 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
350 retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
351 timeout_retry_cost: self
352 .timeout_retry_cost
353 .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
354 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
355 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
356 refill_rate: self.refill_rate.unwrap_or(0.0),
357 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364
365 use super::*;
366 use aws_smithy_async::test_util::ManualTimeSource;
367 use std::{sync::LazyLock, time::UNIX_EPOCH};
368
369 static TIME_SOURCE: LazyLock<ManualTimeSource> =
370 LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
371
372 #[test]
373 fn test_unlimited_token_bucket() {
374 let bucket = TokenBucket::unlimited();
375
376 assert!(bucket
378 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
379 .is_some());
380 assert!(bucket
381 .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
382 .is_some());
383
384 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
386
387 assert_eq!(bucket.retry_cost, 0);
389 assert_eq!(bucket.timeout_retry_cost, 0);
390
391 let mut permits = Vec::new();
393 for _ in 0..100 {
394 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
395 assert!(permit.is_some());
396 permits.push(permit);
397 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
399 }
400 }
401
402 #[test]
403 fn test_bounded_permits_exhaustion() {
404 let bucket = TokenBucket::new(10);
405 let mut permits = Vec::new();
406
407 for _ in 0..100 {
408 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
409 if let Some(p) = permit {
410 permits.push(p);
411 } else {
412 break;
413 }
414 }
415
416 assert_eq!(permits.len(), 2); assert!(bucket
420 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
421 .is_none());
422 }
423
424 #[test]
425 fn test_fractional_tokens_accumulate_and_convert() {
426 let bucket = TokenBucket::builder()
427 .capacity(10)
428 .success_reward(0.4)
429 .build();
430
431 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
433 assert_eq!(bucket.semaphore.available_permits(), 0);
434
435 bucket.reward_success();
437 bucket.convert_fractional_tokens();
438 assert_eq!(bucket.semaphore.available_permits(), 0);
439
440 bucket.reward_success();
442 bucket.convert_fractional_tokens();
443 assert_eq!(bucket.semaphore.available_permits(), 0);
444
445 bucket.reward_success();
447 bucket.convert_fractional_tokens();
448 assert_eq!(bucket.semaphore.available_permits(), 1);
449 }
450
451 #[test]
452 fn test_fractional_tokens_respect_max_capacity() {
453 let bucket = TokenBucket::builder()
454 .capacity(10)
455 .success_reward(2.0)
456 .build();
457
458 for _ in 0..20 {
459 bucket.reward_success();
460 }
461
462 assert!(bucket.semaphore.available_permits() == 10);
463 }
464
465 #[test]
466 fn test_convert_fractional_tokens() {
467 let test_cases = [
469 (0.7, 0, 0.7),
470 (1.0, 1, 0.0),
471 (2.3, 2, 0.3),
472 (5.8, 5, 0.8),
473 (10.0, 10, 0.0),
474 (f32::NAN, 0, 0.0),
476 (f32::INFINITY, 0, 0.0),
477 ];
478
479 for (input, expected_permits, expected_remaining) in test_cases {
480 let bucket = TokenBucket::builder().capacity(10).build();
481 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
482 let initial = bucket.semaphore.available_permits();
483
484 bucket.fractional_tokens.store(input);
485 bucket.convert_fractional_tokens();
486
487 assert_eq!(
488 bucket.semaphore.available_permits() - initial,
489 expected_permits
490 );
491 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
492 }
493 }
494
495 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
496 #[test]
497 fn test_builder_with_custom_values() {
498 let bucket = TokenBucket::builder()
499 .capacity(100)
500 .retry_cost(10)
501 .timeout_retry_cost(20)
502 .success_reward(0.5)
503 .refill_rate(2.5)
504 .build();
505
506 assert_eq!(bucket.max_permits, 100);
507 assert_eq!(bucket.retry_cost, 10);
508 assert_eq!(bucket.timeout_retry_cost, 20);
509 assert_eq!(bucket.success_reward, 0.5);
510 assert_eq!(bucket.refill_rate, 2.5);
511 }
512
513 #[test]
514 fn test_builder_refill_rate_validation() {
515 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
517 assert_eq!(bucket.refill_rate, 0.0);
518
519 let bucket = TokenBucket::builder().refill_rate(1.5).build();
521 assert_eq!(bucket.refill_rate, 1.5);
522
523 let bucket = TokenBucket::builder().refill_rate(0.0).build();
525 assert_eq!(bucket.refill_rate, 0.0);
526 }
527
528 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
529 #[test]
530 fn test_builder_custom_time_source() {
531 use aws_smithy_async::test_util::ManualTimeSource;
532 use std::time::UNIX_EPOCH;
533
534 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
536 let bucket = TokenBucket::builder()
537 .capacity(100)
538 .refill_rate(1.0)
539 .build();
540
541 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
543 assert_eq!(bucket.available_permits(), 0);
544
545 manual_time.advance(Duration::from_secs(5));
547
548 bucket.refill_tokens_based_on_time(&manual_time);
549 bucket.convert_fractional_tokens();
550
551 assert_eq!(bucket.available_permits(), 5);
553 }
554
555 #[test]
556 fn test_atomicf32_f32_to_bits_conversion_correctness() {
557 let test_values = vec![
559 0.0,
560 -0.0,
561 1.0,
562 -1.0,
563 f32::INFINITY,
564 f32::NEG_INFINITY,
565 f32::NAN,
566 f32::MIN,
567 f32::MAX,
568 f32::MIN_POSITIVE,
569 f32::EPSILON,
570 std::f32::consts::PI,
571 std::f32::consts::E,
572 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
577
578 for &expected in &test_values {
579 let atomic = AtomicF32::new(expected);
580 let actual = atomic.load();
581
582 if expected.is_nan() {
584 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
585 assert_eq!(expected.to_bits(), actual.to_bits());
587 } else {
588 assert_eq!(expected.to_bits(), actual.to_bits());
589 }
590 }
591 }
592
593 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
594 #[test]
595 fn test_atomicf32_store_load_preserves_exact_bits() {
596 let atomic = AtomicF32::new(0.0);
597
598 let critical_bit_patterns = vec![
601 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
611
612 for &expected_bits in &critical_bit_patterns {
613 let expected_f32 = f32::from_bits(expected_bits);
614 atomic.store(expected_f32);
615 let loaded_f32 = atomic.load();
616 let actual_bits = loaded_f32.to_bits();
617
618 assert_eq!(expected_bits, actual_bits);
619 }
620 }
621
622 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
623 #[test]
624 fn test_atomicf32_concurrent_store_load_safety() {
625 use std::sync::Arc;
626 use std::thread;
627
628 let atomic = Arc::new(AtomicF32::new(0.0));
629 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
630 let mut handles = Vec::new();
631
632 for &value in &test_values {
634 let atomic_clone = Arc::clone(&atomic);
635 let handle = thread::spawn(move || {
636 for _ in 0..1000 {
637 atomic_clone.store(value);
638 }
639 });
640 handles.push(handle);
641 }
642
643 let atomic_reader = Arc::clone(&atomic);
645 let reader_handle = thread::spawn(move || {
646 let mut readings = Vec::new();
647 for _ in 0..5000 {
648 let value = atomic_reader.load();
649 readings.push(value);
650 }
651 readings
652 });
653
654 for handle in handles {
656 handle.join().expect("Writer thread panicked");
657 }
658
659 let readings = reader_handle.join().expect("Reader thread panicked");
660
661 for &reading in &readings {
664 assert!(test_values.contains(&reading) || reading == 0.0);
665
666 assert!(
669 reading.is_finite() || reading == 0.0,
670 "Corrupted reading detected"
671 );
672 }
673 }
674
675 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
676 #[test]
677 fn test_atomicf32_stress_concurrent_access() {
678 use std::sync::{Arc, Barrier};
679 use std::thread;
680
681 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
682 let atomic = Arc::new(AtomicF32::new(0.0));
683 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
685
686 for i in 0..10 {
688 let atomic_clone = Arc::clone(&atomic);
689 let barrier_clone = Arc::clone(&barrier);
690 let handle = thread::spawn(move || {
691 barrier_clone.wait(); for _ in 0..10000 {
695 let value = i as f32;
696 atomic_clone.store(value);
697 let loaded = atomic_clone.load();
698 assert!(loaded >= 0.0 && loaded <= 9.0);
700 assert!(
701 expected_values.contains(&loaded),
702 "Got unexpected value: {}, expected one of {:?}",
703 loaded,
704 expected_values
705 );
706 }
707 });
708 handles.push(handle);
709 }
710
711 for handle in handles {
712 handle.join().unwrap();
713 }
714 }
715
716 #[test]
717 fn test_atomicf32_integration_with_token_bucket_usage() {
718 let atomic = AtomicF32::new(0.0);
719 let success_reward = 0.3;
720 let iterations = 5;
721
722 for _ in 1..=iterations {
724 let current = atomic.load();
725 atomic.store(current + success_reward);
726 }
727
728 let accumulated = atomic.load();
729 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
733 atomic.store(accumulated - full_tokens);
734 let remaining = atomic.load();
735
736 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
739 assert_eq!(remaining, expected_total - expected_total.floor());
740 }
741
742 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
743 #[test]
744 fn test_atomicf32_clone_creates_independent_copy() {
745 let original = AtomicF32::new(123.456);
746 let cloned = original.clone();
747
748 assert_eq!(original.load(), cloned.load());
750
751 original.store(999.0);
753 assert_eq!(
754 cloned.load(),
755 123.456,
756 "Clone should be unaffected by original changes"
757 );
758 assert_eq!(original.load(), 999.0, "Original should have new value");
759 }
760
761 #[test]
762 fn test_combined_time_and_success_rewards() {
763 use aws_smithy_async::test_util::ManualTimeSource;
764 use std::time::UNIX_EPOCH;
765
766 let time_source = ManualTimeSource::new(UNIX_EPOCH);
767 let current_time_secs = UNIX_EPOCH
768 .duration_since(SystemTime::UNIX_EPOCH)
769 .unwrap()
770 .as_secs() as u32;
771
772 let bucket = TokenBucket {
773 refill_rate: 1.0,
774 success_reward: 0.5,
775 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
776 semaphore: Arc::new(Semaphore::new(0)),
777 max_permits: 100,
778 ..Default::default()
779 };
780
781 bucket.reward_success();
783 bucket.reward_success();
784
785 time_source.advance(Duration::from_secs(2));
787
788 bucket.refill_tokens_based_on_time(&time_source);
791 bucket.convert_fractional_tokens();
792
793 assert_eq!(bucket.available_permits(), 3);
794 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
795 }
796
797 #[test]
798 fn test_refill_rates() {
799 use aws_smithy_async::test_util::ManualTimeSource;
800 use std::time::UNIX_EPOCH;
801 let test_cases = [
803 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
810
811 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
812 let time_source = ManualTimeSource::new(UNIX_EPOCH);
813 let current_time_secs = UNIX_EPOCH
814 .duration_since(SystemTime::UNIX_EPOCH)
815 .unwrap()
816 .as_secs() as u32;
817
818 let bucket = TokenBucket {
819 refill_rate,
820 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
821 semaphore: Arc::new(Semaphore::new(0)),
822 max_permits: 100,
823 ..Default::default()
824 };
825
826 time_source.advance(Duration::from_secs(elapsed_secs));
828
829 bucket.refill_tokens_based_on_time(&time_source);
830 bucket.convert_fractional_tokens();
831
832 assert_eq!(
833 bucket.available_permits(),
834 expected_permits,
835 "Rate {}: After {}s expected {} permits",
836 refill_rate,
837 elapsed_secs,
838 expected_permits
839 );
840 assert!(
841 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
842 "Rate {}: After {}s expected {} fractional, got {}",
843 refill_rate,
844 elapsed_secs,
845 expected_fractional,
846 bucket.fractional_tokens.load()
847 );
848 }
849 }
850
851 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
852 #[test]
853 fn test_rewards_capped_at_max_capacity() {
854 use aws_smithy_async::test_util::ManualTimeSource;
855 use std::time::UNIX_EPOCH;
856
857 let time_source = ManualTimeSource::new(UNIX_EPOCH);
858 let current_time_secs = UNIX_EPOCH
859 .duration_since(SystemTime::UNIX_EPOCH)
860 .unwrap()
861 .as_secs() as u32;
862
863 let bucket = TokenBucket {
864 refill_rate: 50.0,
865 success_reward: 2.0,
866 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
867 semaphore: Arc::new(Semaphore::new(5)),
868 max_permits: 10,
869 ..Default::default()
870 };
871
872 for _ in 0..50 {
874 bucket.reward_success();
875 }
876
877 assert_eq!(bucket.fractional_tokens.load(), 10.0);
879
880 time_source.advance(Duration::from_secs(100));
882
883 bucket.refill_tokens_based_on_time(&time_source);
886
887 assert_eq!(
889 bucket.fractional_tokens.load(),
890 10.0,
891 "Fractional tokens should be capped at max_permits"
892 );
893 bucket.convert_fractional_tokens();
895 assert_eq!(bucket.available_permits(), 10);
896 }
897
898 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
899 #[test]
900 fn test_concurrent_time_based_refill_no_over_generation() {
901 use aws_smithy_async::test_util::ManualTimeSource;
902 use std::sync::{Arc, Barrier};
903 use std::thread;
904 use std::time::UNIX_EPOCH;
905
906 let time_source = ManualTimeSource::new(UNIX_EPOCH);
907 let current_time_secs = UNIX_EPOCH
908 .duration_since(SystemTime::UNIX_EPOCH)
909 .unwrap()
910 .as_secs() as u32;
911
912 let bucket = Arc::new(TokenBucket {
914 refill_rate: 1.0,
915 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
916 semaphore: Arc::new(Semaphore::new(0)),
917 max_permits: 100,
918 ..Default::default()
919 });
920
921 time_source.advance(Duration::from_secs(10));
923 let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
924
925 let barrier = Arc::new(Barrier::new(100));
927 let mut handles = Vec::new();
928
929 for _ in 0..100 {
930 let bucket_clone1 = Arc::clone(&bucket);
931 let barrier_clone1 = Arc::clone(&barrier);
932 let time_source_clone1 = shared_time_source.clone();
933 let bucket_clone2 = Arc::clone(&bucket);
934 let barrier_clone2 = Arc::clone(&barrier);
935 let time_source_clone2 = shared_time_source.clone();
936
937 let handle1 = thread::spawn(move || {
938 barrier_clone1.wait();
940
941 bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
943 });
944
945 let handle2 = thread::spawn(move || {
946 barrier_clone2.wait();
948
949 bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
951 });
952 handles.push(handle1);
953 handles.push(handle2);
954 }
955
956 for handle in handles {
958 handle.join().unwrap();
959 }
960
961 bucket.convert_fractional_tokens();
963
964 assert_eq!(
967 bucket.available_permits(),
968 10,
969 "Only one thread should have added tokens, not all 100"
970 );
971
972 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
974 }
975
976 #[test]
978 fn test_is_full_accounts_for_fractional_tokens() {
979 let bucket = TokenBucket::builder()
980 .capacity(2)
981 .retry_cost(1)
982 .success_reward(0.9)
983 .build();
984
985 assert!(bucket.is_full());
986
987 let _p1 = bucket
988 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
989 .unwrap();
990 let _p2 = bucket
991 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
992 .unwrap();
993
994 assert!(bucket.is_empty());
995
996 bucket.reward_success();
999 bucket.reward_success();
1000 bucket.reward_success();
1001
1002 assert!(bucket.is_full());
1005 assert!(!bucket.is_empty());
1006 }
1007
1008 #[test]
1009 fn test_is_empty_accounts_for_fractional_tokens() {
1010 let bucket = TokenBucket::builder()
1011 .capacity(10)
1012 .retry_cost(10)
1013 .success_reward(0.5)
1014 .build();
1015
1016 let _p = bucket
1017 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1018 .unwrap();
1019 assert_eq!(bucket.semaphore.available_permits(), 0);
1020
1021 bucket.reward_success();
1023 assert!(bucket.is_empty());
1024
1025 bucket.reward_success();
1027 assert!(!bucket.is_empty());
1028 }
1029}