load_balancer/
threshold.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5    future::Future,
6    sync::{Arc, atomic::AtomicU64},
7    time::Duration,
8};
9use tokio::{
10    spawn,
11    sync::{Mutex, RwLock},
12    task::{JoinHandle, yield_now},
13    time::sleep,
14};
15
16/// Represents a single entry in the threshold load balancer.
17/// Tracks the maximum allowed requests, maximum errors, current usage count, and error count.
18#[derive(Debug)]
19pub struct Entry<T>
20where
21    T: Send + Sync + Clone + 'static,
22{
23    /// Maximum number of allowed allocations per interval.
24    pub max_count: u64,
25    /// Maximum number of allowed errors before the entry is considered disabled.
26    pub max_error_count: u64,
27    /// Current allocation count.
28    pub count: AtomicU64,
29    /// Current error count.
30    pub error_count: AtomicU64,
31    /// The actual value being balanced (e.g., client, resource).
32    pub value: T,
33}
34
35impl<T> Entry<T>
36where
37    T: Send + Sync + Clone + 'static,
38{
39    /// Reset both the allocation count and the error count.
40    pub fn reset(&self) {
41        self.count.store(0, Release);
42        self.error_count.store(0, Release);
43    }
44
45    /// Disable this entry by setting its error count to the maximum.
46    pub fn disable(&self) {
47        self.error_count.store(self.max_error_count, Release);
48    }
49}
50
51/// Internal representation of the threshold load balancer.
52pub struct ThresholdLoadBalancerRef<T>
53where
54    T: Send + Sync + Clone + 'static,
55{
56    /// List of entries to balance between.
57    pub entries: RwLock<Vec<Entry<T>>>,
58    /// Optional background timer handle for resetting counts periodically.
59    pub timer: Mutex<Option<JoinHandle<()>>>,
60    /// Interval duration for resetting counts.
61    pub interval: RwLock<Duration>,
62}
63
64/// Threshold-based load balancer that limits allocations per entry and handles failures.
65#[derive(Clone)]
66pub struct ThresholdLoadBalancer<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    inner: Arc<ThresholdLoadBalancerRef<T>>,
71}
72
73impl<T> ThresholdLoadBalancer<T>
74where
75    T: Send + Sync + Clone + 'static,
76{
77    /// Create a new threshold load balancer with a fixed interval.
78    pub fn new(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
79        Self {
80            inner: Arc::new(ThresholdLoadBalancerRef {
81                entries: entries
82                    .into_iter()
83                    .map(|(max_count, max_error_count, value)| Entry {
84                        max_count,
85                        max_error_count,
86                        value,
87                        count: 0.into(),
88                        error_count: 0.into(),
89                    })
90                    .collect::<Vec<_>>()
91                    .into(),
92                timer: Mutex::new(None),
93                interval: interval.into(),
94            }),
95        }
96    }
97
98    /// Create a new threshold load balancer with a custom interval.
99    pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
100        Self::new(entries, interval)
101    }
102
103    /// Execute a custom async update on the internal reference.
104    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
105    where
106        F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
107        R: Future<Output = anyhow::Result<()>>,
108    {
109        handle(self.inner.clone()).await
110    }
111
112    /// Allocate an entry, skipping the specified index if provided.
113    async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
114        loop {
115            match self.try_alloc_skip(skip_index) {
116                Some(v) => return v,
117                None => yield_now().await,
118            };
119        }
120    }
121
122    /// Try to allocate an entry immediately, skipping the specified index if provided.
123    fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
124        // Start the background timer if it is not already running.
125        if let Ok(mut v) = self.inner.timer.try_lock() {
126            if v.is_none() {
127                let this = self.inner.clone();
128
129                *v = Some(spawn(async move {
130                    let mut interval = *this.interval.read().await;
131
132                    loop {
133                        sleep(match this.interval.try_read() {
134                            Ok(v) => {
135                                interval = *v;
136                                interval
137                            }
138                            Err(_) => interval,
139                        })
140                        .await;
141
142                        // Reset the allocation count for all entries.
143                        for i in this.entries.read().await.iter() {
144                            i.count.store(0, Release);
145                        }
146                    }
147                }));
148            }
149        }
150
151        // Attempt to select a valid entry.
152        if let Ok(entries) = self.inner.entries.try_read() {
153            let mut skip_count = 0;
154
155            for (i, entry) in entries.iter().enumerate() {
156                if i == skip_index {
157                    continue;
158                }
159
160                if entry.max_error_count != 0
161                    && entry.error_count.load(Acquire) >= entry.max_error_count
162                {
163                    skip_count += 1;
164                    continue;
165                }
166
167                let count = entry.count.load(Acquire);
168
169                if entry.max_count == 0
170                    || (count < entry.max_count
171                        && entry
172                            .count
173                            .compare_exchange(count, count + 1, Release, Acquire)
174                            .is_ok())
175                {
176                    return Some((i, entry.value.clone()));
177                }
178            }
179
180            // All entries are skipped due to errors.
181            if skip_count == entries.len() {
182                return None;
183            }
184        }
185
186        None
187    }
188
189    /// Mark a successful usage for the entry at the given index.
190    pub fn success(&self, index: usize) {
191        if let Ok(entries) = self.inner.entries.try_read() {
192            if let Some(entry) = entries.get(index) {
193                let current = entry.error_count.load(Acquire);
194
195                if current != 0 {
196                    let _ =
197                        entry
198                            .error_count
199                            .compare_exchange(current, current - 1, Release, Acquire);
200                }
201            }
202        }
203    }
204
205    /// Mark a failure for the entry at the given index.
206    pub fn failure(&self, index: usize) {
207        if let Ok(entries) = self.inner.entries.try_read() {
208            if let Some(entry) = entries.get(index) {
209                entry.error_count.fetch_add(1, Release);
210            }
211        }
212    }
213}
214
215impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
216where
217    T: Clone + Send + Sync + 'static,
218{
219    /// Allocate an entry asynchronously.
220    fn alloc(&self) -> impl Future<Output = T> + Send {
221        async move { self.alloc_skip(usize::MAX).await.1 }
222    }
223
224    /// Attempt to allocate an entry immediately.
225    fn try_alloc(&self) -> Option<T> {
226        self.try_alloc_skip(usize::MAX).map(|v| v.1)
227    }
228}
229
230#[async_trait]
231impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
232where
233    T: Send + Sync + Clone + 'static,
234{
235    /// Allocate an entry asynchronously.
236    async fn alloc(&self) -> T {
237        self.alloc_skip(usize::MAX).await.1
238    }
239
240    /// Attempt to allocate an entry immediately.
241    fn try_alloc(&self) -> Option<T> {
242        self.try_alloc_skip(usize::MAX).map(|v| v.1)
243    }
244}