mod hashring;
pub use hashring::HashRing;
use pollen_membership::Membership;
use pollen_types::{MembershipEvent, NodeId, TaskId};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::broadcast;
use parking_lot::RwLock;
use tracing::{debug, info};
pub trait TaskRouter: Send + Sync + 'static {
fn owner(&self, task_id: &TaskId) -> Option<NodeId>;
fn replicas(&self, task_id: &TaskId, count: usize) -> Vec<NodeId>;
fn is_local(&self, task_id: &TaskId) -> bool;
fn local_tasks(&self) -> Vec<TaskId>;
fn subscribe(&self) -> broadcast::Receiver<OwnershipEvent>;
fn register_task(&self, task_id: TaskId);
fn unregister_task(&self, task_id: &TaskId);
}
#[derive(Clone, Debug)]
pub enum OwnershipEvent {
Acquired(Vec<TaskId>),
Released(Vec<TaskId>),
}
pub struct ConsistentHashRouter {
node_id: NodeId,
ring: RwLock<HashRing>,
tasks: RwLock<HashSet<TaskId>>,
event_tx: broadcast::Sender<OwnershipEvent>,
membership: Arc<dyn Membership>,
}
impl ConsistentHashRouter {
pub fn new(node_id: NodeId, membership: Arc<dyn Membership>) -> Self {
let (event_tx, _) = broadcast::channel(100);
let router = Self {
node_id,
ring: RwLock::new(HashRing::new(150)), tasks: RwLock::new(HashSet::new()),
event_tx,
membership,
};
router.update_ring();
router
}
pub fn start(self: Arc<Self>) {
let router = Arc::clone(&self);
let mut rx = self.membership.subscribe();
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(event) => {
router.handle_membership_event(event);
}
Err(broadcast::error::RecvError::Lagged(_)) => {
router.update_ring();
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
});
}
fn update_ring(&self) {
let members = self.membership.alive_members();
let mut ring = self.ring.write();
let old_local: HashSet<_> = self.tasks.read()
.iter()
.filter(|t| ring.get(t.to_string().as_bytes()).map(|n| *n == self.node_id).unwrap_or(false))
.cloned()
.collect();
ring.clear();
for member in members {
ring.add(member.id);
}
let new_local: HashSet<_> = self.tasks.read()
.iter()
.filter(|t| ring.get(t.to_string().as_bytes()).map(|n| *n == self.node_id).unwrap_or(false))
.cloned()
.collect();
let acquired: Vec<_> = new_local.difference(&old_local).cloned().collect();
let released: Vec<_> = old_local.difference(&new_local).cloned().collect();
if !acquired.is_empty() {
debug!("Acquired ownership of {} tasks", acquired.len());
let _ = self.event_tx.send(OwnershipEvent::Acquired(acquired));
}
if !released.is_empty() {
debug!("Released ownership of {} tasks", released.len());
let _ = self.event_tx.send(OwnershipEvent::Released(released));
}
}
fn handle_membership_event(&self, event: MembershipEvent) {
match event {
MembershipEvent::Joined(member) => {
info!("Node {} joined, updating ring", member.id);
self.ring.write().add(member.id);
self.recalculate_ownership();
}
MembershipEvent::Left(node_id) => {
info!("Node {} left, updating ring", node_id);
self.ring.write().remove(node_id);
self.recalculate_ownership();
}
MembershipEvent::StateChanged { id, old, new } => {
if new == pollen_types::MemberState::Dead {
info!("Node {} died, updating ring", id);
self.ring.write().remove(id);
self.recalculate_ownership();
} else if old == pollen_types::MemberState::Dead && new == pollen_types::MemberState::Alive {
info!("Node {} revived, updating ring", id);
self.ring.write().add(id);
self.recalculate_ownership();
}
}
_ => {}
}
}
fn recalculate_ownership(&self) {
let ring = self.ring.read();
let tasks = self.tasks.read();
let mut acquired = vec![];
let _released: Vec<pollen_types::TaskId> = vec![];
for task_id in tasks.iter() {
let key = task_id.to_string();
if let Some(owner) = ring.get(key.as_bytes()) {
if *owner == self.node_id {
acquired.push(task_id.clone());
}
}
}
if !acquired.is_empty() {
let _ = self.event_tx.send(OwnershipEvent::Acquired(acquired));
}
}
}
impl TaskRouter for ConsistentHashRouter {
fn owner(&self, task_id: &TaskId) -> Option<NodeId> {
let key = task_id.to_string();
self.ring.read().get(key.as_bytes()).copied()
}
fn replicas(&self, task_id: &TaskId, count: usize) -> Vec<NodeId> {
let key = task_id.to_string();
self.ring.read().get_n(key.as_bytes(), count)
}
fn is_local(&self, task_id: &TaskId) -> bool {
self.owner(task_id).map(|n| n == self.node_id).unwrap_or(false)
}
fn local_tasks(&self) -> Vec<TaskId> {
let ring = self.ring.read();
self.tasks
.read()
.iter()
.filter(|t| {
ring.get(t.to_string().as_bytes())
.map(|n| *n == self.node_id)
.unwrap_or(false)
})
.cloned()
.collect()
}
fn subscribe(&self) -> broadcast::Receiver<OwnershipEvent> {
self.event_tx.subscribe()
}
fn register_task(&self, task_id: TaskId) {
self.tasks.write().insert(task_id);
}
fn unregister_task(&self, task_id: &TaskId) {
self.tasks.write().remove(task_id);
}
}
pub type SharedRouter = Arc<dyn TaskRouter>;
#[cfg(test)]
mod tests {
use super::*;
use pollen_types::Result;
struct MockMembership {
node_id: NodeId,
event_tx: broadcast::Sender<MembershipEvent>,
}
impl MockMembership {
fn new(node_id: NodeId) -> Self {
let (event_tx, _) = broadcast::channel(100);
Self { node_id, event_tx }
}
}
#[async_trait::async_trait]
impl Membership for MockMembership {
fn members(&self) -> Vec<pollen_types::Member> {
vec![pollen_types::Member::new(self.node_id, "127.0.0.1:7000".parse().unwrap())]
}
fn alive_members(&self) -> Vec<pollen_types::Member> {
self.members()
}
fn is_alive(&self, node_id: NodeId) -> bool {
node_id == self.node_id
}
fn local(&self) -> pollen_types::Member {
pollen_types::Member::new(self.node_id, "127.0.0.1:7000".parse().unwrap())
}
fn subscribe(&self) -> broadcast::Receiver<MembershipEvent> {
self.event_tx.subscribe()
}
async fn set_metadata(&self, _key: String, _value: String) -> Result<()> {
Ok(())
}
fn get_metadata(&self, _node_id: NodeId, _key: &str) -> Option<String> {
None
}
async fn leave(&self) -> Result<()> {
Ok(())
}
async fn shutdown(&self) {}
}
#[test]
fn test_single_node_routing() {
let node_id = NodeId::new();
let membership = Arc::new(MockMembership::new(node_id));
let router = ConsistentHashRouter::new(node_id, membership);
let task_id = TaskId::new();
router.register_task(task_id.clone());
assert!(router.is_local(&task_id));
assert_eq!(router.owner(&task_id), Some(node_id));
}
}