load_balancer/
threshold.rs1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5 future::Future,
6 sync::{Arc, atomic::AtomicU64},
7 time::Duration,
8};
9use tokio::{
10 spawn,
11 sync::{Mutex, RwLock},
12 task::{JoinHandle, yield_now},
13 time::sleep,
14};
15
16pub struct Entry<T>
19where
20 T: Send + Sync + Clone + 'static,
21{
22 pub max_count: u64,
24 pub max_error_count: u64,
26 pub count: AtomicU64,
28 pub error_count: AtomicU64,
30 pub value: T,
32}
33
34impl<T> Entry<T>
35where
36 T: Send + Sync + Clone + 'static,
37{
38 pub fn reset(&self) {
40 self.count.store(0, Release);
41 self.error_count.store(0, Release);
42 }
43
44 pub fn disable(&self) {
46 self.error_count.store(self.max_error_count, Release);
47 }
48}
49
50impl<T> Clone for Entry<T>
51where
52 T: Send + Sync + Clone + 'static,
53{
54 fn clone(&self) -> Self {
55 Self {
56 max_count: self.max_count.clone(),
57 max_error_count: self.max_error_count.clone(),
58 count: self.count.load(Acquire).into(),
59 error_count: self.error_count.load(Acquire).into(),
60 value: self.value.clone(),
61 }
62 }
63}
64
65pub struct ThresholdLoadBalancerRef<T>
67where
68 T: Send + Sync + Clone + 'static,
69{
70 pub entries: RwLock<Vec<Entry<T>>>,
72 pub timer: Mutex<Option<JoinHandle<()>>>,
74 pub interval: RwLock<Duration>,
76}
77
78#[derive(Clone)]
80pub struct ThresholdLoadBalancer<T>
81where
82 T: Send + Sync + Clone + 'static,
83{
84 inner: Arc<ThresholdLoadBalancerRef<T>>,
85}
86
87impl<T> ThresholdLoadBalancer<T>
88where
89 T: Send + Sync + Clone + 'static,
90{
91 pub fn new(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
93 Self {
94 inner: Arc::new(ThresholdLoadBalancerRef {
95 entries: entries
96 .into_iter()
97 .map(|(max_count, max_error_count, value)| Entry {
98 max_count,
99 max_error_count,
100 value,
101 count: 0.into(),
102 error_count: 0.into(),
103 })
104 .collect::<Vec<_>>()
105 .into(),
106 timer: Mutex::new(None),
107 interval: interval.into(),
108 }),
109 }
110 }
111
112 pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
114 Self::new(entries, interval)
115 }
116
117 pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
119 where
120 F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
121 R: Future<Output = anyhow::Result<()>>,
122 {
123 handle(self.inner.clone()).await
124 }
125
126 async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
128 loop {
129 match self.try_alloc_skip(skip_index) {
130 Some(v) => return v,
131 None => yield_now().await,
132 };
133 }
134 }
135
136 fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
138 if let Ok(mut v) = self.inner.timer.try_lock() {
140 if v.is_none() {
141 let this = self.inner.clone();
142
143 *v = Some(spawn(async move {
144 let mut interval = *this.interval.read().await;
145
146 loop {
147 sleep(match this.interval.try_read() {
148 Ok(v) => {
149 interval = *v;
150 interval
151 }
152 Err(_) => interval,
153 })
154 .await;
155
156 for i in this.entries.read().await.iter() {
158 i.count.store(0, Release);
159 }
160 }
161 }));
162 }
163 }
164
165 if let Ok(entries) = self.inner.entries.try_read() {
167 let mut skip_count = 0;
168
169 for (i, entry) in entries.iter().enumerate() {
170 if i == skip_index {
171 continue;
172 }
173
174 if entry.max_error_count != 0
175 && entry.error_count.load(Acquire) >= entry.max_error_count
176 {
177 skip_count += 1;
178 continue;
179 }
180
181 let count = entry.count.load(Acquire);
182
183 if entry.max_count == 0
184 || (count < entry.max_count
185 && entry
186 .count
187 .compare_exchange(count, count + 1, Release, Acquire)
188 .is_ok())
189 {
190 return Some((i, entry.value.clone()));
191 }
192 }
193
194 if skip_count == entries.len() {
196 return None;
197 }
198 }
199
200 None
201 }
202
203 pub fn success(&self, index: usize) {
205 if let Ok(entries) = self.inner.entries.try_read() {
206 if let Some(entry) = entries.get(index) {
207 let current = entry.error_count.load(Acquire);
208
209 if current != 0 {
210 let _ =
211 entry
212 .error_count
213 .compare_exchange(current, current - 1, Release, Acquire);
214 }
215 }
216 }
217 }
218
219 pub fn failure(&self, index: usize) {
221 if let Ok(entries) = self.inner.entries.try_read() {
222 if let Some(entry) = entries.get(index) {
223 entry.error_count.fetch_add(1, Release);
224 }
225 }
226 }
227}
228
229impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
230where
231 T: Clone + Send + Sync + 'static,
232{
233 fn alloc(&self) -> impl Future<Output = T> + Send {
235 async move { self.alloc_skip(usize::MAX).await.1 }
236 }
237
238 fn try_alloc(&self) -> Option<T> {
240 self.try_alloc_skip(usize::MAX).map(|v| v.1)
241 }
242}
243
244#[async_trait]
245impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
246where
247 T: Send + Sync + Clone + 'static,
248{
249 async fn alloc(&self) -> T {
251 self.alloc_skip(usize::MAX).await.1
252 }
253
254 fn try_alloc(&self) -> Option<T> {
256 self.try_alloc_skip(usize::MAX).map(|v| v.1)
257 }
258}