1use std::{
2 fmt::{Debug, Display},
3 sync::Arc,
4 time::Duration,
5};
6
7use arc_swap::{ArcSwap, ArcSwapOption};
8use ic_bn_lib_common::{
9 traits::utils::{ChecksTarget, ExecutesRequest},
10 types::utils::TargetState,
11};
12use tokio::{select, sync::watch::Receiver};
13use tokio_util::{sync::CancellationToken, task::TaskTracker};
14
15use crate::utils::{
16 distributor::{self, Distributor, Strategy},
17 health_check::{self, HealthChecker},
18};
19
20#[derive(thiserror::Error)]
21pub enum Error<E> {
22 #[error("No healthy nodes")]
23 NoHealthyNodes,
24 #[error("{0}")]
25 Inner(E),
26}
27
28struct Actor<T, RQ = (), RS = (), E = ()> {
29 weights: Vec<usize>,
30 health_checker: Arc<HealthChecker<T>>,
31 strategy: Strategy,
32 executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
33 distributor: Arc<ArcSwapOption<Distributor<T, RQ, RS, E>>>,
34 distributor_metrics: distributor::Metrics,
35 healthy: Arc<ArcSwap<Vec<T>>>,
36}
37
38impl<T, RQ, RS, E> Actor<T, RQ, RS, E>
39where
40 T: Clone + Display + Debug + Send + Sync + 'static,
41 RQ: Send + 'static,
42 RS: Send + 'static,
43 E: Send + 'static,
44{
45 async fn process(&self, backends: Arc<Vec<(T, TargetState)>>) {
47 let healthy = backends
50 .iter()
51 .zip(&self.weights)
52 .filter(|x| x.0.1 == TargetState::Healthy)
53 .map(|x| (x.0.0.clone(), *x.1))
54 .collect::<Vec<_>>();
55
56 if healthy.is_empty() {
58 self.distributor.store(None);
59 return;
60 }
61
62 let distributor = Distributor::new(
63 &healthy,
64 self.strategy,
65 self.executor.clone(),
66 self.distributor_metrics.clone(),
67 );
68 self.distributor.store(Some(Arc::new(distributor)));
69 self.healthy
70 .store(Arc::new(healthy.into_iter().map(|x| x.0).collect()));
71 }
72
73 async fn run(&self, token: CancellationToken) {
74 let mut rx = self.health_checker.subscribe();
76
77 loop {
78 select! {
79 biased;
80
81 _ = token.cancelled() => {
83 self.health_checker.stop().await;
84 return;
85 }
86
87 Ok(()) = rx.changed() => {
89 let backends = rx.borrow_and_update().clone();
90 self.process(backends).await;
91 }
92 }
93 }
94 }
95}
96
97#[derive(Debug)]
100pub struct BackendRouter<T, RQ = (), RS = (), E = ()> {
101 token: CancellationToken,
102 tracker: TaskTracker,
103 distributor: Arc<ArcSwapOption<Distributor<T, RQ, RS, E>>>,
104 notify: Receiver<Arc<Vec<(T, TargetState)>>>,
105 healthy: Arc<ArcSwap<Vec<T>>>,
106}
107
108impl<T, RQ, RS, E> BackendRouter<T, RQ, RS, E>
109where
110 T: Clone + Display + Debug + Send + Sync + 'static,
111 RQ: Send + 'static,
112 RS: Send + 'static,
113 E: Send + 'static,
114{
115 pub fn new(
117 backends: &[(T, usize)],
118 executor: Arc<dyn ExecutesRequest<T, Request = RQ, Response = RS, Error = E>>,
119 checker: Arc<dyn ChecksTarget<T>>,
120 strategy: Strategy,
121 check_interval: Duration,
122 health_check_metrics: health_check::Metrics,
123 distributor_metrics: distributor::Metrics,
124 ) -> Self {
125 let weights = backends.iter().map(|x| x.1).collect();
127 let backends = backends.iter().map(|x| x.0.clone()).collect::<Vec<_>>();
129
130 let health_checker = Arc::new(HealthChecker::new(
131 &backends,
132 checker,
133 check_interval,
134 health_check_metrics,
135 ));
136 let notify = health_checker.subscribe();
137
138 let distributor = Arc::new(ArcSwapOption::empty());
139 let healthy = Arc::new(ArcSwap::new(Arc::new(vec![])));
140
141 let actor = Actor {
142 weights,
143 health_checker,
144 strategy,
145 executor,
146 distributor: distributor.clone(),
147 distributor_metrics,
148 healthy: healthy.clone(),
149 };
150
151 let token = CancellationToken::new();
152 let tracker = TaskTracker::new();
153
154 let child_token = token.child_token();
155 tracker.spawn(async move {
156 actor.run(child_token).await;
157 });
158
159 Self {
160 token,
161 tracker,
162 distributor,
163 notify,
164 healthy,
165 }
166 }
167
168 pub async fn execute(&self, request: RQ) -> Result<RS, Error<E>> {
170 let Some(distributor) = self.distributor.load_full() else {
171 return Err(Error::NoHealthyNodes);
172 };
173
174 distributor
175 .execute(request)
176 .await
177 .map_err(|e| Error::Inner(e))
178 }
179
180 pub fn subscribe(&self) -> Receiver<Arc<Vec<(T, TargetState)>>> {
183 self.notify.clone()
184 }
185
186 pub fn get_healthy(&self) -> Arc<Vec<T>> {
188 self.healthy.load_full()
189 }
190
191 pub async fn stop(&self) {
193 self.token.cancel();
194 self.tracker.close();
195 self.tracker.wait().await;
196 }
197}
198
199#[cfg(test)]
200mod test {
201 use std::{collections::HashMap, sync::Mutex};
202
203 use async_trait::async_trait;
204 use prometheus::Registry;
205
206 use crate::utils::distributor::test::TestExecutor;
207
208 use super::*;
209
210 struct TestChecker;
211
212 #[async_trait]
213 impl ChecksTarget<String> for TestChecker {
214 async fn check(&self, target: &String) -> TargetState {
215 if ["foo", "bar"].contains(&target.as_str()) {
216 TargetState::Healthy
217 } else {
218 TargetState::Degraded
219 }
220 }
221 }
222
223 #[tokio::test]
224 async fn test_request_router_somewhat_healthy() {
225 let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
226
227 let router = BackendRouter::new(
228 &[
229 ("foo".to_string(), 1),
230 ("bar".to_string(), 2),
231 ("baz".to_string(), 3),
232 ],
233 executor.clone(),
234 Arc::new(TestChecker),
235 Strategy::WeightedRoundRobin,
236 Duration::from_millis(1),
237 health_check::Metrics::new(&Registry::new()),
238 distributor::Metrics::new(&Registry::new()),
239 );
240
241 tokio::time::sleep(Duration::from_millis(100)).await;
243
244 for _ in 0..900 {
246 assert!(router.execute(()).await.is_ok());
247 }
248
249 let h = executor.1.lock().unwrap();
251 assert_eq!(h["foo"], 300);
252 assert_eq!(h["bar"], 600);
253 assert!(!h.contains_key("baz"));
255 drop(h)
256 }
257
258 #[tokio::test]
259 async fn test_request_router_unhealthy() {
260 let executor = Arc::new(TestExecutor(Duration::ZERO, Mutex::new(HashMap::new())));
261
262 let router = BackendRouter::new(
263 &[("baz".to_string(), 3)],
264 executor.clone(),
265 Arc::new(TestChecker),
266 Strategy::WeightedRoundRobin,
267 Duration::from_millis(1),
268 health_check::Metrics::new(&Registry::new()),
269 distributor::Metrics::new(&Registry::new()),
270 );
271
272 tokio::time::sleep(Duration::from_millis(100)).await;
274
275 assert!(matches!(
276 router.execute(()).await.unwrap_err(),
277 Error::NoHealthyNodes
278 ));
279 }
280}