1mod hashring;
6
7pub use hashring::HashRing;
8
9use pollen_membership::Membership;
10use pollen_types::{MembershipEvent, NodeId, TaskId};
11use std::collections::HashSet;
12use std::sync::Arc;
13use tokio::sync::broadcast;
14use parking_lot::RwLock;
15use tracing::{debug, info};
16
17pub trait TaskRouter: Send + Sync + 'static {
19 fn owner(&self, task_id: &TaskId) -> Option<NodeId>;
21
22 fn replicas(&self, task_id: &TaskId, count: usize) -> Vec<NodeId>;
24
25 fn is_local(&self, task_id: &TaskId) -> bool;
27
28 fn local_tasks(&self) -> Vec<TaskId>;
30
31 fn subscribe(&self) -> broadcast::Receiver<OwnershipEvent>;
33
34 fn register_task(&self, task_id: TaskId);
36
37 fn unregister_task(&self, task_id: &TaskId);
39}
40
41#[derive(Clone, Debug)]
43pub enum OwnershipEvent {
44 Acquired(Vec<TaskId>),
46 Released(Vec<TaskId>),
48}
49
50pub struct ConsistentHashRouter {
52 node_id: NodeId,
54 ring: RwLock<HashRing>,
56 tasks: RwLock<HashSet<TaskId>>,
58 event_tx: broadcast::Sender<OwnershipEvent>,
60 membership: Arc<dyn Membership>,
62}
63
64impl ConsistentHashRouter {
65 pub fn new(node_id: NodeId, membership: Arc<dyn Membership>) -> Self {
67 let (event_tx, _) = broadcast::channel(100);
68
69 let router = Self {
70 node_id,
71 ring: RwLock::new(HashRing::new(150)), tasks: RwLock::new(HashSet::new()),
73 event_tx,
74 membership,
75 };
76
77 router.update_ring();
79
80 router
81 }
82
83 pub fn start(self: Arc<Self>) {
85 let router = Arc::clone(&self);
86 let mut rx = self.membership.subscribe();
87
88 tokio::spawn(async move {
89 loop {
90 match rx.recv().await {
91 Ok(event) => {
92 router.handle_membership_event(event);
93 }
94 Err(broadcast::error::RecvError::Lagged(_)) => {
95 router.update_ring();
97 }
98 Err(broadcast::error::RecvError::Closed) => {
99 break;
100 }
101 }
102 }
103 });
104 }
105
106 fn update_ring(&self) {
108 let members = self.membership.alive_members();
109 let mut ring = self.ring.write();
110
111 let old_local: HashSet<_> = self.tasks.read()
113 .iter()
114 .filter(|t| ring.get(t.to_string().as_bytes()).map(|n| *n == self.node_id).unwrap_or(false))
115 .cloned()
116 .collect();
117
118 ring.clear();
120 for member in members {
121 ring.add(member.id);
122 }
123
124 let new_local: HashSet<_> = self.tasks.read()
126 .iter()
127 .filter(|t| ring.get(t.to_string().as_bytes()).map(|n| *n == self.node_id).unwrap_or(false))
128 .cloned()
129 .collect();
130
131 let acquired: Vec<_> = new_local.difference(&old_local).cloned().collect();
133 let released: Vec<_> = old_local.difference(&new_local).cloned().collect();
134
135 if !acquired.is_empty() {
136 debug!("Acquired ownership of {} tasks", acquired.len());
137 let _ = self.event_tx.send(OwnershipEvent::Acquired(acquired));
138 }
139
140 if !released.is_empty() {
141 debug!("Released ownership of {} tasks", released.len());
142 let _ = self.event_tx.send(OwnershipEvent::Released(released));
143 }
144 }
145
146 fn handle_membership_event(&self, event: MembershipEvent) {
148 match event {
149 MembershipEvent::Joined(member) => {
150 info!("Node {} joined, updating ring", member.id);
151 self.ring.write().add(member.id);
152 self.recalculate_ownership();
153 }
154 MembershipEvent::Left(node_id) => {
155 info!("Node {} left, updating ring", node_id);
156 self.ring.write().remove(node_id);
157 self.recalculate_ownership();
158 }
159 MembershipEvent::StateChanged { id, old, new } => {
160 if new == pollen_types::MemberState::Dead {
161 info!("Node {} died, updating ring", id);
162 self.ring.write().remove(id);
163 self.recalculate_ownership();
164 } else if old == pollen_types::MemberState::Dead && new == pollen_types::MemberState::Alive {
165 info!("Node {} revived, updating ring", id);
166 self.ring.write().add(id);
167 self.recalculate_ownership();
168 }
169 }
170 _ => {}
171 }
172 }
173
174 fn recalculate_ownership(&self) {
176 let ring = self.ring.read();
177 let tasks = self.tasks.read();
178
179 let mut acquired = vec![];
180 let _released: Vec<pollen_types::TaskId> = vec![];
181
182 for task_id in tasks.iter() {
183 let key = task_id.to_string();
184 if let Some(owner) = ring.get(key.as_bytes()) {
185 if *owner == self.node_id {
186 acquired.push(task_id.clone());
188 }
189 }
190 }
191
192 if !acquired.is_empty() {
196 let _ = self.event_tx.send(OwnershipEvent::Acquired(acquired));
197 }
198 }
199}
200
201impl TaskRouter for ConsistentHashRouter {
202 fn owner(&self, task_id: &TaskId) -> Option<NodeId> {
203 let key = task_id.to_string();
204 self.ring.read().get(key.as_bytes()).copied()
205 }
206
207 fn replicas(&self, task_id: &TaskId, count: usize) -> Vec<NodeId> {
208 let key = task_id.to_string();
209 self.ring.read().get_n(key.as_bytes(), count)
210 }
211
212 fn is_local(&self, task_id: &TaskId) -> bool {
213 self.owner(task_id).map(|n| n == self.node_id).unwrap_or(false)
214 }
215
216 fn local_tasks(&self) -> Vec<TaskId> {
217 let ring = self.ring.read();
218 self.tasks
219 .read()
220 .iter()
221 .filter(|t| {
222 ring.get(t.to_string().as_bytes())
223 .map(|n| *n == self.node_id)
224 .unwrap_or(false)
225 })
226 .cloned()
227 .collect()
228 }
229
230 fn subscribe(&self) -> broadcast::Receiver<OwnershipEvent> {
231 self.event_tx.subscribe()
232 }
233
234 fn register_task(&self, task_id: TaskId) {
235 self.tasks.write().insert(task_id);
236 }
237
238 fn unregister_task(&self, task_id: &TaskId) {
239 self.tasks.write().remove(task_id);
240 }
241}
242
243pub type SharedRouter = Arc<dyn TaskRouter>;
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use pollen_types::Result;
250
251 struct MockMembership {
253 node_id: NodeId,
254 event_tx: broadcast::Sender<MembershipEvent>,
255 }
256
257 impl MockMembership {
258 fn new(node_id: NodeId) -> Self {
259 let (event_tx, _) = broadcast::channel(100);
260 Self { node_id, event_tx }
261 }
262 }
263
264 #[async_trait::async_trait]
265 impl Membership for MockMembership {
266 fn members(&self) -> Vec<pollen_types::Member> {
267 vec![pollen_types::Member::new(self.node_id, "127.0.0.1:7000".parse().unwrap())]
268 }
269
270 fn alive_members(&self) -> Vec<pollen_types::Member> {
271 self.members()
272 }
273
274 fn is_alive(&self, node_id: NodeId) -> bool {
275 node_id == self.node_id
276 }
277
278 fn local(&self) -> pollen_types::Member {
279 pollen_types::Member::new(self.node_id, "127.0.0.1:7000".parse().unwrap())
280 }
281
282 fn subscribe(&self) -> broadcast::Receiver<MembershipEvent> {
283 self.event_tx.subscribe()
284 }
285
286 async fn set_metadata(&self, _key: String, _value: String) -> Result<()> {
287 Ok(())
288 }
289
290 fn get_metadata(&self, _node_id: NodeId, _key: &str) -> Option<String> {
291 None
292 }
293
294 async fn leave(&self) -> Result<()> {
295 Ok(())
296 }
297
298 async fn shutdown(&self) {}
299 }
300
301 #[test]
302 fn test_single_node_routing() {
303 let node_id = NodeId::new();
304 let membership = Arc::new(MockMembership::new(node_id));
305 let router = ConsistentHashRouter::new(node_id, membership);
306
307 let task_id = TaskId::new();
308 router.register_task(task_id.clone());
309
310 assert!(router.is_local(&task_id));
312 assert_eq!(router.owner(&task_id), Some(node_id));
313 }
314}