load_balancer/
threshold.rs1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5 future::Future,
6 sync::{Arc, atomic::AtomicU64},
7 time::Duration,
8};
9use tokio::{
10 spawn,
11 sync::{Mutex, RwLock},
12 task::{JoinHandle, yield_now},
13 time::sleep,
14};
15
16#[derive(Debug)]
19pub struct Entry<T>
20where
21 T: Send + Sync + Clone + 'static,
22{
23 pub max_count: u64,
25 pub max_error_count: u64,
27 pub count: AtomicU64,
29 pub error_count: AtomicU64,
31 pub value: T,
33}
34
35impl<T> Entry<T>
36where
37 T: Send + Sync + Clone + 'static,
38{
39 pub fn reset(&self) {
41 self.count.store(0, Release);
42 self.error_count.store(0, Release);
43 }
44
45 pub fn disable(&self) {
47 self.error_count.store(self.max_error_count, Release);
48 }
49}
50
51pub struct ThresholdLoadBalancerRef<T>
53where
54 T: Send + Sync + Clone + 'static,
55{
56 pub entries: RwLock<Vec<Entry<T>>>,
58 pub timer: Mutex<Option<JoinHandle<()>>>,
60 pub interval: RwLock<Duration>,
62}
63
64#[derive(Clone)]
66pub struct ThresholdLoadBalancer<T>
67where
68 T: Send + Sync + Clone + 'static,
69{
70 inner: Arc<ThresholdLoadBalancerRef<T>>,
71}
72
73impl<T> ThresholdLoadBalancer<T>
74where
75 T: Send + Sync + Clone + 'static,
76{
77 pub fn new(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
79 Self {
80 inner: Arc::new(ThresholdLoadBalancerRef {
81 entries: entries
82 .into_iter()
83 .map(|(max_count, max_error_count, value)| Entry {
84 max_count,
85 max_error_count,
86 value,
87 count: 0.into(),
88 error_count: 0.into(),
89 })
90 .collect::<Vec<_>>()
91 .into(),
92 timer: Mutex::new(None),
93 interval: interval.into(),
94 }),
95 }
96 }
97
98 pub fn new_interval(entries: Vec<(u64, u64, T)>, interval: Duration) -> Self {
100 Self::new(entries, interval)
101 }
102
103 pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
105 where
106 F: Fn(Arc<ThresholdLoadBalancerRef<T>>) -> R,
107 R: Future<Output = anyhow::Result<()>>,
108 {
109 handle(self.inner.clone()).await
110 }
111
112 async fn alloc_skip(&self, skip_index: usize) -> (usize, T) {
114 loop {
115 match self.try_alloc_skip(skip_index) {
116 Some(v) => return v,
117 None => yield_now().await,
118 };
119 }
120 }
121
122 fn try_alloc_skip(&self, skip_index: usize) -> Option<(usize, T)> {
124 if let Ok(mut v) = self.inner.timer.try_lock() {
126 if v.is_none() {
127 let this = self.inner.clone();
128
129 *v = Some(spawn(async move {
130 let mut interval = *this.interval.read().await;
131
132 loop {
133 sleep(match this.interval.try_read() {
134 Ok(v) => {
135 interval = *v;
136 interval
137 }
138 Err(_) => interval,
139 })
140 .await;
141
142 for i in this.entries.read().await.iter() {
144 i.count.store(0, Release);
145 }
146 }
147 }));
148 }
149 }
150
151 if let Ok(entries) = self.inner.entries.try_read() {
153 let mut skip_count = 0;
154
155 for (i, entry) in entries.iter().enumerate() {
156 if i == skip_index {
157 continue;
158 }
159
160 if entry.max_error_count != 0
161 && entry.error_count.load(Acquire) >= entry.max_error_count
162 {
163 skip_count += 1;
164 continue;
165 }
166
167 let count = entry.count.load(Acquire);
168
169 if entry.max_count == 0
170 || (count < entry.max_count
171 && entry
172 .count
173 .compare_exchange(count, count + 1, Release, Acquire)
174 .is_ok())
175 {
176 return Some((i, entry.value.clone()));
177 }
178 }
179
180 if skip_count == entries.len() {
182 return None;
183 }
184 }
185
186 None
187 }
188
189 pub fn success(&self, index: usize) {
191 if let Ok(entries) = self.inner.entries.try_read() {
192 if let Some(entry) = entries.get(index) {
193 let current = entry.error_count.load(Acquire);
194
195 if current != 0 {
196 let _ =
197 entry
198 .error_count
199 .compare_exchange(current, current - 1, Release, Acquire);
200 }
201 }
202 }
203 }
204
205 pub fn failure(&self, index: usize) {
207 if let Ok(entries) = self.inner.entries.try_read() {
208 if let Some(entry) = entries.get(index) {
209 entry.error_count.fetch_add(1, Release);
210 }
211 }
212 }
213}
214
215impl<T> LoadBalancer<T> for ThresholdLoadBalancer<T>
216where
217 T: Clone + Send + Sync + 'static,
218{
219 fn alloc(&self) -> impl Future<Output = T> + Send {
221 async move { self.alloc_skip(usize::MAX).await.1 }
222 }
223
224 fn try_alloc(&self) -> Option<T> {
226 self.try_alloc_skip(usize::MAX).map(|v| v.1)
227 }
228}
229
230#[async_trait]
231impl<T> BoxLoadBalancer<T> for ThresholdLoadBalancer<T>
232where
233 T: Send + Sync + Clone + 'static,
234{
235 async fn alloc(&self) -> T {
237 self.alloc_skip(usize::MAX).await.1
238 }
239
240 fn try_alloc(&self) -> Option<T> {
242 self.try_alloc_skip(usize::MAX).map(|v| v.1)
243 }
244}