Skip to main content

aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29/// Token bucket used for standard and adaptive retry.
30#[derive(Clone, Debug)]
31pub struct TokenBucket {
32    semaphore: Arc<Semaphore>,
33    max_permits: usize,
34    timeout_retry_cost: u32,
35    retry_cost: u32,
36    success_reward: f32,
37    fractional_tokens: Arc<AtomicF32>,
38    refill_rate: f32,
39    // Note this value is only an AtomicU32 so it works on 32bit powerpc architectures.
40    // If we ever remove the need for that compatibility it should become an AtomicU64
41    last_refill_time_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47    storage: AtomicU32,
48}
49impl AtomicF32 {
50    fn new(value: f32) -> Self {
51        let as_u32 = value.to_bits();
52        Self {
53            storage: AtomicU32::new(as_u32),
54        }
55    }
56    fn store(&self, value: f32) {
57        let as_u32 = value.to_bits();
58        self.storage.store(as_u32, Ordering::Relaxed)
59    }
60    fn load(&self) -> f32 {
61        let as_u32 = self.storage.load(Ordering::Relaxed);
62        f32::from_bits(as_u32)
63    }
64}
65
66impl fmt::Debug for AtomicF32 {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        // Use debug_struct, debug_tuple, or write! for formatting
69        f.debug_struct("AtomicF32")
70            .field("value", &self.load())
71            .finish()
72    }
73}
74
75impl Clone for AtomicF32 {
76    fn clone(&self) -> Self {
77        // Manually clone each field
78        AtomicF32 {
79            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80        }
81    }
82}
83
84impl Storable for TokenBucket {
85    type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89    fn default() -> Self {
90        Self {
91            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
92            max_permits: DEFAULT_CAPACITY,
93            timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
94            retry_cost: DEFAULT_RETRY_COST,
95            success_reward: DEFAULT_SUCCESS_REWARD,
96            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
97            refill_rate: 0.0,
98            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
99        }
100    }
101}
102
103impl TokenBucket {
104    /// Creates a new `TokenBucket` with the given initial quota.
105    pub fn new(initial_quota: usize) -> Self {
106        Self {
107            semaphore: Arc::new(Semaphore::new(initial_quota)),
108            max_permits: initial_quota,
109            ..Default::default()
110        }
111    }
112
113    /// A token bucket with unlimited capacity that allows retries at no cost.
114    pub fn unlimited() -> Self {
115        Self {
116            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
117            max_permits: MAXIMUM_CAPACITY,
118            timeout_retry_cost: 0,
119            retry_cost: 0,
120            success_reward: 0.0,
121            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
122            refill_rate: 0.0,
123            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
124        }
125    }
126
127    /// Creates a builder for constructing a `TokenBucket`.
128    pub fn builder() -> TokenBucketBuilder {
129        TokenBucketBuilder::default()
130    }
131
132    pub(crate) fn acquire(
133        &self,
134        err: &ErrorKind,
135        time_source: &impl TimeSource,
136    ) -> Option<OwnedSemaphorePermit> {
137        // Add time-based tokens to fractional accumulator
138        self.refill_tokens_based_on_time(time_source);
139        // Convert accumulated fractional tokens to whole tokens
140        self.convert_fractional_tokens();
141
142        let retry_cost = if err == &ErrorKind::TransientError {
143            self.timeout_retry_cost
144        } else {
145            self.retry_cost
146        };
147
148        self.semaphore
149            .clone()
150            .try_acquire_many_owned(retry_cost)
151            .ok()
152    }
153
154    pub(crate) fn success_reward(&self) -> f32 {
155        self.success_reward
156    }
157
158    pub(crate) fn regenerate_a_token(&self) {
159        self.add_permits(PERMIT_REGENERATION_AMOUNT);
160    }
161
162    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
163    /// Stores the remaining fractional amount back.
164    /// This is shared by both time-based refill and success rewards.
165    #[inline]
166    fn convert_fractional_tokens(&self) {
167        let mut calc_fractional_tokens = self.fractional_tokens.load();
168        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
169        if !calc_fractional_tokens.is_finite() {
170            tracing::error!(
171                "Fractional tokens corrupted to: {}, resetting to 0.0",
172                calc_fractional_tokens
173            );
174            self.fractional_tokens.store(0.0);
175            return;
176        }
177
178        let full_tokens_accumulated = calc_fractional_tokens.floor();
179        if full_tokens_accumulated >= 1.0 {
180            self.add_permits(full_tokens_accumulated as usize);
181            calc_fractional_tokens -= full_tokens_accumulated;
182        }
183        // Always store the updated fractional tokens back, even if no conversion happened
184        self.fractional_tokens.store(calc_fractional_tokens);
185    }
186
187    /// Refills tokens based on elapsed time since last refill.
188    /// This method implements lazy evaluation - tokens are only calculated when accessed.
189    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
190    #[inline]
191    fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
192        if self.refill_rate > 0.0 {
193            // The cast to u32 here is safe until 2106, and I will be long dead then so ¯\_(ツ)_/¯
194            let current_time_secs = time_source
195                .now()
196                .duration_since(SystemTime::UNIX_EPOCH)
197                .unwrap_or(Duration::ZERO)
198                .as_secs() as u32;
199
200            let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
201
202            // Early exit if no time elapsed - most threads take this path
203            if current_time_secs == last_refill_secs {
204                return;
205            }
206
207            // Try to atomically claim this time window with a single CAS
208            // If we lose, another thread is handling the refill, so we can exit
209            if self
210                .last_refill_time_secs
211                .compare_exchange(
212                    last_refill_secs,
213                    current_time_secs,
214                    Ordering::Relaxed,
215                    Ordering::Relaxed,
216                )
217                .is_err()
218            {
219                // Another thread claimed this time window, we're done
220                return;
221            }
222
223            // We won the CAS - we're responsible for adding tokens for this time window
224            let current_fractional = self.fractional_tokens.load();
225            let max_fractional = self.max_permits as f32;
226
227            // Skip token addition if already at cap
228            if current_fractional >= max_fractional {
229                return;
230            }
231
232            let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
233            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
234
235            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
236            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
237            self.fractional_tokens.store(new_fractional);
238        }
239    }
240
241    #[inline]
242    pub(crate) fn reward_success(&self) {
243        if self.success_reward > 0.0 {
244            let current = self.fractional_tokens.load();
245            let max_fractional = self.max_permits as f32;
246            // Early exit if already at cap - no point calculating
247            if current >= max_fractional {
248                return;
249            }
250            // Cap fractional tokens at max_permits to prevent unbounded growth
251            let new_fractional = (current + self.success_reward).min(max_fractional);
252            self.fractional_tokens.store(new_fractional);
253        }
254    }
255
256    pub(crate) fn add_permits(&self, amount: usize) {
257        let available = self.semaphore.available_permits();
258        if available >= self.max_permits {
259            return;
260        }
261        self.semaphore
262            .add_permits(amount.min(self.max_permits - available));
263    }
264
265    /// Returns true if the token bucket is full, false otherwise
266    pub fn is_full(&self) -> bool {
267        self.convert_fractional_tokens();
268        self.semaphore.available_permits() >= self.max_permits
269    }
270
271    /// Returns true if the token bucket is empty, false otherwise
272    pub fn is_empty(&self) -> bool {
273        self.convert_fractional_tokens();
274        self.semaphore.available_permits() == 0
275    }
276
277    #[allow(dead_code)] // only used in tests
278    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
279    pub(crate) fn available_permits(&self) -> usize {
280        self.semaphore.available_permits()
281    }
282
283    /// Only used in tests
284    #[allow(dead_code)]
285    #[doc(hidden)]
286    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
287    pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
288        self.last_refill_time_secs.clone()
289    }
290}
291
292/// Builder for constructing a `TokenBucket`.
293#[derive(Clone, Debug, Default)]
294pub struct TokenBucketBuilder {
295    capacity: Option<usize>,
296    retry_cost: Option<u32>,
297    timeout_retry_cost: Option<u32>,
298    success_reward: Option<f32>,
299    refill_rate: Option<f32>,
300}
301
302impl TokenBucketBuilder {
303    /// Creates a new `TokenBucketBuilder` with default values.
304    pub fn new() -> Self {
305        Self::default()
306    }
307
308    /// Sets the maximum bucket capacity for the builder.
309    pub fn capacity(mut self, mut capacity: usize) -> Self {
310        if capacity > MAXIMUM_CAPACITY {
311            capacity = MAXIMUM_CAPACITY;
312        }
313        self.capacity = Some(capacity);
314        self
315    }
316
317    /// Sets the specified retry cost for the builder.
318    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
319        self.retry_cost = Some(retry_cost);
320        self
321    }
322
323    /// Sets the specified timeout retry cost for the builder.
324    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
325        self.timeout_retry_cost = Some(timeout_retry_cost);
326        self
327    }
328
329    /// Sets the reward for any successful request for the builder.
330    pub fn success_reward(mut self, reward: f32) -> Self {
331        self.success_reward = Some(reward);
332        self
333    }
334
335    /// Sets the refill rate (tokens per second) for time-based token regeneration.
336    ///
337    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
338    /// Non-finite values (NaN, infinity) are treated as 0.0.
339    pub fn refill_rate(mut self, rate: f32) -> Self {
340        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
341        self.refill_rate = Some(validated_rate);
342        self
343    }
344
345    /// Builds a `TokenBucket`.
346    pub fn build(self) -> TokenBucket {
347        TokenBucket {
348            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
349            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
350            retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
351            timeout_retry_cost: self
352                .timeout_retry_cost
353                .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
354            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
355            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
356            refill_rate: self.refill_rate.unwrap_or(0.0),
357            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364
365    use super::*;
366    use aws_smithy_async::test_util::ManualTimeSource;
367    use std::{sync::LazyLock, time::UNIX_EPOCH};
368
369    static TIME_SOURCE: LazyLock<ManualTimeSource> =
370        LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
371
372    #[test]
373    fn test_unlimited_token_bucket() {
374        let bucket = TokenBucket::unlimited();
375
376        // Should always acquire permits regardless of error type
377        assert!(bucket
378            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
379            .is_some());
380        assert!(bucket
381            .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
382            .is_some());
383
384        // Should have maximum capacity
385        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
386
387        // Should have zero retry costs
388        assert_eq!(bucket.retry_cost, 0);
389        assert_eq!(bucket.timeout_retry_cost, 0);
390
391        // The loop count is arbitrary; should obtain permits without limit
392        let mut permits = Vec::new();
393        for _ in 0..100 {
394            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
395            assert!(permit.is_some());
396            permits.push(permit);
397            // Available permits should stay constant
398            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
399        }
400    }
401
402    #[test]
403    fn test_bounded_permits_exhaustion() {
404        let bucket = TokenBucket::new(10);
405        let mut permits = Vec::new();
406
407        for _ in 0..100 {
408            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
409            if let Some(p) = permit {
410                permits.push(p);
411            } else {
412                break;
413            }
414        }
415
416        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
417
418        // Verify next acquisition fails
419        assert!(bucket
420            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
421            .is_none());
422    }
423
424    #[test]
425    fn test_fractional_tokens_accumulate_and_convert() {
426        let bucket = TokenBucket::builder()
427            .capacity(10)
428            .success_reward(0.4)
429            .build();
430
431        // acquire 10 tokens to bring capacity below max so we can test accumulation
432        let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
433        assert_eq!(bucket.semaphore.available_permits(), 0);
434
435        // First success: 0.4 fractional tokens
436        bucket.reward_success();
437        bucket.convert_fractional_tokens();
438        assert_eq!(bucket.semaphore.available_permits(), 0);
439
440        // Second success: 0.8 fractional tokens
441        bucket.reward_success();
442        bucket.convert_fractional_tokens();
443        assert_eq!(bucket.semaphore.available_permits(), 0);
444
445        // Third success: 1.2 fractional tokens -> 1 full token added
446        bucket.reward_success();
447        bucket.convert_fractional_tokens();
448        assert_eq!(bucket.semaphore.available_permits(), 1);
449    }
450
451    #[test]
452    fn test_fractional_tokens_respect_max_capacity() {
453        let bucket = TokenBucket::builder()
454            .capacity(10)
455            .success_reward(2.0)
456            .build();
457
458        for _ in 0..20 {
459            bucket.reward_success();
460        }
461
462        assert!(bucket.semaphore.available_permits() == 10);
463    }
464
465    #[test]
466    fn test_convert_fractional_tokens() {
467        // (input, expected_permits_added, expected_remaining)
468        let test_cases = [
469            (0.7, 0, 0.7),
470            (1.0, 1, 0.0),
471            (2.3, 2, 0.3),
472            (5.8, 5, 0.8),
473            (10.0, 10, 0.0),
474            // verify that if fractional permits are corrupted, we reset to 0 gracefully
475            (f32::NAN, 0, 0.0),
476            (f32::INFINITY, 0, 0.0),
477        ];
478
479        for (input, expected_permits, expected_remaining) in test_cases {
480            let bucket = TokenBucket::builder().capacity(10).build();
481            let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
482            let initial = bucket.semaphore.available_permits();
483
484            bucket.fractional_tokens.store(input);
485            bucket.convert_fractional_tokens();
486
487            assert_eq!(
488                bucket.semaphore.available_permits() - initial,
489                expected_permits
490            );
491            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
492        }
493    }
494
495    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
496    #[test]
497    fn test_builder_with_custom_values() {
498        let bucket = TokenBucket::builder()
499            .capacity(100)
500            .retry_cost(10)
501            .timeout_retry_cost(20)
502            .success_reward(0.5)
503            .refill_rate(2.5)
504            .build();
505
506        assert_eq!(bucket.max_permits, 100);
507        assert_eq!(bucket.retry_cost, 10);
508        assert_eq!(bucket.timeout_retry_cost, 20);
509        assert_eq!(bucket.success_reward, 0.5);
510        assert_eq!(bucket.refill_rate, 2.5);
511    }
512
513    #[test]
514    fn test_builder_refill_rate_validation() {
515        // Test negative values are clamped to 0.0
516        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
517        assert_eq!(bucket.refill_rate, 0.0);
518
519        // Test valid positive value
520        let bucket = TokenBucket::builder().refill_rate(1.5).build();
521        assert_eq!(bucket.refill_rate, 1.5);
522
523        // Test zero is valid
524        let bucket = TokenBucket::builder().refill_rate(0.0).build();
525        assert_eq!(bucket.refill_rate, 0.0);
526    }
527
528    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
529    #[test]
530    fn test_builder_custom_time_source() {
531        use aws_smithy_async::test_util::ManualTimeSource;
532        use std::time::UNIX_EPOCH;
533
534        // Test that TokenBucket uses provided TimeSource when specified via builder
535        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
536        let bucket = TokenBucket::builder()
537            .capacity(100)
538            .refill_rate(1.0)
539            .build();
540
541        // Consume all tokens to test refill from empty state
542        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
543        assert_eq!(bucket.available_permits(), 0);
544
545        // Advance time and verify tokens are added based on manual time
546        manual_time.advance(Duration::from_secs(5));
547
548        bucket.refill_tokens_based_on_time(&manual_time);
549        bucket.convert_fractional_tokens();
550
551        // Should have 5 tokens (5 seconds * 1 token/sec)
552        assert_eq!(bucket.available_permits(), 5);
553    }
554
555    #[test]
556    fn test_atomicf32_f32_to_bits_conversion_correctness() {
557        // This is the core functionality
558        let test_values = vec![
559            0.0,
560            -0.0,
561            1.0,
562            -1.0,
563            f32::INFINITY,
564            f32::NEG_INFINITY,
565            f32::NAN,
566            f32::MIN,
567            f32::MAX,
568            f32::MIN_POSITIVE,
569            f32::EPSILON,
570            std::f32::consts::PI,
571            std::f32::consts::E,
572            // Test values that could expose bit manipulation bugs
573            1.23456789e-38, // Very small normal number
574            1.23456789e38,  // Very large number (within f32 range)
575            1.1754944e-38,  // Near MIN_POSITIVE for f32
576        ];
577
578        for &expected in &test_values {
579            let atomic = AtomicF32::new(expected);
580            let actual = atomic.load();
581
582            // For NaN, we can't use == but must check bit patterns
583            if expected.is_nan() {
584                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
585                // Different NaN bit patterns should be preserved exactly
586                assert_eq!(expected.to_bits(), actual.to_bits());
587            } else {
588                assert_eq!(expected.to_bits(), actual.to_bits());
589            }
590        }
591    }
592
593    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
594    #[test]
595    fn test_atomicf32_store_load_preserves_exact_bits() {
596        let atomic = AtomicF32::new(0.0);
597
598        // Test that store/load cycle preserves EXACT bit patterns
599        // This would catch bugs in the to_bits/from_bits conversion
600        let critical_bit_patterns = vec![
601            0x00000000u32, // +0.0
602            0x80000000u32, // -0.0
603            0x7F800000u32, // +infinity
604            0xFF800000u32, // -infinity
605            0x7FC00000u32, // Quiet NaN
606            0x7FA00000u32, // Signaling NaN
607            0x00000001u32, // Smallest positive subnormal
608            0x007FFFFFu32, // Largest subnormal
609            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
610        ];
611
612        for &expected_bits in &critical_bit_patterns {
613            let expected_f32 = f32::from_bits(expected_bits);
614            atomic.store(expected_f32);
615            let loaded_f32 = atomic.load();
616            let actual_bits = loaded_f32.to_bits();
617
618            assert_eq!(expected_bits, actual_bits);
619        }
620    }
621
622    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
623    #[test]
624    fn test_atomicf32_concurrent_store_load_safety() {
625        use std::sync::Arc;
626        use std::thread;
627
628        let atomic = Arc::new(AtomicF32::new(0.0));
629        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
630        let mut handles = Vec::new();
631
632        // Start multiple threads that continuously write different values
633        for &value in &test_values {
634            let atomic_clone = Arc::clone(&atomic);
635            let handle = thread::spawn(move || {
636                for _ in 0..1000 {
637                    atomic_clone.store(value);
638                }
639            });
640            handles.push(handle);
641        }
642
643        // Start a reader thread that continuously reads
644        let atomic_reader = Arc::clone(&atomic);
645        let reader_handle = thread::spawn(move || {
646            let mut readings = Vec::new();
647            for _ in 0..5000 {
648                let value = atomic_reader.load();
649                readings.push(value);
650            }
651            readings
652        });
653
654        // Wait for all writers to complete
655        for handle in handles {
656            handle.join().expect("Writer thread panicked");
657        }
658
659        let readings = reader_handle.join().expect("Reader thread panicked");
660
661        // Verify that all read values are valid (one of the written values)
662        // This tests that there's no data corruption from concurrent access
663        for &reading in &readings {
664            assert!(test_values.contains(&reading) || reading == 0.0);
665
666            // More importantly, verify the reading is a valid f32
667            // (not corrupted bits that happen to parse as valid)
668            assert!(
669                reading.is_finite() || reading == 0.0,
670                "Corrupted reading detected"
671            );
672        }
673    }
674
675    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
676    #[test]
677    fn test_atomicf32_stress_concurrent_access() {
678        use std::sync::{Arc, Barrier};
679        use std::thread;
680
681        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
682        let atomic = Arc::new(AtomicF32::new(0.0));
683        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
684        let mut handles = Vec::new();
685
686        // Launch threads that all start simultaneously
687        for i in 0..10 {
688            let atomic_clone = Arc::clone(&atomic);
689            let barrier_clone = Arc::clone(&barrier);
690            let handle = thread::spawn(move || {
691                barrier_clone.wait(); // All threads start at same time
692
693                // Tight loop increases chance of race conditions
694                for _ in 0..10000 {
695                    let value = i as f32;
696                    atomic_clone.store(value);
697                    let loaded = atomic_clone.load();
698                    // Verify no corruption occurred
699                    assert!(loaded >= 0.0 && loaded <= 9.0);
700                    assert!(
701                        expected_values.contains(&loaded),
702                        "Got unexpected value: {}, expected one of {:?}",
703                        loaded,
704                        expected_values
705                    );
706                }
707            });
708            handles.push(handle);
709        }
710
711        for handle in handles {
712            handle.join().unwrap();
713        }
714    }
715
716    #[test]
717    fn test_atomicf32_integration_with_token_bucket_usage() {
718        let atomic = AtomicF32::new(0.0);
719        let success_reward = 0.3;
720        let iterations = 5;
721
722        // Accumulate fractional tokens
723        for _ in 1..=iterations {
724            let current = atomic.load();
725            atomic.store(current + success_reward);
726        }
727
728        let accumulated = atomic.load();
729        let expected_total = iterations as f32 * success_reward; // 1.5
730
731        // Test the floor() operation pattern
732        let full_tokens = accumulated.floor();
733        atomic.store(accumulated - full_tokens);
734        let remaining = atomic.load();
735
736        // These assertions should be general:
737        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
738        assert!(remaining >= 0.0 && remaining < 1.0);
739        assert_eq!(remaining, expected_total - expected_total.floor());
740    }
741
742    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
743    #[test]
744    fn test_atomicf32_clone_creates_independent_copy() {
745        let original = AtomicF32::new(123.456);
746        let cloned = original.clone();
747
748        // Verify they start with the same value
749        assert_eq!(original.load(), cloned.load());
750
751        // Verify they're independent - modifying one doesn't affect the other
752        original.store(999.0);
753        assert_eq!(
754            cloned.load(),
755            123.456,
756            "Clone should be unaffected by original changes"
757        );
758        assert_eq!(original.load(), 999.0, "Original should have new value");
759    }
760
761    #[test]
762    fn test_combined_time_and_success_rewards() {
763        use aws_smithy_async::test_util::ManualTimeSource;
764        use std::time::UNIX_EPOCH;
765
766        let time_source = ManualTimeSource::new(UNIX_EPOCH);
767        let current_time_secs = UNIX_EPOCH
768            .duration_since(SystemTime::UNIX_EPOCH)
769            .unwrap()
770            .as_secs() as u32;
771
772        let bucket = TokenBucket {
773            refill_rate: 1.0,
774            success_reward: 0.5,
775            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
776            semaphore: Arc::new(Semaphore::new(0)),
777            max_permits: 100,
778            ..Default::default()
779        };
780
781        // Add success rewards: 2 * 0.5 = 1.0 token
782        bucket.reward_success();
783        bucket.reward_success();
784
785        // Advance time by 2 seconds
786        time_source.advance(Duration::from_secs(2));
787
788        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
789        // Total: 1.0 + 2.0 = 3.0 tokens
790        bucket.refill_tokens_based_on_time(&time_source);
791        bucket.convert_fractional_tokens();
792
793        assert_eq!(bucket.available_permits(), 3);
794        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
795    }
796
797    #[test]
798    fn test_refill_rates() {
799        use aws_smithy_async::test_util::ManualTimeSource;
800        use std::time::UNIX_EPOCH;
801        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
802        let test_cases = [
803            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
804            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
805            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
806            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
807            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
808            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
809        ];
810
811        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
812            let time_source = ManualTimeSource::new(UNIX_EPOCH);
813            let current_time_secs = UNIX_EPOCH
814                .duration_since(SystemTime::UNIX_EPOCH)
815                .unwrap()
816                .as_secs() as u32;
817
818            let bucket = TokenBucket {
819                refill_rate,
820                last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
821                semaphore: Arc::new(Semaphore::new(0)),
822                max_permits: 100,
823                ..Default::default()
824            };
825
826            // Advance time by the specified duration
827            time_source.advance(Duration::from_secs(elapsed_secs));
828
829            bucket.refill_tokens_based_on_time(&time_source);
830            bucket.convert_fractional_tokens();
831
832            assert_eq!(
833                bucket.available_permits(),
834                expected_permits,
835                "Rate {}: After {}s expected {} permits",
836                refill_rate,
837                elapsed_secs,
838                expected_permits
839            );
840            assert!(
841                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
842                "Rate {}: After {}s expected {} fractional, got {}",
843                refill_rate,
844                elapsed_secs,
845                expected_fractional,
846                bucket.fractional_tokens.load()
847            );
848        }
849    }
850
851    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
852    #[test]
853    fn test_rewards_capped_at_max_capacity() {
854        use aws_smithy_async::test_util::ManualTimeSource;
855        use std::time::UNIX_EPOCH;
856
857        let time_source = ManualTimeSource::new(UNIX_EPOCH);
858        let current_time_secs = UNIX_EPOCH
859            .duration_since(SystemTime::UNIX_EPOCH)
860            .unwrap()
861            .as_secs() as u32;
862
863        let bucket = TokenBucket {
864            refill_rate: 50.0,
865            success_reward: 2.0,
866            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
867            semaphore: Arc::new(Semaphore::new(5)),
868            max_permits: 10,
869            ..Default::default()
870        };
871
872        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
873        for _ in 0..50 {
874            bucket.reward_success();
875        }
876
877        // Fractional tokens capped at 10 from success rewards
878        assert_eq!(bucket.fractional_tokens.load(), 10.0);
879
880        // Advance time by 100 seconds
881        time_source.advance(Duration::from_secs(100));
882
883        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
884        // But fractional is already at 10, so it stays at 10
885        bucket.refill_tokens_based_on_time(&time_source);
886
887        // Fractional tokens should be capped at max_permits (10)
888        assert_eq!(
889            bucket.fractional_tokens.load(),
890            10.0,
891            "Fractional tokens should be capped at max_permits"
892        );
893        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
894        bucket.convert_fractional_tokens();
895        assert_eq!(bucket.available_permits(), 10);
896    }
897
898    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
899    #[test]
900    fn test_concurrent_time_based_refill_no_over_generation() {
901        use aws_smithy_async::test_util::ManualTimeSource;
902        use std::sync::{Arc, Barrier};
903        use std::thread;
904        use std::time::UNIX_EPOCH;
905
906        let time_source = ManualTimeSource::new(UNIX_EPOCH);
907        let current_time_secs = UNIX_EPOCH
908            .duration_since(SystemTime::UNIX_EPOCH)
909            .unwrap()
910            .as_secs() as u32;
911
912        // Create bucket with 1 token/sec refill
913        let bucket = Arc::new(TokenBucket {
914            refill_rate: 1.0,
915            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
916            semaphore: Arc::new(Semaphore::new(0)),
917            max_permits: 100,
918            ..Default::default()
919        });
920
921        // Advance time by 10 seconds
922        time_source.advance(Duration::from_secs(10));
923        let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
924
925        // Launch 100 threads that all try to refill simultaneously
926        let barrier = Arc::new(Barrier::new(100));
927        let mut handles = Vec::new();
928
929        for _ in 0..100 {
930            let bucket_clone1 = Arc::clone(&bucket);
931            let barrier_clone1 = Arc::clone(&barrier);
932            let time_source_clone1 = shared_time_source.clone();
933            let bucket_clone2 = Arc::clone(&bucket);
934            let barrier_clone2 = Arc::clone(&barrier);
935            let time_source_clone2 = shared_time_source.clone();
936
937            let handle1 = thread::spawn(move || {
938                // Wait for all threads to be ready
939                barrier_clone1.wait();
940
941                // All threads call refill at the same time
942                bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
943            });
944
945            let handle2 = thread::spawn(move || {
946                // Wait for all threads to be ready
947                barrier_clone2.wait();
948
949                // All threads call refill at the same time
950                bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
951            });
952            handles.push(handle1);
953            handles.push(handle2);
954        }
955
956        // Wait for all threads to complete
957        for handle in handles {
958            handle.join().unwrap();
959        }
960
961        // Convert fractional tokens to whole tokens
962        bucket.convert_fractional_tokens();
963
964        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
965        // Not 1000 tokens (100 threads * 10 tokens each)
966        assert_eq!(
967            bucket.available_permits(),
968            10,
969            "Only one thread should have added tokens, not all 100"
970        );
971
972        // Fractional should be 0 after conversion
973        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
974    }
975
976    /// Regression test for https://github.com/awslabs/aws-sdk-rust/issues/1423
977    #[test]
978    fn test_is_full_accounts_for_fractional_tokens() {
979        let bucket = TokenBucket::builder()
980            .capacity(2)
981            .retry_cost(1)
982            .success_reward(0.9)
983            .build();
984
985        assert!(bucket.is_full());
986
987        let _p1 = bucket
988            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
989            .unwrap();
990        let _p2 = bucket
991            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
992            .unwrap();
993
994        assert!(bucket.is_empty());
995
996        // 3 rewards of 0.9 = 2.7 fractional tokens, which converts to 2 whole
997        // permits — enough to fill the bucket (capacity 2).
998        bucket.reward_success();
999        bucket.reward_success();
1000        bucket.reward_success();
1001
1002        // Before the fix, is_full() returned false here because fractional
1003        // tokens hadn't been converted to real permits.
1004        assert!(bucket.is_full());
1005        assert!(!bucket.is_empty());
1006    }
1007
1008    #[test]
1009    fn test_is_empty_accounts_for_fractional_tokens() {
1010        let bucket = TokenBucket::builder()
1011            .capacity(10)
1012            .retry_cost(10)
1013            .success_reward(0.5)
1014            .build();
1015
1016        let _p = bucket
1017            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1018            .unwrap();
1019        assert_eq!(bucket.semaphore.available_permits(), 0);
1020
1021        // 0.5 fractional tokens can't convert to a whole permit
1022        bucket.reward_success();
1023        assert!(bucket.is_empty());
1024
1025        // 1.0 fractional tokens converts to a permit
1026        bucket.reward_success();
1027        assert!(!bucket.is_empty());
1028    }
1029}