load_balancer/
threshold.rs1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5 future::Future,
6 sync::{Arc, atomic::AtomicU64},
7 time::Duration,
8};
9use tokio::{
10 spawn,
11 sync::{Mutex, RwLock},
12 task::{JoinHandle, yield_now},
13 time::sleep,
14};
15
16pub struct Entry<T>
19where
20 T: Send + Sync + Clone + 'static,
21{
22 pub max_count: u64,
24 pub max_error_count: u64,
26 pub count: AtomicU64,
28 pub error_count: AtomicU64,
30 pub value: T,
32}
33
34impl<T> Entry<T>
35where
36 T: Send + Sync + Clone + 'static,
37{
38 pub fn reset(&self) {
40 self.count.store(0, Release);
41 self.error_count.store(0, Release);
42 }
43
44 pub fn disable(&self) {
46 self.error_count.store(self.max_error_count, Release);
47 }
48}
49
50impl<T> Clone for Entry<T>
51where
52 T: Send + Sync + Clone + 'static,
53{
54 fn clone(&self) -> Self {
55 Self {
56 max_count: self.max_count.clone(),
57 max_error_count: self.max_error_count.clone(),
58 count: self.count.load(Acquire).into(),
59 error_count: self.error_count.load(Acquire).into(),
60 value: self.value.clone(),
61 }
62 }
63}
64
65pub struct ThresholdLoadBalancerRef<T>
67where
68 T: Send + Sync + Clone + 'static,
69{
70 pub entries: RwLock<Vec<Entry<T>>>,
72 pub timer: Mutex<Option<JoinHandle<()>>>,
74 pub interval: RwLock<Duration>,
76}
77
78#[derive(Clone)]
80pub struct ThresholdLoadBalancer<T>
81where
82 T: Send + Sync + Clone + 'static,
83{
84 inner: Arc<ThresholdLoadBalancerRef<T>>,
85}
86
87impl<T> ThresholdLoadBalancer<T>
88where
89 T: Send + Sync + Clone + 'static,
90{
91 pub fn new(entries: Vec<(u64, u64, T)>) -> Self {
100 Self::new_interval(entries, Duration::from_secs(1))
101 }
102
103 pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
114 Self {
115 inner: Arc::new(ThresholdLoadBalancerRef {
116 entries: entries
117 .into_iter()
118 .map(|(max_count, max_error_count, value)| Entry {
119 max_count,
120 max_error_count,
121 value,
122 count: 0.into(),
123 error_count: 0.into(),
124 })
125 .collect::<Vec<_>>()
126 .into(),
127 timer: Mutex::new(None),
128 interval: interval.into(),
129 }),
130 }
131 }
132
133 pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
135 where
136 F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
137 R: Future<Output = anyhow::Result<()>>,
138 {
139 handle(self.inner.clone()).await
140 }
141
142 pub async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
144 loop {
145 match self.try_alloc_skip(skip_index) {
146 Some(v) => return v,
147 None => yield_now().await,
148 };
149 }
150 }
151
152 pub fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
154 if let Ok(mut v) = self.inner.timer.try_lock() {
155 if v.is_none() {
156 let this = self.inner.clone();
157
158 *v = Some(spawn(async move {
159 let mut interval = *this.interval.read().await;
160
161 loop {
162 sleep(match this.interval.try_read() {
163 Ok(v) => {
164 interval = *v;
165 interval
166 }
167 Err(_) => interval,
168 })
169 .await;
170
171 for i in this.entries.read().await.iter() {
173 i.count.store(0, Release);
174 }
175 }
176 }));
177 }
178 }
179
180 if let Ok(entries) = self.inner.entries.try_read() {
181 let mut skip_count = 0;
182
183 for (i, entry) in entries.iter().enumerate() {
184 if i == skip_index {
185 continue;
186 }
187
188 if entry.max_error_count != 0
189 && entry.error_count.load(Acquire) >= entry.max_error_count
190 {
191 skip_count += 1;
192 continue;
193 }
194
195 let count = entry.count.load(Acquire);
196
197 if entry.max_count == 0
198 || (count < entry.max_count
199 && entry
200 .count
201 .compare_exchange(count, count + 1, Release, Acquire)
202 .is_ok())
203 {
204 return Some((i, entry.value.clone()));
205 }
206 }
207
208 if skip_count == entries.len() {
210 return None;
211 }
212 }
213
214 None
215 }
216
217 pub fn success(&self, index: usize) {
219 if let Ok(entries) = self.inner.entries.try_read() {
220 if let Some(entry) = entries.get(index) {
221 let current = entry.error_count.load(Acquire);
222
223 if current != 0 {
224 let _ =
225 entry
226 .error_count
227 .compare_exchange(current, current - 1, Release, Acquire);
228 }
229 }
230 }
231 }
232
233 pub fn failure(&self, index: usize) {
235 if let Ok(entries) = self.inner.entries.try_read() {
236 if let Some(entry) = entries.get(index) {
237 entry.error_count.fetch_add(1, Release);
238 }
239 }
240 }
241}
242
243impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
244where
245 T: Clone + Send + Sync + 'static,
246{
247 fn alloc(&self) -> impl Future<Output = T> + Send {
249 async move { self.alloc_skip(usize::MAX).await.1 }
250 }
251
252 fn try_alloc(&self) -> Option<T> {
254 self.try_alloc_skip(usize::MAX).map(|v| v.1)
255 }
256}
257
258#[async_trait]
259impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
260where
261 T: Send + Sync + Clone + 'static,
262{
263 async fn alloc(&self) -> T {
265 self.alloc_skip(usize::MAX).await.1
266 }
267
268 fn try_alloc(&self) -> Option<T> {
270 self.try_alloc_skip(usize::MAX).map(|v| v.1)
271 }
272}