1use std::{
2 fmt::{Debug, Display},
3 sync::{
4 Arc, Mutex,
5 atomic::{AtomicU8, AtomicUsize, Ordering},
6 },
7 time::Instant,
8};
9
10use ic_bn_lib_common::traits::utils::ExecutesRequest;
11use prometheus::{
12 HistogramVec, IntCounterVec, IntGaugeVec, Registry, register_histogram_vec_with_registry,
13 register_int_counter_vec_with_registry, register_int_gauge_vec_with_registry,
14};
15use scopeguard::defer;
16use serde::{Deserialize, Serialize};
17use strum::{Display, EnumString};
18
19const fn calc_gcd(x: isize, y: isize) -> isize {
21 let mut t: isize;
22 let mut a = x;
23 let mut b = y;
24
25 loop {
26 t = a % b;
27 if t > 0 {
28 a = b;
29 b = t;
30 } else {
31 return b;
32 }
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct Metrics {
38 inflight: IntGaugeVec,
39 requests: IntCounterVec,
40 duration: HistogramVec,
41}
42
43impl Metrics {
44 pub fn new(registry: &Registry) -> Self {
45 Self {
46 inflight: register_int_gauge_vec_with_registry!(
47 format!("distributor_inflight"),
48 format!("Stores the current number of in-flight requests"),
49 &["target"],
50 registry
51 )
52 .unwrap(),
53
54 requests: register_int_counter_vec_with_registry!(
55 format!("distributor_requests"),
56 format!("Counts the number of requests and results"),
57 &["target", "result"],
58 registry
59 )
60 .unwrap(),
61
62 duration: register_histogram_vec_with_registry!(
63 format!("distributor_duration"),
64 format!("Records the duration of requests in seconds"),
65 &["target"],
66 [0.01, 0.05, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2].to_vec(),
67 registry
68 )
69 .unwrap(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum Strategy {
78 #[strum(serialize = "wrr")]
79 #[serde(alias = "wrr")]
80 WeightedRoundRobin,
81 #[strum(serialize = "lor")]
82 #[serde(alias = "lor")]
83 LeastOutstandingRequests,
84}
85
86#[derive(Debug)]
88pub struct Backend<T> {
89 backend: T,
90 name: String,
91 weight: usize,
92 inflight: AtomicUsize,
93}
94
95impl<T: Display + Send + Sync> Backend<T> {
96 pub fn new(backend: T, weight: usize) -> Self {
97 Self {
98 name: backend.to_string(),
99 backend,
100 weight,
101 inflight: AtomicUsize::new(0),
102 }
103 }
104}
105
106#[derive(Debug)]
107struct Wrr {
108 n: isize,
109 i: isize,
110 gcd: isize,
111 max_weight: isize,
112 curr_weight: isize,
113}
114
115impl Wrr {
116 fn new<T>(backends: &[Backend<T>]) -> Self {
117 let mut gcd = 0;
118 let mut max_weight = 0;
119 for v in backends.iter() {
120 gcd = calc_gcd(gcd, v.weight as isize);
121
122 if v.weight > max_weight {
123 max_weight = v.weight;
124 }
125 }
126
127 Self {
128 n: backends.len() as isize,
129 i: -1,
130 gcd,
131 max_weight: max_weight as isize,
132 curr_weight: 0,
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct Distributor<T, RQ = (), RS = (), E = ()> {
140 backends: Vec<Backend<T>>,
141 strategy: Strategy,
142 executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
143 wrr: Mutex<Wrr>,
144 metrics: Metrics,
145}
146
147impl<T, RQ, RS, E> Distributor<T, RQ, RS, E>
148where
149 T: Clone + Display + Send + Sync,
150 RQ: Send,
151 RS: Send,
152 E: Send,
153{
154 pub fn new(
155 backends: &[(T, usize)],
156 strategy: Strategy,
157 executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
158 metrics: Metrics,
159 ) -> Self {
160 if backends.is_empty() {
161 panic!("There must be at least one backend");
162 }
163
164 let backends = backends
165 .iter()
166 .map(|(b, w)| Backend::new(b.clone(), *w))
167 .collect::<Vec<_>>();
168 let wrr = Wrr::new(&backends);
169
170 Self {
171 backends,
172 strategy,
173 executor,
174 wrr: Mutex::new(wrr),
175 metrics,
176 }
177 }
178
179 fn next_wrr(&self) -> &Backend<T> {
182 let mut wrr = self.wrr.lock().unwrap();
183
184 loop {
185 wrr.i = (wrr.i + 1) % wrr.n;
186 if wrr.i == 0 {
187 wrr.curr_weight -= wrr.gcd;
188 if wrr.curr_weight <= 0 {
189 wrr.curr_weight = wrr.max_weight;
190 }
191 }
192
193 if (self.backends[wrr.i as usize].weight as isize) >= wrr.curr_weight {
194 return &self.backends[wrr.i as usize];
195 }
196 }
197 }
198
199 fn next_lor(&self) -> &Backend<T> {
201 self.backends
202 .iter()
203 .min_by_key(|x| x.inflight.load(Ordering::SeqCst))
204 .unwrap()
205 }
206
207 pub async fn execute(&self, request: RQ) -> Result<RS, E> {
209 let backend = match self.strategy {
210 Strategy::LeastOutstandingRequests => self.next_lor(),
211 Strategy::WeightedRoundRobin => self.next_wrr(),
212 };
213
214 backend.inflight.fetch_add(1, Ordering::SeqCst);
215 self.metrics
216 .inflight
217 .with_label_values(&[&backend.name])
218 .inc();
219
220 let start = Instant::now();
221 let ok = Arc::new(AtomicU8::new(0));
222 let ok_clone = ok.clone();
223
224 defer! {
226 backend.inflight.fetch_sub(1, Ordering::SeqCst);
227 self.metrics.inflight.with_label_values(&[&backend.name]).dec();
228 self.metrics
229 .duration
230 .with_label_values(&[&backend.name])
231 .observe(start.elapsed().as_secs_f64());
232 self.metrics
233 .requests
234 .with_label_values(&[
235 backend.name.as_str(),
236 match ok_clone.load(Ordering::SeqCst) {
237 1 => "ok",
238 2 => "fail",
239 _ => "cancel"
240 }])
241 .inc();
242 }
243
244 let res = self.executor.execute(&backend.backend, request).await;
245 ok.store(if res.is_ok() { 1 } else { 2 }, Ordering::SeqCst);
246 res
247 }
248}
249
250#[cfg(test)]
251pub(crate) mod test {
252 use std::{collections::HashMap, time::Duration};
253
254 use async_trait::async_trait;
255 use tokio::task::JoinSet;
256
257 use super::*;
258
259 #[derive(Debug)]
260 pub struct TestExecutor(pub Duration, pub Mutex<HashMap<String, usize>>);
261
262 #[async_trait]
263 impl ExecutesRequest<String> for TestExecutor {
264 type Error = ();
265 type Request = ();
266 type Response = ();
267
268 async fn execute(
269 &self,
270 backend: &String,
271 _req: Self::Request,
272 ) -> Result<Self::Response, Self::Error> {
273 *self.1.lock().unwrap().entry(backend.clone()).or_insert(0) += 1;
274 if self.0 > Duration::ZERO {
275 tokio::time::sleep(self.0).await;
276 }
277 Ok(())
278 }
279 }
280
281 #[tokio::test]
282 async fn test_distributor_wrr() {
283 let backends = vec![
284 ("foo".to_string(), 2),
285 ("bar".to_string(), 3),
286 ("baz".to_string(), 5),
287 ];
288
289 let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
290 let metrics = Metrics::new(&Registry::new());
291 let d = Distributor::new(
292 &backends,
293 Strategy::WeightedRoundRobin,
294 executor.clone(),
295 metrics,
296 );
297
298 for _ in 0..1000 {
300 let _ = d.execute(()).await;
301 }
302
303 let h = executor.1.lock().unwrap();
305 assert_eq!(h["foo"], 200);
306 assert_eq!(h["bar"], 300);
307 assert_eq!(h["baz"], 500);
308 drop(h)
309 }
310
311 #[tokio::test]
312 async fn test_distributor_lor() {
313 let backends = vec![
314 ("foo".to_string(), 2),
315 ("bar".to_string(), 3),
316 ("baz".to_string(), 5),
317 ];
318
319 let executor = Arc::new(TestExecutor(
320 Duration::from_secs(1),
321 Mutex::new(HashMap::new()),
322 ));
323
324 let metrics = Metrics::new(&Registry::new());
325 let d = Arc::new(Distributor::new(
326 &backends,
327 Strategy::LeastOutstandingRequests,
328 executor.clone(),
329 metrics,
330 ));
331
332 let mut js = JoinSet::new();
333 for _ in 0..60 {
335 let d = d.clone();
336 js.spawn(async move {
337 let _ = d.execute(()).await;
338 });
339 }
340
341 js.join_all().await;
342
343 let h = executor.1.lock().unwrap();
346 assert_eq!(h["foo"], 20);
347 assert_eq!(h["bar"], 20);
348 assert_eq!(h["baz"], 20);
349 drop(h)
350 }
351}