use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::error::{Error, Result};
use crate::proto::grpc::block::WorkerInfo;
use crate::proto::grpc::WorkerNetAddress;
const DEFAULT_FAILURE_TTL: Duration = Duration::from_secs(60);
const DEFAULT_WORKER_REFRESH_TTL: Duration = Duration::from_secs(30);
const VIRTUAL_NODES_PER_WORKER: u32 = 100;
pub struct WorkerRouter {
workers: RwLock<Arc<Vec<WorkerInfo>>>,
failed_workers: DashMap<String, Instant>,
failure_ttl: Duration,
last_refresh: RwLock<Instant>,
worker_refresh_ttl: Duration,
local_worker_id: RwLock<i64>,
}
impl WorkerRouter {
pub fn new() -> Self {
Self {
workers: RwLock::new(Arc::new(Vec::new())),
failed_workers: DashMap::new(),
failure_ttl: DEFAULT_FAILURE_TTL,
last_refresh: RwLock::new(Instant::now()),
worker_refresh_ttl: DEFAULT_WORKER_REFRESH_TTL,
local_worker_id: RwLock::new(0),
}
}
pub fn with_failure_ttl(failure_ttl: Duration) -> Self {
Self {
workers: RwLock::new(Arc::new(Vec::new())),
failed_workers: DashMap::new(),
failure_ttl,
last_refresh: RwLock::new(Instant::now()),
worker_refresh_ttl: DEFAULT_WORKER_REFRESH_TTL,
local_worker_id: RwLock::new(0),
}
}
pub fn with_ttls(failure_ttl: Duration, worker_refresh_ttl: Duration) -> Self {
Self {
workers: RwLock::new(Arc::new(Vec::new())),
failed_workers: DashMap::new(),
failure_ttl,
last_refresh: RwLock::new(Instant::now()),
worker_refresh_ttl,
local_worker_id: RwLock::new(0),
}
}
pub async fn update_workers(&self, workers: Vec<WorkerInfo>) {
let new_snapshot = Arc::new(workers);
let mut guard = self.workers.write().await;
*guard = new_snapshot;
*self.last_refresh.write().await = Instant::now();
*self.local_worker_id.write().await = 0;
}
pub async fn get_workers(&self) -> Arc<Vec<WorkerInfo>> {
self.workers.read().await.clone()
}
pub async fn needs_refresh(&self) -> bool {
self.last_refresh.read().await.elapsed() >= self.worker_refresh_ttl
}
pub async fn refresh_workers(&self, wm: &crate::client::WorkerManagerClient) -> Result<()> {
match wm.get_worker_info_list().await {
Ok(workers) => {
debug!(count = workers.len(), "worker list refreshed");
self.update_workers(workers).await;
Ok(())
}
Err(e) => {
warn!("worker list refresh failed, keeping stale list: {}", e);
*self.last_refresh.write().await = Instant::now();
Ok(())
}
}
}
async fn detect_local_worker(workers: &[WorkerInfo]) -> i64 {
let local_names = Self::local_hostnames();
for w in workers {
if let Some(addr) = &w.address {
let host = addr.host.as_deref().unwrap_or("");
if local_names.iter().any(|n| n == host) {
let id = w.id.unwrap_or(0);
debug!(host = %host, worker_id = id, "detected local worker");
return id;
}
}
}
0
}
fn local_hostnames() -> Vec<String> {
let mut names = vec![
"localhost".to_string(),
"127.0.0.1".to_string(),
"::1".to_string(),
];
if let Ok(h) = hostname::get() {
if let Ok(s) = h.into_string() {
names.push(s.clone());
if let Some(short) = s.split('.').next() {
names.push(short.to_string());
}
}
}
names
}
pub async fn select_worker(&self, block_id: i64) -> Result<WorkerInfo> {
let workers = self.workers.read().await.clone();
if workers.is_empty() {
return Err(Error::NoWorkerAvailable {
message: "no workers registered".to_string(),
});
}
self.cleanup_expired_failures();
{
let cached_id = *self.local_worker_id.read().await;
if cached_id == 0 {
let id = Self::detect_local_worker(&workers).await;
*self.local_worker_id.write().await = id;
}
}
{
let local_id = *self.local_worker_id.read().await;
if local_id > 0 {
if let Some(local_w) = workers.iter().find(|w| w.id == Some(local_id)) {
if let Some(addr) = &local_w.address {
if !self.is_failed(&worker_addr_key(addr)) {
return Ok(local_w.clone());
}
}
}
}
}
let eligible: Vec<&WorkerInfo> = workers
.iter()
.filter(|w| {
if let Some(addr) = w.address.as_ref() {
let key = worker_addr_key(addr);
!self.is_failed(&key)
} else {
false
}
})
.collect();
if eligible.is_empty() {
return self
.consistent_hash_select(block_id, &workers)
.ok_or_else(|| Error::NoWorkerAvailable {
message: "all workers are marked as failed".to_string(),
});
}
let worker_infos: Vec<WorkerInfo> = eligible.into_iter().cloned().collect();
self.consistent_hash_select(block_id, &worker_infos)
.ok_or_else(|| Error::NoWorkerAvailable {
message: format!("no suitable worker for block_id={}", block_id),
})
}
pub fn mark_failed(&self, addr: &WorkerNetAddress) {
let key = worker_addr_key(addr);
self.failed_workers.insert(key, Instant::now());
}
pub async fn pick_any_worker(&self) -> Result<WorkerInfo> {
let workers = self.workers.read().await.clone();
if workers.is_empty() {
return Err(Error::NoWorkerAvailable {
message: "no workers registered".to_string(),
});
}
self.cleanup_expired_failures();
let eligible: Vec<WorkerInfo> = workers
.iter()
.filter(|w| {
if let Some(addr) = w.address.as_ref() {
let key = worker_addr_key(addr);
!self.is_failed(&key)
} else {
false
}
})
.cloned()
.collect();
let pool = if eligible.is_empty() {
(*workers).clone()
} else {
eligible
};
if pool.is_empty() {
return Err(Error::NoWorkerAvailable {
message: "no eligible workers".to_string(),
});
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as usize)
.unwrap_or(0);
let idx = nanos % pool.len();
Ok(pool[idx].clone())
}
fn is_failed(&self, key: &str) -> bool {
if let Some(entry) = self.failed_workers.get(key) {
entry.value().elapsed() < self.failure_ttl
} else {
false
}
}
fn cleanup_expired_failures(&self) {
self.failed_workers
.retain(|_, v| v.elapsed() < self.failure_ttl);
}
fn consistent_hash_select(&self, block_id: i64, workers: &[WorkerInfo]) -> Option<WorkerInfo> {
if workers.is_empty() {
return None;
}
let mut ring: Vec<(u64, usize)> = Vec::new();
for (idx, worker) in workers.iter().enumerate() {
let worker_id = worker.id.unwrap_or(idx as i64);
let virtual_nodes = worker
.virtual_node_num
.unwrap_or(VIRTUAL_NODES_PER_WORKER as i32) as u32;
for vn in 0..virtual_nodes {
let hash = hash_key(&format!("{}:{}", worker_id, vn));
ring.push((hash, idx));
}
}
ring.sort_by_key(|(h, _)| *h);
let target = hash_key(&block_id.to_string());
let pos = ring
.binary_search_by_key(&target, |(h, _)| *h)
.unwrap_or_else(|pos| pos);
let pos = pos % ring.len();
Some(workers[ring[pos].1].clone())
}
}
impl Default for WorkerRouter {
fn default() -> Self {
Self::new()
}
}
fn worker_addr_key(addr: &WorkerNetAddress) -> String {
format!(
"{}:{}",
addr.host.as_deref().unwrap_or("unknown"),
addr.rpc_port.unwrap_or(0)
)
}
fn hash_key(key: &str) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_worker(id: i64, host: &str, port: i32) -> WorkerInfo {
WorkerInfo {
id: Some(id),
address: Some(WorkerNetAddress {
host: Some(host.to_string()),
rpc_port: Some(port),
..Default::default()
}),
..Default::default()
}
}
#[tokio::test]
async fn test_select_worker_empty() {
let router = WorkerRouter::new();
assert!(router.select_worker(123).await.is_err());
}
#[tokio::test]
async fn test_select_worker_deterministic() {
let router = WorkerRouter::new();
let workers = vec![
make_worker(1, "w1", 9203),
make_worker(2, "w2", 9203),
make_worker(3, "w3", 9203),
];
router.update_workers(workers).await;
let w1 = router.select_worker(42).await.unwrap();
let w2 = router.select_worker(42).await.unwrap();
assert_eq!(w1.id, w2.id);
}
#[tokio::test]
async fn test_failed_worker_filtered() {
let router = WorkerRouter::with_failure_ttl(Duration::from_secs(3600));
let workers = vec![make_worker(1, "w1", 9203), make_worker(2, "w2", 9203)];
router.update_workers(workers.clone()).await;
router.mark_failed(workers[0].address.as_ref().unwrap());
let selected = router.select_worker(42).await.unwrap();
assert_eq!(selected.id, Some(2));
}
#[tokio::test]
async fn test_pick_any_worker_empty() {
let router = WorkerRouter::new();
assert!(router.pick_any_worker().await.is_err());
}
#[tokio::test]
async fn test_pick_any_worker_returns_eligible() {
let router = WorkerRouter::with_failure_ttl(Duration::from_secs(3600));
let workers = vec![
make_worker(1, "w1", 9203),
make_worker(2, "w2", 9203),
make_worker(3, "w3", 9203),
];
router.update_workers(workers.clone()).await;
router.mark_failed(workers[0].address.as_ref().unwrap());
router.mark_failed(workers[1].address.as_ref().unwrap());
for _ in 0..10 {
let picked = router.pick_any_worker().await.unwrap();
assert_eq!(picked.id, Some(3));
}
}
#[tokio::test]
async fn test_pick_any_worker_fallback_when_all_failed() {
let router = WorkerRouter::with_failure_ttl(Duration::from_secs(3600));
let workers = vec![make_worker(1, "w1", 9203), make_worker(2, "w2", 9203)];
router.update_workers(workers.clone()).await;
router.mark_failed(workers[0].address.as_ref().unwrap());
router.mark_failed(workers[1].address.as_ref().unwrap());
let picked = router.pick_any_worker().await.unwrap();
assert!(picked.id == Some(1) || picked.id == Some(2));
}
#[tokio::test]
async fn test_needs_refresh_false_after_new() {
let router = WorkerRouter::new();
assert!(!router.needs_refresh().await);
}
#[tokio::test]
async fn test_needs_refresh_true_with_zero_ttl() {
let router = WorkerRouter::with_ttls(DEFAULT_FAILURE_TTL, Duration::ZERO);
tokio::time::sleep(Duration::from_millis(1)).await;
assert!(router.needs_refresh().await);
}
#[tokio::test]
async fn test_with_ttls_stores_values() {
let failure = Duration::from_secs(10);
let refresh = Duration::from_secs(5);
let router = WorkerRouter::with_ttls(failure, refresh);
assert_eq!(router.failure_ttl, failure);
assert_eq!(router.worker_refresh_ttl, refresh);
}
#[tokio::test]
async fn test_update_workers_resets_refresh_clock() {
let router = WorkerRouter::with_ttls(DEFAULT_FAILURE_TTL, Duration::ZERO);
tokio::time::sleep(Duration::from_millis(1)).await;
assert!(
router.needs_refresh().await,
"should need refresh before update"
);
let router2 = WorkerRouter::with_ttls(DEFAULT_FAILURE_TTL, Duration::from_secs(60));
router2
.update_workers(vec![make_worker(1, "w1", 9203)])
.await;
assert!(!router2.needs_refresh().await);
}
#[tokio::test]
async fn test_local_worker_preferred() {
let router = WorkerRouter::new();
let workers = vec![
make_worker(1, "remote1", 9203),
make_worker(2, "localhost", 9203), make_worker(3, "remote2", 9203),
];
router.update_workers(workers).await;
for block_id in [1i64, 42, 100, 999, 10_000] {
let selected = router.select_worker(block_id).await.unwrap();
assert_eq!(
selected.id,
Some(2),
"block_id={} should route to local worker",
block_id
);
}
}
#[tokio::test]
async fn test_local_worker_skipped_when_failed() {
let router = WorkerRouter::with_failure_ttl(Duration::from_secs(3600));
let local_worker = make_worker(2, "localhost", 9203);
let workers = vec![
make_worker(1, "remote1", 9203),
local_worker.clone(),
make_worker(3, "remote2", 9203),
];
router.update_workers(workers).await;
router.mark_failed(local_worker.address.as_ref().unwrap());
let selected = router.select_worker(42).await.unwrap();
assert_ne!(
selected.id,
Some(2),
"failed local worker should not be selected"
);
}
#[tokio::test]
async fn test_detect_local_worker_none() {
let workers = vec![
make_worker(1, "remote-host-a.example.com", 9203),
make_worker(2, "remote-host-b.example.com", 9203),
];
let id = WorkerRouter::detect_local_worker(&workers).await;
assert_eq!(id, 0);
}
#[tokio::test]
async fn test_detect_local_worker_loopback() {
let workers = vec![
make_worker(1, "10.0.0.1", 9203),
make_worker(2, "127.0.0.1", 9203),
];
let id = WorkerRouter::detect_local_worker(&workers).await;
assert_eq!(id, 2);
}
#[tokio::test]
async fn test_local_worker_cache_invalidated_on_update() {
let router = WorkerRouter::new();
router
.update_workers(vec![make_worker(1, "remote1", 9203)])
.await;
let _ = router.select_worker(1).await;
assert_eq!(*router.local_worker_id.read().await, 0);
router
.update_workers(vec![
make_worker(1, "remote1", 9203),
make_worker(2, "127.0.0.1", 9203),
])
.await;
assert_eq!(*router.local_worker_id.read().await, 0);
let selected = router.select_worker(1).await.unwrap();
assert_eq!(selected.id, Some(2), "new local worker should be preferred");
}
}