Skip to main content

ic_bn_lib/utils/
backend_router.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::Arc,
4    time::Duration,
5};
6
7use arc_swap::{ArcSwap, ArcSwapOption};
8use ic_bn_lib_common::{
9    traits::utils::{ChecksTarget, ExecutesRequest},
10    types::utils::TargetState,
11};
12use tokio::{select, sync::watch::Receiver};
13use tokio_util::{sync::CancellationToken, task::TaskTracker};
14
15use crate::utils::{
16    distributor::{self, Distributor, Strategy},
17    health_check::{self, HealthChecker},
18};
19
20#[derive(thiserror::Error)]
21pub enum Error<E> {
22    #[error("No healthy nodes")]
23    NoHealthyNodes,
24    #[error("{0}")]
25    Inner(E),
26}
27
28struct Actor<T, RQ = (), RS = (), E = ()> {
29    weights: Vec<usize>,
30    health_checker: Arc<HealthChecker<T>>,
31    strategy: Strategy,
32    executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
33    distributor: Arc<ArcSwapOption<Distributor<T, RQ, RS, E>>>,
34    distributor_metrics: distributor::Metrics,
35    healthy: Arc<ArcSwap<Vec<T>>>,
36}
37
38impl<T, RQ, RS, E> Actor<T, RQ, RS, E>
39where
40    T: Clone + Display + Debug + Send + Sync + 'static,
41    RQ: Send + 'static,
42    RS: Send + 'static,
43    E: Send + 'static,
44{
45    /// Create a new Distributor with a healthy node set
46    async fn process(&self, backends: Arc<Vec<(T, TargetState)>>) {
47        // Combine the nodes with their weights
48        // and filter out unhealthy ones.
49        let healthy = backends
50            .iter()
51            .zip(&self.weights)
52            .filter(|x| x.0.1 == TargetState::Healthy)
53            .map(|x| (x.0.0.clone(), *x.1))
54            .collect::<Vec<_>>();
55
56        // If there are no healthy nodes - remove the distributor
57        if healthy.is_empty() {
58            self.distributor.store(None);
59            return;
60        }
61
62        let distributor = Distributor::new(
63            &healthy,
64            self.strategy,
65            self.executor.clone(),
66            self.distributor_metrics.clone(),
67        );
68        self.distributor.store(Some(Arc::new(distributor)));
69        self.healthy
70            .store(Arc::new(healthy.into_iter().map(|x| x.0).collect()));
71    }
72
73    async fn run(&self, token: CancellationToken) {
74        // Subscribe to state notifications from HealthChecker
75        let mut rx = self.health_checker.subscribe();
76
77        loop {
78            select! {
79                biased;
80
81                // Check if we need to shut down
82                _ = token.cancelled() => {
83                    self.health_checker.stop().await;
84                    return;
85                }
86
87                // Process the changes in the set of healthy backends
88                Ok(()) = rx.changed() => {
89                    let backends = rx.borrow_and_update().clone();
90                    self.process(backends).await;
91                }
92            }
93        }
94    }
95}
96
97/// Routes the request to healthy nodes provided by HealthChecker.
98/// Uses Distributor with given Strategy to distribute them.
99#[derive(Debug)]
100pub struct BackendRouter<T, RQ = (), RS = (), E = ()> {
101    token: CancellationToken,
102    tracker: TaskTracker,
103    distributor: Arc<ArcSwapOption<Distributor<T, RQ, RS, E>>>,
104    notify: Receiver<Arc<Vec<(T, TargetState)>>>,
105    healthy: Arc<ArcSwap<Vec<T>>>,
106}
107
108impl<T, RQ, RS, E> BackendRouter<T, RQ, RS, E>
109where
110    T: Clone + Display + Debug + Send + Sync + 'static,
111    RQ: Send + 'static,
112    RS: Send + 'static,
113    E: Send + 'static,
114{
115    /// Create a new BackendRouter
116    pub fn new(
117        backends: &[(T, usize)],
118        executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
119        checker: Arc<dyn ChecksTarget<T>>,
120        strategy: Strategy,
121        check_interval: Duration,
122        health_check_metrics: health_check::Metrics,
123        distributor_metrics: distributor::Metrics,
124    ) -> Self {
125        // Collect the weights for the Actor
126        let weights = backends.iter().map(|x| x.1).collect();
127        // Collect backends w/o weights for the HealthChecker
128        let backends = backends.iter().map(|x| x.0.clone()).collect::<Vec<_>>();
129
130        let health_checker = Arc::new(HealthChecker::new(
131            &backends,
132            checker,
133            check_interval,
134            health_check_metrics,
135        ));
136        let notify = health_checker.subscribe();
137
138        let distributor = Arc::new(ArcSwapOption::empty());
139        let healthy = Arc::new(ArcSwap::new(Arc::new(vec![])));
140
141        let actor = Actor {
142            weights,
143            health_checker,
144            strategy,
145            executor,
146            distributor: distributor.clone(),
147            distributor_metrics,
148            healthy: healthy.clone(),
149        };
150
151        let token = CancellationToken::new();
152        let tracker = TaskTracker::new();
153
154        let child_token = token.child_token();
155        tracker.spawn(async move {
156            actor.run(child_token).await;
157        });
158
159        Self {
160            token,
161            tracker,
162            distributor,
163            notify,
164            healthy,
165        }
166    }
167
168    /// Executes the request
169    pub async fn execute(&self, request: RQ) -> Result<RS, Error<E>> {
170        let Some(distributor) = self.distributor.load_full() else {
171            return Err(Error::NoHealthyNodes);
172        };
173
174        distributor
175            .execute(request)
176            .await
177            .map_err(|e| Error::Inner(e))
178    }
179
180    /// Subscribes to notifications when the set of healthy nodes changes.
181    /// Returns a channel which emits a new set of healthy nodes.
182    pub fn subscribe(&self) -> Receiver<Arc<Vec<(T, TargetState)>>> {
183        self.notify.clone()
184    }
185
186    /// Returns the current set of healthy targets
187    pub fn get_healthy(&self) -> Arc<Vec<T>> {
188        self.healthy.load_full()
189    }
190
191    /// Stops the router
192    pub async fn stop(&self) {
193        self.token.cancel();
194        self.tracker.close();
195        self.tracker.wait().await;
196    }
197}
198
199#[cfg(test)]
200mod test {
201    use std::{collections::HashMap, sync::Mutex};
202
203    use async_trait::async_trait;
204    use prometheus::Registry;
205
206    use crate::utils::distributor::test::TestExecutor;
207
208    use super::*;
209
210    struct TestChecker;
211
212    #[async_trait]
213    impl ChecksTarget<String> for TestChecker {
214        async fn check(&self, target: &String) -> TargetState {
215            if ["foo", "bar"].contains(&target.as_str()) {
216                TargetState::Healthy
217            } else {
218                TargetState::Degraded
219            }
220        }
221    }
222
223    #[tokio::test]
224    async fn test_request_router_somewhat_healthy() {
225        let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
226
227        let router = BackendRouter::new(
228            &[
229                ("foo".to_string(), 1),
230                ("bar".to_string(), 2),
231                ("baz".to_string(), 3),
232            ],
233            executor.clone(),
234            Arc::new(TestChecker),
235            Strategy::WeightedRoundRobin,
236            Duration::from_millis(1),
237            health_check::Metrics::new(&Registry::new()),
238            distributor::Metrics::new(&Registry::new()),
239        );
240
241        // Wait a bit for health checks to run
242        tokio::time::sleep(Duration::from_millis(100)).await;
243
244        // Do 900 requests
245        for _ in 0..900 {
246            assert!(router.execute(()).await.is_ok());
247        }
248
249        // Make sure that we get the distribution according to the weights
250        let h = executor.1.lock().unwrap();
251        assert_eq!(h["foo"], 300);
252        assert_eq!(h["bar"], 600);
253        // This one is unhealthy and shouldn't get any requests
254        assert!(!h.contains_key("baz"));
255        drop(h)
256    }
257
258    #[tokio::test]
259    async fn test_request_router_unhealthy() {
260        let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
261
262        let router = BackendRouter::new(
263            &[("baz".to_string(), 3)],
264            executor.clone(),
265            Arc::new(TestChecker),
266            Strategy::WeightedRoundRobin,
267            Duration::from_millis(1),
268            health_check::Metrics::new(&Registry::new()),
269            distributor::Metrics::new(&Registry::new()),
270        );
271
272        // Wait a bit for health checks to run
273        tokio::time::sleep(Duration::from_millis(100)).await;
274
275        assert!(matches!(
276            router.execute(()).await.unwrap_err(),
277            Error::NoHealthyNodes
278        ));
279    }
280}