load_balancer/
threshold.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5    future::Future,
6    sync::{Arc, atomic::AtomicU64},
7    time::Duration,
8};
9use tokio::{
10    spawn,
11    sync::{Mutex, RwLock},
12    task::{JoinHandle, yield_now},
13    time::sleep,
14};
15
16/// Represents a single entry in the threshold load balancer.
17/// Tracks the maximum allowed requests, maximum errors, current usage count, and error count.
18pub struct Entry<T>
19where
20    T: Send + Sync + Clone + 'static,
21{
22    /// Maximum number of allowed allocations per interval.
23    pub max_count: u64,
24    /// Maximum number of allowed errors before the entry is considered disabled.
25    pub max_error_count: u64,
26    /// Current allocation count.
27    pub count: AtomicU64,
28    /// Current error count.
29    pub error_count: AtomicU64,
30    /// The actual value being balanced (e.g., client, resource).
31    pub value: T,
32}
33
34impl<T> Entry<T>
35where
36    T: Send + Sync + Clone + 'static,
37{
38    /// Reset both the allocation count and the error count.
39    pub fn reset(&self) {
40        self.count.store(0, Release);
41        self.error_count.store(0, Release);
42    }
43
44    /// Disable this entry by setting its error count to the maximum.
45    pub fn disable(&self) {
46        self.error_count.store(self.max_error_count, Release);
47    }
48}
49
50impl<T> Clone for Entry<T>
51where
52    T: Send + Sync + Clone + 'static,
53{
54    fn clone(&self) -> Self {
55        Self {
56            max_count: self.max_count.clone(),
57            max_error_count: self.max_error_count.clone(),
58            count: self.count.load(Acquire).into(),
59            error_count: self.error_count.load(Acquire).into(),
60            value: self.value.clone(),
61        }
62    }
63}
64
65/// Internal representation of the threshold load balancer.
66pub struct ThresholdLoadBalancerRef<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    /// List of entries to balance between.
71    pub entries: RwLock<Vec<Entry<T>>>,
72    /// Optional background timer handle for resetting counts periodically.
73    pub timer: Mutex<Option<JoinHandle<()>>>,
74    /// Interval duration for resetting counts.
75    pub interval: RwLock<Duration>,
76}
77
78/// Threshold-based load balancer that limits allocations per entry and handles failures.
79#[derive(Clone)]
80pub struct ThresholdLoadBalancer<T>
81where
82    T: Send + Sync + Clone + 'static,
83{
84    inner: Arc<ThresholdLoadBalancerRef<T>>,
85}
86
87impl<T> ThresholdLoadBalancer<T>
88where
89    T: Send + Sync + Clone + 'static,
90{
91    /// Create a new threshold load balancer with a fixed interval.
92    pub fn new(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
93        Self {
94            inner: Arc::new(ThresholdLoadBalancerRef {
95                entries: entries
96                    .into_iter()
97                    .map(|(max_count, max_error_count, value)| Entry {
98                        max_count,
99                        max_error_count,
100                        value,
101                        count: 0.into(),
102                        error_count: 0.into(),
103                    })
104                    .collect::<Vec<_>>()
105                    .into(),
106                timer: Mutex::new(None),
107                interval: interval.into(),
108            }),
109        }
110    }
111
112    /// Create a new threshold load balancer with a custom interval.
113    pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
114        Self::new(entries, interval)
115    }
116
117    /// Execute a custom async update on the internal reference.
118    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
119    where
120        F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
121        R: Future<Output = anyhow::Result<()>>,
122    {
123        handle(self.inner.clone()).await
124    }
125
126    /// Allocate an entry, skipping the specified index if provided.
127    async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
128        loop {
129            match self.try_alloc_skip(skip_index) {
130                Some(v) => return v,
131                None => yield_now().await,
132            };
133        }
134    }
135
136    /// Try to allocate an entry immediately, skipping the specified index if provided.
137    fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
138        // Start the background timer if it is not already running.
139        if let Ok(mut v) = self.inner.timer.try_lock() {
140            if v.is_none() {
141                let this = self.inner.clone();
142
143                *v = Some(spawn(async move {
144                    let mut interval = *this.interval.read().await;
145
146                    loop {
147                        sleep(match this.interval.try_read() {
148                            Ok(v) => {
149                                interval = *v;
150                                interval
151                            }
152                            Err(_) => interval,
153                        })
154                        .await;
155
156                        // Reset the allocation count for all entries.
157                        for i in this.entries.read().await.iter() {
158                            i.count.store(0, Release);
159                        }
160                    }
161                }));
162            }
163        }
164
165        // Attempt to select a valid entry.
166        if let Ok(entries) = self.inner.entries.try_read() {
167            let mut skip_count = 0;
168
169            for (i, entry) in entries.iter().enumerate() {
170                if i == skip_index {
171                    continue;
172                }
173
174                if entry.max_error_count != 0
175                    && entry.error_count.load(Acquire) >= entry.max_error_count
176                {
177                    skip_count += 1;
178                    continue;
179                }
180
181                let count = entry.count.load(Acquire);
182
183                if entry.max_count == 0
184                    || (count < entry.max_count
185                        && entry
186                            .count
187                            .compare_exchange(count, count + 1, Release, Acquire)
188                            .is_ok())
189                {
190                    return Some((i, entry.value.clone()));
191                }
192            }
193
194            // All entries are skipped due to errors.
195            if skip_count == entries.len() {
196                return None;
197            }
198        }
199
200        None
201    }
202
203    /// Mark a successful usage for the entry at the given index.
204    pub fn success(&self, index: usize) {
205        if let Ok(entries) = self.inner.entries.try_read() {
206            if let Some(entry) = entries.get(index) {
207                let current = entry.error_count.load(Acquire);
208
209                if current != 0 {
210                    let _ =
211                        entry
212                            .error_count
213                            .compare_exchange(current, current - 1, Release, Acquire);
214                }
215            }
216        }
217    }
218
219    /// Mark a failure for the entry at the given index.
220    pub fn failure(&self, index: usize) {
221        if let Ok(entries) = self.inner.entries.try_read() {
222            if let Some(entry) = entries.get(index) {
223                entry.error_count.fetch_add(1, Release);
224            }
225        }
226    }
227}
228
229impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
230where
231    T: Clone + Send + Sync + 'static,
232{
233    /// Allocate an entry asynchronously.
234    fn alloc(&self) -> impl Future<Output = T> + Send {
235        async move { self.alloc_skip(usize::MAX).await.1 }
236    }
237
238    /// Attempt to allocate an entry immediately.
239    fn try_alloc(&self) -> Option<T> {
240        self.try_alloc_skip(usize::MAX).map(|v| v.1)
241    }
242}
243
244#[async_trait]
245impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
246where
247    T: Send + Sync + Clone + 'static,
248{
249    /// Allocate an entry asynchronously.
250    async fn alloc(&self) -> T {
251        self.alloc_skip(usize::MAX).await.1
252    }
253
254    /// Attempt to allocate an entry immediately.
255    fn try_alloc(&self) -> Option<T> {
256        self.try_alloc_skip(usize::MAX).map(|v| v.1)
257    }
258}