use crate::types::{ShardId, ShardInfo, ShardStatus};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct RegisteredShard {
pub info: ShardInfo,
pub status: ShardStatus,
pub memory_bytes: u64,
}
impl RegisteredShard {
pub fn new(info: ShardInfo) -> Self {
Self {
info,
status: ShardStatus::Online,
memory_bytes: 0,
}
}
}
pub struct ShardRegistry {
shards: HashMap<ShardId, RegisteredShard>,
next_id: u32,
heartbeat_timeout_ms: u64,
}
impl ShardRegistry {
pub fn new() -> Self {
Self {
shards: HashMap::new(),
next_id: 0,
heartbeat_timeout_ms: 30_000, }
}
pub fn with_heartbeat_timeout(timeout_ms: u64) -> Self {
Self {
shards: HashMap::new(),
next_id: 0,
heartbeat_timeout_ms: timeout_ms,
}
}
pub fn register(&mut self, info: ShardInfo) -> ShardId {
let id = ShardId::new(self.next_id);
self.next_id += 1;
let mut registered = RegisteredShard::new(info);
registered.info.id = id;
registered.info.last_heartbeat = Self::current_timestamp();
registered.status = ShardStatus::Online;
self.shards.insert(id, registered);
id
}
pub fn register_with_id(&mut self, info: ShardInfo, id: ShardId) -> ShardId {
let mut registered = RegisteredShard::new(info);
registered.info.id = id;
registered.info.last_heartbeat = Self::current_timestamp();
registered.status = ShardStatus::Online;
self.shards.insert(id, registered);
if id.0 >= self.next_id {
self.next_id = id.0 + 1;
}
id
}
pub fn get(&self, id: &ShardId) -> Option<&ShardInfo> {
self.shards.get(id).map(|r| &r.info)
}
pub fn get_registered(&self, id: &ShardId) -> Option<&RegisteredShard> {
self.shards.get(id)
}
pub fn get_registered_mut(&mut self, id: &ShardId) -> Option<&mut RegisteredShard> {
self.shards.get_mut(id)
}
pub fn remove(&mut self, id: &ShardId) -> Option<ShardInfo> {
self.shards.remove(id).map(|r| r.info)
}
pub fn all(&self) -> Vec<ShardInfo> {
self.shards.values().map(|r| r.info.clone()).collect()
}
pub fn all_ids(&self) -> Vec<ShardId> {
self.shards.keys().copied().collect()
}
pub fn count(&self) -> usize {
self.shards.len()
}
pub fn contains(&self, id: &ShardId) -> bool {
self.shards.contains_key(id)
}
pub fn heartbeat(&mut self, id: &ShardId) {
if let Some(registered) = self.shards.get_mut(id) {
registered.info.last_heartbeat = Self::current_timestamp();
if registered.status == ShardStatus::Recovering {
registered.status = ShardStatus::Online;
}
}
}
pub fn heartbeat_with_timestamp(&mut self, id: &ShardId, timestamp: u64) {
if let Some(registered) = self.shards.get_mut(id) {
registered.info.last_heartbeat = timestamp;
}
}
pub fn set_status(&mut self, id: &ShardId, status: ShardStatus) {
if let Some(registered) = self.shards.get_mut(id) {
registered.status = status;
}
}
pub fn get_status(&self, id: &ShardId) -> Option<ShardStatus> {
self.shards.get(id).map(|r| r.status)
}
pub fn update_metrics(&mut self, id: &ShardId, document_count: usize, memory_bytes: u64) {
if let Some(registered) = self.shards.get_mut(id) {
registered.info.document_count = document_count;
registered.memory_bytes = memory_bytes;
}
}
pub fn online_shards(&self) -> Vec<ShardInfo> {
self.shards
.values()
.filter(|r| r.status == ShardStatus::Online)
.map(|r| r.info.clone())
.collect()
}
pub fn shards_with_status(&self, status: ShardStatus) -> Vec<ShardInfo> {
self.shards
.values()
.filter(|r| r.status == status)
.map(|r| r.info.clone())
.collect()
}
pub fn check_dead_shards(&mut self) -> Vec<ShardId> {
let now = Self::current_timestamp();
let timeout = self.heartbeat_timeout_ms;
let mut dead_shards = Vec::new();
for (id, registered) in self.shards.iter_mut() {
if registered.status == ShardStatus::Online
&& now - registered.info.last_heartbeat > timeout
{
registered.status = ShardStatus::Offline;
dead_shards.push(*id);
}
}
dead_shards
}
pub fn total_documents(&self) -> u64 {
self.shards
.values()
.map(|r| r.info.document_count as u64)
.sum()
}
pub fn total_memory(&self) -> u64 {
self.shards.values().map(|r| r.memory_bytes).sum()
}
pub fn least_loaded_shard(&self) -> Option<ShardId> {
self.shards
.values()
.filter(|r| r.status == ShardStatus::Online)
.min_by_key(|r| r.info.document_count)
.map(|r| r.info.id)
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
}
impl Default for ShardRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_shard_info() -> ShardInfo {
ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string())
}
#[test]
fn test_registry_creation() {
let registry = ShardRegistry::new();
assert_eq!(registry.count(), 0);
}
#[test]
fn test_register_shard() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register(info);
assert_eq!(id, ShardId::new(0));
assert_eq!(registry.count(), 1);
let info2 = test_shard_info();
let id2 = registry.register(info2);
assert_eq!(id2, ShardId::new(1));
assert_eq!(registry.count(), 2);
}
#[test]
fn test_get_shard() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register(info);
let retrieved = registry.get(&id).unwrap();
assert_eq!(retrieved.id, id);
assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
}
#[test]
fn test_remove_shard() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register(info);
assert!(registry.contains(&id));
let removed = registry.remove(&id);
assert!(removed.is_some());
assert!(!registry.contains(&id));
}
#[test]
fn test_set_status() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register(info);
assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
registry.set_status(&id, ShardStatus::Draining);
assert_eq!(registry.get_status(&id), Some(ShardStatus::Draining));
}
#[test]
fn test_update_metrics() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register(info);
registry.update_metrics(&id, 100, 1024 * 1024);
let shard = registry.get(&id).unwrap();
assert_eq!(shard.document_count, 100);
let registered = registry.get_registered(&id).unwrap();
assert_eq!(registered.memory_bytes, 1024 * 1024);
}
#[test]
fn test_online_shards() {
let mut registry = ShardRegistry::new();
let id1 = registry.register(test_shard_info());
let id2 = registry.register(test_shard_info());
let _id3 = registry.register(test_shard_info());
registry.set_status(&id2, ShardStatus::Offline);
let online = registry.online_shards();
assert_eq!(online.len(), 2);
assert!(online.iter().all(|s| s.id != id2));
}
#[test]
fn test_total_documents() {
let mut registry = ShardRegistry::new();
let id1 = registry.register(test_shard_info());
let id2 = registry.register(test_shard_info());
registry.update_metrics(&id1, 100, 1000);
registry.update_metrics(&id2, 200, 2000);
assert_eq!(registry.total_documents(), 300);
assert_eq!(registry.total_memory(), 3000);
}
#[test]
fn test_least_loaded_shard() {
let mut registry = ShardRegistry::new();
let id1 = registry.register(test_shard_info());
let id2 = registry.register(test_shard_info());
let id3 = registry.register(test_shard_info());
registry.update_metrics(&id1, 100, 1000);
registry.update_metrics(&id2, 50, 500);
registry.update_metrics(&id3, 200, 2000);
assert_eq!(registry.least_loaded_shard(), Some(id2));
}
#[test]
fn test_check_dead_shards() {
let mut registry = ShardRegistry::with_heartbeat_timeout(100);
let info = test_shard_info();
let id = registry.register(info);
registry.heartbeat_with_timestamp(&id, 0);
let dead = registry.check_dead_shards();
assert_eq!(dead.len(), 1);
assert_eq!(dead[0], id);
assert_eq!(registry.get_status(&id), Some(ShardStatus::Offline));
}
#[test]
fn test_register_with_specific_id() {
let mut registry = ShardRegistry::new();
let info = test_shard_info();
let id = registry.register_with_id(info, ShardId::new(42));
assert_eq!(id, ShardId::new(42));
assert!(registry.contains(&id));
let info2 = test_shard_info();
let id2 = registry.register(info2);
assert_eq!(id2, ShardId::new(43));
}
}