load_balancer/
threshold.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5    future::Future,
6    sync::{Arc, atomic::AtomicU64},
7    time::Duration,
8};
9use tokio::{
10    spawn,
11    sync::{Mutex, RwLock},
12    task::{JoinHandle, yield_now},
13    time::sleep,
14};
15
16/// Represents a single entry in the threshold load balancer.
17/// Tracks the maximum allowed requests, maximum errors, current usage count, and error count.
18pub struct Entry<T>
19where
20    T: Send + Sync + Clone + 'static,
21{
22    /// Maximum number of allowed allocations per interval.
23    pub max_count: u64,
24    /// Maximum number of allowed errors before the entry is considered disabled.
25    pub max_error_count: u64,
26    /// Current allocation count.
27    pub count: AtomicU64,
28    /// Current error count.
29    pub error_count: AtomicU64,
30    /// The actual value being balanced (e.g., client, resource).
31    pub value: T,
32}
33
34impl<T> Entry<T>
35where
36    T: Send + Sync + Clone + 'static,
37{
38    /// Reset both the allocation count and the error count.
39    pub fn reset(&self) {
40        self.count.store(0, Release);
41        self.error_count.store(0, Release);
42    }
43
44    /// Disable this entry by setting its error count to the maximum.
45    pub fn disable(&self) {
46        self.error_count.store(self.max_error_count, Release);
47    }
48}
49
50impl<T> Clone for Entry<T>
51where
52    T: Send + Sync + Clone + 'static,
53{
54    fn clone(&self) -> Self {
55        Self {
56            max_count: self.max_count.clone(),
57            max_error_count: self.max_error_count.clone(),
58            count: self.count.load(Acquire).into(),
59            error_count: self.error_count.load(Acquire).into(),
60            value: self.value.clone(),
61        }
62    }
63}
64
65/// Internal representation of the threshold load balancer.
66pub struct ThresholdLoadBalancerRef<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    /// List of entries to balance between.
71    pub entries: RwLock<Vec<Entry<T>>>,
72    /// Optional background timer handle for resetting counts periodically.
73    pub timer: Mutex<Option<JoinHandle<()>>>,
74    /// Interval duration for resetting counts.
75    pub interval: RwLock<Duration>,
76}
77
78/// Threshold-based load balancer that limits allocations per entry and handles failures.
79#[derive(Clone)]
80pub struct ThresholdLoadBalancer<T>
81where
82    T: Send + Sync + Clone + 'static,
83{
84    inner: Arc<ThresholdLoadBalancerRef<T>>,
85}
86
87impl<T> ThresholdLoadBalancer<T>
88where
89    T: Send + Sync + Clone + 'static,
90{
91    /// Create a new threshold load balancer with a fixed 1-second interval.
92    ///
93    /// # Arguments
94    ///
95    /// * `entries` - A vector of tuples `(max_count, max_error_count, value)`:
96    ///     - `max_count`: Maximum number of allocations allowed per interval.
97    ///     - `max_error_count`: Maximum number of errors allowed before disabling the entry.
98    ///     - `value`: value.
99    pub fn new(entries: Vec<(u64, u64, T)>) -> Self {
100        Self::new_interval(entries, Duration::from_secs(1))
101    }
102
103    /// Create a new threshold load balancer with a custom interval.
104    ///
105    /// # Arguments
106    ///
107    /// * `entries` - A vector of tuples `(max_count, max_error_count, value)`:
108    ///     - `max_count`: Maximum number of allocations allowed per interval.
109    ///     - `max_error_count`: Maximum number of errors allowed before disabling the entry.
110    ///     - `value`: value.
111    ///
112    /// * `interval` - Duration after which all allocation/error counts are reset.
113    pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
114        Self {
115            inner: Arc::new(ThresholdLoadBalancerRef {
116                entries: entries
117                    .into_iter()
118                    .map(|(max_count, max_error_count, value)| Entry {
119                        max_count,
120                        max_error_count,
121                        value,
122                        count: 0.into(),
123                        error_count: 0.into(),
124                    })
125                    .collect::<Vec<_>>()
126                    .into(),
127                timer: Mutex::new(None),
128                interval: interval.into(),
129            }),
130        }
131    }
132
133    /// Execute a custom async update on the internal reference.
134    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
135    where
136        F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
137        R: Future<Output = anyhow::Result<()>>,
138    {
139        handle(self.inner.clone()).await
140    }
141
142    /// Allocate an entry, skipping the specified index if provided.
143    pub async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
144        loop {
145            match self.try_alloc_skip(skip_index) {
146                Some(v) => return v,
147                None => yield_now().await,
148            };
149        }
150    }
151
152    /// Try to allocate an entry immediately, skipping the specified index if provided.
153    pub fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
154        if let Ok(mut v) = self.inner.timer.try_lock() {
155            if v.is_none() {
156                let this = self.inner.clone();
157
158                *v = Some(spawn(async move {
159                    let mut interval = *this.interval.read().await;
160
161                    loop {
162                        sleep(match this.interval.try_read() {
163                            Ok(v) => {
164                                interval = *v;
165                                interval
166                            }
167                            Err(_) => interval,
168                        })
169                        .await;
170
171                        // Reset the allocation count for all entries.
172                        for i in this.entries.read().await.iter() {
173                            i.count.store(0, Release);
174                        }
175                    }
176                }));
177            }
178        }
179
180        if let Ok(entries) = self.inner.entries.try_read() {
181            let mut skip_count = 0;
182
183            for (i, entry) in entries.iter().enumerate() {
184                if i == skip_index {
185                    continue;
186                }
187
188                if entry.max_error_count != 0
189                    && entry.error_count.load(Acquire) >= entry.max_error_count
190                {
191                    skip_count += 1;
192                    continue;
193                }
194
195                let count = entry.count.load(Acquire);
196
197                if entry.max_count == 0
198                    || (count < entry.max_count
199                        && entry
200                            .count
201                            .compare_exchange(count, count + 1, Release, Acquire)
202                            .is_ok())
203                {
204                    return Some((i, entry.value.clone()));
205                }
206            }
207
208            // All entries are skipped due to errors.
209            if skip_count == entries.len() {
210                return None;
211            }
212        }
213
214        None
215    }
216
217    /// Mark a successful usage for the entry at the given index.
218    pub fn success(&self, index: usize) {
219        if let Ok(entries) = self.inner.entries.try_read() {
220            if let Some(entry) = entries.get(index) {
221                let current = entry.error_count.load(Acquire);
222
223                if current != 0 {
224                    let _ =
225                        entry
226                            .error_count
227                            .compare_exchange(current, current - 1, Release, Acquire);
228                }
229            }
230        }
231    }
232
233    /// Mark a failure for the entry at the given index.
234    pub fn failure(&self, index: usize) {
235        if let Ok(entries) = self.inner.entries.try_read() {
236            if let Some(entry) = entries.get(index) {
237                entry.error_count.fetch_add(1, Release);
238            }
239        }
240    }
241}
242
243impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
244where
245    T: Clone + Send + Sync + 'static,
246{
247    /// Allocate an entry asynchronously.
248    fn alloc(&self) -> impl Future<Output = T> + Send {
249        async move { self.alloc_skip(usize::MAX).await.1 }
250    }
251
252    /// Attempt to allocate an entry immediately.
253    fn try_alloc(&self) -> Option<T> {
254        self.try_alloc_skip(usize::MAX).map(|v| v.1)
255    }
256}
257
258#[async_trait]
259impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
260where
261    T: Send + Sync + Clone + 'static,
262{
263    /// Allocate an entry asynchronously.
264    async fn alloc(&self) -> T {
265        self.alloc_skip(usize::MAX).await.1
266    }
267
268    /// Attempt to allocate an entry immediately.
269    fn try_alloc(&self) -> Option<T> {
270        self.try_alloc_skip(usize::MAX).map(|v| v.1)
271    }
272}