load_balancer/
threshold.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5    future::Future,
6    sync::{Arc, atomic::AtomicU64},
7    time::{Duration, Instant},
8};
9use tokio::{
10    spawn,
11    sync::{Mutex, RwLock},
12    task::{JoinHandle, yield_now},
13    time::sleep,
14};
15
16/// Represents a single entry in the threshold load balancer.
17/// Tracks the maximum allowed requests, maximum errors, current usage count, and error count.
18pub struct Entry<T>
19where
20    T: Send + Sync + Clone + 'static,
21{
22    /// Maximum number of allocations per interval.
23    pub max_count: u64,
24    /// Maximum number of allowed errors before the entry is considered disabled.
25    pub max_error_count: u64,
26    /// Current allocation count.
27    pub count: AtomicU64,
28    /// Current error count.
29    pub error_count: AtomicU64,
30    /// The actual value being balanced (e.g., client, resource).
31    pub value: T,
32}
33
34impl<T> Entry<T>
35where
36    T: Send + Sync + Clone + 'static,
37{
38    /// Reset both the allocation count and the error count.
39    pub fn reset(&self) {
40        self.count.store(0, Release);
41        self.error_count.store(0, Release);
42    }
43
44    /// Disable this entry by setting its error count to the maximum.
45    pub fn disable(&self) {
46        self.error_count.store(self.max_error_count, Release);
47    }
48}
49
50impl<T> Clone for Entry<T>
51where
52    T: Send + Sync + Clone + 'static,
53{
54    fn clone(&self) -> Self {
55        Self {
56            max_count: self.max_count.clone(),
57            max_error_count: self.max_error_count.clone(),
58            count: self.count.load(Acquire).into(),
59            error_count: self.error_count.load(Acquire).into(),
60            value: self.value.clone(),
61        }
62    }
63}
64
65/// Internal representation of the threshold load balancer.
66pub struct ThresholdLoadBalancerRef<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    /// List of entries to balance between.
71    pub entries: RwLock<Vec<Entry<T>>>,
72    /// Optional background timer handle for resetting counts periodically.
73    pub timer: Mutex<Option<JoinHandle<()>>>,
74    /// Interval duration for resetting counts.
75    pub interval: RwLock<Duration>,
76    /// The next scheduled reset time.
77    pub next_reset: RwLock<Instant>,
78}
79
80/// Threshold-based load balancer that limits allocations per entry and handles failures.
81#[derive(Clone)]
82pub struct ThresholdLoadBalancer<T>
83where
84    T: Send + Sync + Clone + 'static,
85{
86    inner: Arc<ThresholdLoadBalancerRef<T>>,
87}
88
89impl<T> ThresholdLoadBalancer<T>
90where
91    T: Send + Sync + Clone + 'static,
92{
93    /// Create a new threshold load balancer with a fixed 1-second interval.
94    ///
95    /// # Arguments
96    ///
97    /// * `entries` - A vector of tuples `(max_count, max_error_count, value)`:
98    ///     - `max_count`: Maximum number of allocations allowed per interval.
99    ///     - `max_error_count`: Maximum number of errors allowed before disabling the entry.
100    ///     - `value`: value.
101    pub fn new(entries: Vec<(u64, u64, T)>) -> Self {
102        Self::new_interval(entries, Duration::from_secs(1))
103    }
104
105    /// Create a new threshold load balancer with a custom interval.
106    ///
107    /// # Arguments
108    ///
109    /// * `entries` - A vector of tuples `(max_count, max_error_count, value)`:
110    ///     - `max_count`: Maximum number of allocations allowed per interval.
111    ///     - `max_error_count`: Maximum number of errors allowed before disabling the entry.
112    ///     - `value`: value.
113    ///
114    /// * `interval` - Duration after which all allocation/error counts are reset.
115    pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
116        Self {
117            inner: Arc::new(ThresholdLoadBalancerRef {
118                entries: entries
119                    .into_iter()
120                    .map(|(max_count, max_error_count, value)| Entry {
121                        max_count,
122                        max_error_count,
123                        value,
124                        count: 0.into(),
125                        error_count: 0.into(),
126                    })
127                    .collect::<Vec<_>>()
128                    .into(),
129                timer: Mutex::new(None),
130                interval: interval.into(),
131                next_reset: RwLock::new(Instant::now() + interval),
132            }),
133        }
134    }
135
136    /// Execute a custom async update on the internal reference.
137    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
138    where
139        F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
140        R: Future<Output = anyhow::Result<()>>,
141    {
142        handle(self.inner.clone()).await
143    }
144
145    /// Allocate an entry, skipping the specified index if provided.
146    pub async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
147        loop {
148            if let Some(v) = self.try_alloc_skip(skip_index) {
149                return v;
150            }
151
152            let now = Instant::now();
153
154            let next = *self.inner.next_reset.read().await;
155
156            let remaining = if now < next {
157                next - now
158            } else {
159                Duration::ZERO
160            };
161
162            if remaining > Duration::ZERO {
163                sleep(remaining).await;
164            } else {
165                yield_now().await;
166            }
167        }
168    }
169
170    /// Try to allocate an entry immediately, skipping the specified index if provided.
171    pub fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
172        if let Ok(mut timer_guard) = self.inner.timer.try_lock() {
173            if timer_guard.is_none() {
174                let this = self.inner.clone();
175
176                *timer_guard = Some(spawn(async move {
177                    let mut interval = *this.interval.read().await;
178
179                    *this.next_reset.write().await = Instant::now() + interval;
180
181                    loop {
182                        sleep(match this.interval.try_read() {
183                            Ok(v) => {
184                                interval = *v;
185                                interval
186                            }
187                            Err(_) => interval,
188                        })
189                        .await;
190
191                        let now = Instant::now();
192
193                        let entries = this.entries.read().await;
194
195                        for entry in entries.iter() {
196                            entry.count.store(0, Release);
197                        }
198
199                        *this.next_reset.write().await = now + interval;
200                    }
201                }));
202            }
203        }
204
205        if let Ok(entries) = self.inner.entries.try_read() {
206            let mut skip_count = 0;
207
208            for (i, entry) in entries.iter().enumerate() {
209                if i == skip_index {
210                    continue;
211                }
212
213                if entry.max_error_count != 0
214                    && entry.error_count.load(Acquire) >= entry.max_error_count
215                {
216                    skip_count += 1;
217                    continue;
218                }
219
220                let count = entry.count.load(Acquire);
221
222                if entry.max_count == 0
223                    || (count < entry.max_count
224                        && entry
225                            .count
226                            .compare_exchange(count, count + 1, Release, Acquire)
227                            .is_ok())
228                {
229                    return Some((i, entry.value.clone()));
230                }
231            }
232
233            if skip_count == entries.len() {
234                return None;
235            }
236        }
237
238        None
239    }
240
241    /// Mark a successful usage for the entry at the given index.
242    pub fn success(&self, index: usize) {
243        if let Ok(entries) = self.inner.entries.try_read() {
244            if let Some(entry) = entries.get(index) {
245                let current = entry.error_count.load(Acquire);
246
247                if current != 0 {
248                    let _ =
249                        entry
250                            .error_count
251                            .compare_exchange(current, current - 1, Release, Acquire);
252                }
253            }
254        }
255    }
256
257    /// Mark a failure for the entry at the given index.
258    pub fn failure(&self, index: usize) {
259        if let Ok(entries) = self.inner.entries.try_read() {
260            if let Some(entry) = entries.get(index) {
261                entry.error_count.fetch_add(1, Release);
262            }
263        }
264    }
265}
266
267impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
268where
269    T: Clone + Send + Sync + 'static,
270{
271    /// Allocate an entry asynchronously.
272    fn alloc(&self) -> impl Future<Output = T> + Send {
273        async move { self.alloc_skip(usize::MAX).await.1 }
274    }
275
276    /// Attempt to allocate an entry immediately.
277    fn try_alloc(&self) -> Option<T> {
278        self.try_alloc_skip(usize::MAX).map(|v| v.1)
279    }
280}
281
282#[async_trait]
283impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
284where
285    T: Send + Sync + Clone + 'static,
286{
287    /// Allocate an entry asynchronously.
288    async fn alloc(&self) -> T {
289        self.alloc_skip(usize::MAX).await.1
290    }
291
292    /// Attempt to allocate an entry immediately.
293    fn try_alloc(&self) -> Option<T> {
294        self.try_alloc_skip(usize::MAX).map(|v| v.1)
295    }
296}