ic_bn_lib/utils/
distributor.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::{
4        Arc, Mutex,
5        atomic::{AtomicU8, AtomicUsize, Ordering},
6    },
7    time::Instant,
8};
9
10use ic_bn_lib_common::traits::utils::ExecutesRequest;
11use prometheus::{
12    HistogramVec, IntCounterVec, IntGaugeVec, Registry, register_histogram_vec_with_registry,
13    register_int_counter_vec_with_registry, register_int_gauge_vec_with_registry,
14};
15use scopeguard::defer;
16use serde::{Deserialize, Serialize};
17use strum::{Display, EnumString};
18
19/// Calculates Greatest Common Denominator
20const fn calc_gcd(x: isize, y: isize) -> isize {
21    let mut t: isize;
22    let mut a = x;
23    let mut b = y;
24
25    loop {
26        t = a % b;
27        if t > 0 {
28            a = b;
29            b = t;
30        } else {
31            return b;
32        }
33    }
34}
35
36#[derive(Clone, Debug)]
37pub struct Metrics {
38    inflight: IntGaugeVec,
39    requests: IntCounterVec,
40    duration: HistogramVec,
41}
42
43impl Metrics {
44    pub fn new(registry: &Registry) -> Self {
45        Self {
46            inflight: register_int_gauge_vec_with_registry!(
47                format!("distributor_inflight"),
48                format!("Stores the current number of in-flight requests"),
49                &["target"],
50                registry
51            )
52            .unwrap(),
53
54            requests: register_int_counter_vec_with_registry!(
55                format!("distributor_requests"),
56                format!("Counts the number of requests and results"),
57                &["target", "result"],
58                registry
59            )
60            .unwrap(),
61
62            duration: register_histogram_vec_with_registry!(
63                format!("distributor_duration"),
64                format!("Records the duration of requests in seconds"),
65                &["target"],
66                [0.01, 0.05, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2].to_vec(),
67                registry
68            )
69            .unwrap(),
70        }
71    }
72}
73
74/// Distribution strategy to use
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum Strategy {
78    #[strum(serialize = "wrr")]
79    #[serde(alias = "wrr")]
80    WeightedRoundRobin,
81    #[strum(serialize = "lor")]
82    #[serde(alias = "lor")]
83    LeastOutstandingRequests,
84}
85
86/// Backend that represents a target that receives the request
87#[derive(Debug)]
88pub struct Backend<T> {
89    backend: T,
90    name: String,
91    weight: usize,
92    inflight: AtomicUsize,
93}
94
95impl<T: Display + Send + Sync> Backend<T> {
96    pub fn new(backend: T, weight: usize) -> Self {
97        Self {
98            name: backend.to_string(),
99            backend,
100            weight,
101            inflight: AtomicUsize::new(0),
102        }
103    }
104}
105
106#[derive(Debug)]
107struct Wrr {
108    n: isize,
109    i: isize,
110    gcd: isize,
111    max_weight: isize,
112    curr_weight: isize,
113}
114
115impl Wrr {
116    fn new<T>(backends: &[Backend<T>]) -> Self {
117        let mut gcd = 0;
118        let mut max_weight = 0;
119        for v in backends.iter() {
120            gcd = calc_gcd(gcd, v.weight as isize);
121
122            if v.weight > max_weight {
123                max_weight = v.weight;
124            }
125        }
126
127        Self {
128            n: backends.len() as isize,
129            i: -1,
130            gcd,
131            max_weight: max_weight as isize,
132            curr_weight: 0,
133        }
134    }
135}
136
137/// Distributes the requests over backends using the given `Strategy`
138#[derive(Debug)]
139pub struct Distributor<T, RQ = (), RS = (), E = ()> {
140    backends: Vec<Backend<T>>,
141    strategy: Strategy,
142    executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
143    wrr: Mutex<Wrr>,
144    metrics: Metrics,
145}
146
147impl<T, RQ, RS, E> Distributor<T, RQ, RS, E>
148where
149    T: Clone + Display + Send + Sync,
150    RQ: Send,
151    RS: Send,
152    E: Send,
153{
154    pub fn new(
155        backends: &[(T, usize)],
156        strategy: Strategy,
157        executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
158        metrics: Metrics,
159    ) -> Self {
160        if backends.is_empty() {
161            panic!("There must be at least one backend");
162        }
163
164        let backends = backends
165            .iter()
166            .map(|(b, w)| Backend::new(b.clone(), *w))
167            .collect::<Vec<_>>();
168        let wrr = Wrr::new(&backends);
169
170        Self {
171            backends,
172            strategy,
173            executor,
174            wrr: Mutex::new(wrr),
175            metrics,
176        }
177    }
178
179    /// Picks the next backend to execute the request using WRR algorigthm.
180    /// Based on http://kb.linuxvirtualserver.org/wiki/Weighted_Round-Robin_Scheduling
181    fn next_wrr(&self) -> &Backend<T> {
182        let mut wrr = self.wrr.lock().unwrap();
183
184        loop {
185            wrr.i = (wrr.i + 1) % wrr.n;
186            if wrr.i == 0 {
187                wrr.curr_weight -= wrr.gcd;
188                if wrr.curr_weight <= 0 {
189                    wrr.curr_weight = wrr.max_weight;
190                }
191            }
192
193            if (self.backends[wrr.i as usize].weight as isize) >= wrr.curr_weight {
194                return &self.backends[wrr.i as usize];
195            }
196        }
197    }
198
199    /// Picks the next backend to execute the request using Least Outstanding Requests algorigthm.
200    fn next_lor(&self) -> &Backend<T> {
201        self.backends
202            .iter()
203            .min_by_key(|x| x.inflight.load(Ordering::SeqCst))
204            .unwrap()
205    }
206
207    /// Execute the request using the next server picked by selected algorithm
208    pub async fn execute(&self, request: RQ) -> Result<RS, E> {
209        let backend = match self.strategy {
210            Strategy::LeastOutstandingRequests => self.next_lor(),
211            Strategy::WeightedRoundRobin => self.next_wrr(),
212        };
213
214        backend.inflight.fetch_add(1, Ordering::SeqCst);
215        self.metrics
216            .inflight
217            .with_label_values(&[&backend.name])
218            .inc();
219
220        let start = Instant::now();
221        let ok = Arc::new(AtomicU8::new(0));
222        let ok_clone = ok.clone();
223
224        // Record metrics under defer to make sure they're recorded in case of future cancellation
225        defer! {
226            backend.inflight.fetch_sub(1, Ordering::SeqCst);
227            self.metrics.inflight.with_label_values(&[&backend.name]).dec();
228            self.metrics
229                .duration
230                .with_label_values(&[&backend.name])
231                .observe(start.elapsed().as_secs_f64());
232            self.metrics
233                .requests
234                .with_label_values(&[
235                    backend.name.as_str(),
236                    match ok_clone.load(Ordering::SeqCst) {
237                        1 => "ok",
238                        2 => "fail",
239                        _ => "cancel"
240                    }])
241                .inc();
242        }
243
244        let res = self.executor.execute(&backend.backend, request).await;
245        ok.store(if res.is_ok() { 1 } else { 2 }, Ordering::SeqCst);
246        res
247    }
248}
249
250#[cfg(test)]
251pub(crate) mod test {
252    use std::{collections::HashMap, time::Duration};
253
254    use async_trait::async_trait;
255    use tokio::task::JoinSet;
256
257    use super::*;
258
259    #[derive(Debug)]
260    pub struct TestExecutor(pub Duration, pub Mutex<HashMap<String, usize>>);
261
262    #[async_trait]
263    impl ExecutesRequest<String> for TestExecutor {
264        type Error = ();
265        type Request = ();
266        type Response = ();
267
268        async fn execute(
269            &self,
270            backend: &String,
271            _req: Self::Request,
272        ) -> Result<Self::Response, Self::Error> {
273            *self.1.lock().unwrap().entry(backend.clone()).or_insert(0) += 1;
274            if self.0 > Duration::ZERO {
275                tokio::time::sleep(self.0).await;
276            }
277            Ok(())
278        }
279    }
280
281    #[tokio::test]
282    async fn test_distributor_wrr() {
283        let backends = vec![
284            ("foo".to_string(), 2),
285            ("bar".to_string(), 3),
286            ("baz".to_string(), 5),
287        ];
288
289        let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
290        let metrics = Metrics::new(&Registry::new());
291        let d = Distributor::new(
292            &backends,
293            Strategy::WeightedRoundRobin,
294            executor.clone(),
295            metrics,
296        );
297
298        // Do 1k backend selections
299        for _ in 0..1000 {
300            let _ = d.execute(()).await;
301        }
302
303        // Make sure that we get the distribution according to the weights
304        let h = executor.1.lock().unwrap();
305        assert_eq!(h["foo"], 200);
306        assert_eq!(h["bar"], 300);
307        assert_eq!(h["baz"], 500);
308        drop(h)
309    }
310
311    #[tokio::test]
312    async fn test_distributor_lor() {
313        let backends = vec![
314            ("foo".to_string(), 2),
315            ("bar".to_string(), 3),
316            ("baz".to_string(), 5),
317        ];
318
319        let executor = Arc::new(TestExecutor(
320            Duration::from_secs(1),
321            Mutex::new(HashMap::new()),
322        ));
323
324        let metrics = Metrics::new(&Registry::new());
325        let d = Arc::new(Distributor::new(
326            &backends,
327            Strategy::LeastOutstandingRequests,
328            executor.clone(),
329            metrics,
330        ));
331
332        let mut js = JoinSet::new();
333        // Do 1k backend selections
334        for _ in 0..60 {
335            let d = d.clone();
336            js.spawn(async move {
337                let _ = d.execute(()).await;
338            });
339        }
340
341        js.join_all().await;
342
343        // Make sure that we get even distribution since the requests are accumulated on each node evenly
344        // due to sleep
345        let h = executor.1.lock().unwrap();
346        assert_eq!(h["foo"], 20);
347        assert_eq!(h["bar"], 20);
348        assert_eq!(h["baz"], 20);
349        drop(h)
350    }
351}