1use std::sync::{
2 atomic::{AtomicBool, AtomicU8},
3 Arc,
4};
5
6use thiserror::Error;
7use tokio::sync::RwLock;
8
9use crate::driver::Driver;
10
11pub struct NodePool<D>
13where
14 D: Driver + Send + Sync,
15{
16 node_id: String,
17
18 pre_nodes: RwLock<Vec<String>>,
19 hash: RwLock<hashring::HashRing<String>>,
20 driver: Arc<D>,
21
22 state_lock: AtomicBool,
23 stop: AtomicBool,
24}
25
26impl<D> std::fmt::Debug for NodePool<D>
27where
28 D: Driver + Send + Sync + std::fmt::Debug,
29{
30 fn fmt(
31 &self,
32 f: &mut std::fmt::Formatter<'_>,
33 ) -> std::fmt::Result {
34 f.debug_struct("NodePool")
35 .field("node_id", &self.node_id)
36 .field("pre_nodes", &self.pre_nodes)
37 .field("hash", &self.hash)
38 .field("driver", &self.driver)
39 .field("state_lock", &self.state_lock)
40 .field("stop", &self.stop)
41 .finish()
42 }
43}
44
45#[derive(Debug, Error)]
46pub enum Error<D>
47where
48 D: Driver + Send + Sync,
49{
50 #[error("No node available")]
51 NoNodeAvailable,
52 #[error("Driver error: {0}")]
53 DriverError(D::Error),
54}
55
56impl<D> NodePool<D>
57where
58 D: Driver + Send + Sync,
59{
60 pub async fn new(mut driver: D) -> Result<Self, Error<D>> {
62 driver.start().await.map_err(Error::DriverError)?;
63
64 let mut pre_nodes = Vec::new();
66 let mut hash = hashring::HashRing::new();
67 let state_lock = AtomicBool::new(false);
68
69 update_hash_ring::<D>(
70 &mut pre_nodes,
71 &state_lock,
72 &mut hash,
73 &driver.get_nodes().await.map_err(Error::DriverError)?,
74 )
75 .await?;
76
77 Ok(Self {
78 node_id: driver.node_id(),
79 pre_nodes: RwLock::new(pre_nodes),
80 hash: RwLock::new(hash),
81 driver: Arc::new(driver),
82 state_lock,
83 stop: AtomicBool::new(false),
84 })
85 }
86
87 pub async fn start(&self) -> Result<(), Error<D>> {
89 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
90 let error_time = AtomicU8::new(0);
91
92 loop {
93 interval.tick().await;
94 if self.stop.load(std::sync::atomic::Ordering::SeqCst) {
95 return Ok(());
96 }
97
98 {
100 let nodes = match self.driver.get_nodes().await {
101 Ok(nodes) => nodes,
102 Err(_) => continue,
103 };
104
105 let mut pre_nodes = self.pre_nodes.write().await;
106 let mut hash = self.hash.write().await;
107
108 match update_hash_ring::<D>(&mut pre_nodes, &self.state_lock, &mut hash, &nodes).await {
109 Ok(_) => {
110 error_time.store(0, std::sync::atomic::Ordering::SeqCst);
111 }
112 Err(_) => {
113 error_time.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
114 tracing::error!("Failed to update hash ring");
115 }
116 }
117 }
118
119 if error_time.load(std::sync::atomic::Ordering::SeqCst) >= 5 {
120 panic!("Failed to update hash ring 5 times")
121 }
122 }
123 }
124
125 pub(crate) async fn check_job_available(
127 &self,
128 job_name: &str,
129 ) -> Result<bool, Error<D>> {
130 let hash = self.hash.read().await;
131 match hash.get(&job_name) {
132 Some(node) if node == &self.node_id => Ok(true),
133 Some(_) => Ok(false),
134 None => Err(Error::NoNodeAvailable),
135 }
136 }
137
138 pub fn stop(&self) {
139 self.stop.store(true, std::sync::atomic::Ordering::SeqCst);
140 }
141}
142
143impl<D> Drop for NodePool<D>
144where
145 D: Driver + Send + Sync,
146{
147 fn drop(&mut self) {
148 self.stop();
149 }
150}
151
152async fn update_hash_ring<D>(
161 pre_nodes: &mut Vec<String>,
162 state_lock: &AtomicBool,
163 hash: &mut hashring::HashRing<String>,
164 nodes: &Vec<String>,
165) -> Result<(), Error<D>>
166where
167 D: Driver + Send + Sync,
168{
169 if equal_ring(nodes, pre_nodes) {
170 tracing::trace!("Nodes are equal, skipping update, nodes: {:?}", nodes);
171 return Ok(());
172 }
173
174 tracing::info!(
175 "Nodes detected, updating hash ring, pre_nodes: {:?}, now_nodes: {:?}",
176 pre_nodes,
177 nodes
178 );
179
180 if state_lock
182 .compare_exchange(
183 false,
184 true,
185 std::sync::atomic::Ordering::SeqCst,
186 std::sync::atomic::Ordering::SeqCst,
187 )
188 .is_err()
189 {
190 return Ok(());
191 }
192
193 pre_nodes.clone_from(nodes);
195
196 *hash = hashring::HashRing::new();
197 for node in nodes {
198 hash.add(node.clone());
199 }
200
201 state_lock.store(false, std::sync::atomic::Ordering::SeqCst);
203
204 Ok(())
205}
206
207fn equal_ring(
209 a: &[String],
210 b: &[String],
211) -> bool {
212 if a.len() != b.len() {
213 return false;
214 }
215
216 let mut a_sorted = a.to_vec();
217 let mut pre_nodes_sorted = b.to_vec();
218
219 a_sorted.sort();
220 pre_nodes_sorted.sort();
221
222 a_sorted == pre_nodes_sorted
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::driver::local::LocalDriver;
229
230 #[tokio::test]
231 async fn test_update_hash_ring() {
232 let mut pre_nodes = Vec::new();
233 let state_lock = AtomicBool::new(false);
234 let mut hash = hashring::HashRing::new();
235
236 let nodes = vec!["node1".to_string(), "node2".to_string()];
237
238 update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
239 .await
240 .unwrap();
241 assert_eq!(pre_nodes, nodes);
242 assert_eq!(hash.get(&"test"), Some(&"node2".to_string()));
243
244 let nodes = vec!["node1".to_string(), "node2".to_string(), "node3".to_string()];
245
246 update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
247 .await
248 .unwrap();
249 assert_eq!(pre_nodes, nodes);
250 assert_eq!(hash.get(&"test"), Some(&"node2".to_string()));
251
252 let nodes = vec!["node1".to_string(), "node3".to_string()];
253
254 update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
255 .await
256 .unwrap();
257 assert_eq!(pre_nodes, nodes);
258 assert_eq!(hash.get(&"test"), Some(&"node3".to_string()));
259
260 let nodes = vec!["node1".to_string(), "node3".to_string()];
261
262 update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
263 .await
264 .unwrap();
265 assert_eq!(pre_nodes, nodes);
266 assert_eq!(hash.get(&"test"), Some(&"node3".to_string()));
267 }
268
269 #[tokio::test]
270 async fn test_equal_ring() {
271 let a = vec!["node1".to_string(), "node2".to_string()];
272 let b = vec!["node1".to_string(), "node2".to_string()];
273
274 assert!(equal_ring(&a, &b));
275
276 let a = vec!["node1".to_string(), "node2".to_string()];
277 let b = vec!["node2".to_string(), "node1".to_string()];
278
279 assert!(equal_ring(&a, &b));
280
281 let a = vec!["node1".to_string(), "node2".to_string()];
282 let b = vec!["node1".to_string(), "node3".to_string()];
283 assert!(!equal_ring(&a, &b));
284 }
285}