load_balancer/
threshold.rs1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5 future::Future,
6 sync::{Arc, atomic::AtomicU64},
7 time::{Duration, Instant},
8};
9use tokio::{
10 spawn,
11 sync::{Mutex, RwLock},
12 task::{JoinHandle, yield_now},
13 time::sleep,
14};
15
16pub struct Entry<T>
19where
20 T: Send + Sync + Clone + 'static,
21{
22 pub max_count: u64,
24 pub max_error_count: u64,
26 pub count: AtomicU64,
28 pub error_count: AtomicU64,
30 pub value: T,
32}
33
34impl<T> Entry<T>
35where
36 T: Send + Sync + Clone + 'static,
37{
38 pub fn reset(&self) {
40 self.count.store(0, Release);
41 self.error_count.store(0, Release);
42 }
43
44 pub fn disable(&self) {
46 self.error_count.store(self.max_error_count, Release);
47 }
48}
49
50impl<T> Clone for Entry<T>
51where
52 T: Send + Sync + Clone + 'static,
53{
54 fn clone(&self) -> Self {
55 Self {
56 max_count: self.max_count.clone(),
57 max_error_count: self.max_error_count.clone(),
58 count: self.count.load(Acquire).into(),
59 error_count: self.error_count.load(Acquire).into(),
60 value: self.value.clone(),
61 }
62 }
63}
64
65pub struct ThresholdLoadBalancerRef<T>
67where
68 T: Send + Sync + Clone + 'static,
69{
70 pub entries: RwLock<Vec<Entry<T>>>,
72 pub timer: Mutex<Option<JoinHandle<()>>>,
74 pub interval: RwLock<Duration>,
76 pub next_reset: RwLock<Instant>,
78}
79
80#[derive(Clone)]
82pub struct ThresholdLoadBalancer<T>
83where
84 T: Send + Sync + Clone + 'static,
85{
86 inner: Arc<ThresholdLoadBalancerRef<T>>,
87}
88
89impl<T> ThresholdLoadBalancer<T>
90where
91 T: Send + Sync + Clone + 'static,
92{
93 pub fn new(entries: Vec<(u64, u64, T)>) -> Self {
102 Self::new_interval(entries, Duration::from_secs(1))
103 }
104
105 pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
116 Self {
117 inner: Arc::new(ThresholdLoadBalancerRef {
118 entries: entries
119 .into_iter()
120 .map(|(max_count, max_error_count, value)| Entry {
121 max_count,
122 max_error_count,
123 value,
124 count: 0.into(),
125 error_count: 0.into(),
126 })
127 .collect::<Vec<_>>()
128 .into(),
129 timer: Mutex::new(None),
130 interval: interval.into(),
131 next_reset: RwLock::new(Instant::now() + interval),
132 }),
133 }
134 }
135
136 pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
138 where
139 F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
140 R: Future<Output = anyhow::Result<()>>,
141 {
142 handle(self.inner.clone()).await
143 }
144
145 pub async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
147 loop {
148 if let Some(v) = self.try_alloc_skip(skip_index) {
149 return v;
150 }
151
152 let now = Instant::now();
153
154 let next = *self.inner.next_reset.read().await;
155
156 let remaining = if now < next {
157 next - now
158 } else {
159 Duration::ZERO
160 };
161
162 if remaining > Duration::ZERO {
163 sleep(remaining).await;
164 } else {
165 yield_now().await;
166 }
167 }
168 }
169
170 pub fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
172 if let Ok(mut timer_guard) = self.inner.timer.try_lock() {
173 if timer_guard.is_none() {
174 let this = self.inner.clone();
175
176 *timer_guard = Some(spawn(async move {
177 let mut interval = *this.interval.read().await;
178
179 *this.next_reset.write().await = Instant::now() + interval;
180
181 loop {
182 sleep(match this.interval.try_read() {
183 Ok(v) => {
184 interval = *v;
185 interval
186 }
187 Err(_) => interval,
188 })
189 .await;
190
191 let now = Instant::now();
192
193 let entries = this.entries.read().await;
194
195 for entry in entries.iter() {
196 entry.count.store(0, Release);
197 }
198
199 *this.next_reset.write().await = now + interval;
200 }
201 }));
202 }
203 }
204
205 if let Ok(entries) = self.inner.entries.try_read() {
206 let mut skip_count = 0;
207
208 for (i, entry) in entries.iter().enumerate() {
209 if i == skip_index {
210 continue;
211 }
212
213 if entry.max_error_count != 0
214 && entry.error_count.load(Acquire) >= entry.max_error_count
215 {
216 skip_count += 1;
217 continue;
218 }
219
220 let count = entry.count.load(Acquire);
221
222 if entry.max_count == 0
223 || (count < entry.max_count
224 && entry
225 .count
226 .compare_exchange(count, count + 1, Release, Acquire)
227 .is_ok())
228 {
229 return Some((i, entry.value.clone()));
230 }
231 }
232
233 if skip_count == entries.len() {
234 return None;
235 }
236 }
237
238 None
239 }
240
241 pub fn success(&self, index: usize) {
243 if let Ok(entries) = self.inner.entries.try_read() {
244 if let Some(entry) = entries.get(index) {
245 let current = entry.error_count.load(Acquire);
246
247 if current != 0 {
248 let _ =
249 entry
250 .error_count
251 .compare_exchange(current, current - 1, Release, Acquire);
252 }
253 }
254 }
255 }
256
257 pub fn failure(&self, index: usize) {
259 if let Ok(entries) = self.inner.entries.try_read() {
260 if let Some(entry) = entries.get(index) {
261 entry.error_count.fetch_add(1, Release);
262 }
263 }
264 }
265}
266
267impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
268where
269 T: Clone + Send + Sync + 'static,
270{
271 fn alloc(&self) -> impl Future<Output = T> + Send {
273 async move { self.alloc_skip(usize::MAX).await.1 }
274 }
275
276 fn try_alloc(&self) -> Option<T> {
278 self.try_alloc_skip(usize::MAX).map(|v| v.1)
279 }
280}
281
282#[async_trait]
283impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
284where
285 T: Send + Sync + Clone + 'static,
286{
287 async fn alloc(&self) -> T {
289 self.alloc_skip(usize::MAX).await.1
290 }
291
292 fn try_alloc(&self) -> Option<T> {
294 self.try_alloc_skip(usize::MAX).map(|v| v.1)
295 }
296}