use crate::entity_manager::EntityManager;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Notify, RwLock};
use tokio_util::sync::CancellationToken;
use tracing::{debug, instrument};
const MIN_RESOLUTION_MS: u64 = 5_000;
const DEFAULT_RESOLUTION_MS: u64 = 30_000;
pub struct EntityReaper {
managers: RwLock<Vec<RegisteredManager>>,
cancel: CancellationToken,
current_resolution_ms: AtomicU64,
latch: Notify,
}
struct RegisteredManager {
manager: Arc<EntityManager>,
max_idle: Duration,
}
impl EntityReaper {
pub fn new(cancel: CancellationToken) -> Self {
Self {
managers: RwLock::new(Vec::new()),
cancel,
current_resolution_ms: AtomicU64::new(DEFAULT_RESOLUTION_MS),
latch: Notify::new(),
}
}
pub async fn register(&self, manager: Arc<EntityManager>, max_idle: Duration) {
let max_idle_ms = max_idle.as_millis() as u64;
loop {
let current = self.current_resolution_ms.load(Ordering::Acquire);
let new_resolution = current.min(max_idle_ms).max(MIN_RESOLUTION_MS);
if new_resolution == current
|| self
.current_resolution_ms
.compare_exchange_weak(
current,
new_resolution,
Ordering::Release,
Ordering::Acquire,
)
.is_ok()
{
break;
}
}
let mut managers = self.managers.write().await;
let was_empty = managers.is_empty();
managers.push(RegisteredManager { manager, max_idle });
if was_empty {
self.latch.notify_waiters();
}
}
pub async fn run(&self) {
if self.managers.read().await.is_empty() {
tokio::select! {
_ = self.cancel.cancelled() => return,
_ = self.latch.notified() => {}
}
}
loop {
let resolution_ms = self.current_resolution_ms.load(Ordering::Acquire);
let interval = Duration::from_millis(resolution_ms);
tokio::select! {
_ = self.cancel.cancelled() => break,
_ = tokio::time::sleep(interval) => {
self.reap_all().await;
}
}
}
}
pub fn current_resolution(&self) -> Duration {
Duration::from_millis(self.current_resolution_ms.load(Ordering::Acquire))
}
#[instrument(level = "debug", skip(self))]
pub async fn reap_all(&self) -> usize {
let managers = self.managers.read().await;
let mut total_reaped = 0;
for entry in managers.iter() {
let reaped = entry.manager.reap_idle(entry.max_idle).await;
if reaped > 0 {
debug!(
entity_type = %entry.manager.entity().entity_type(),
reaped,
"reaped idle entities"
);
}
total_reaped += reaped;
}
total_reaped
}
pub async fn manager_count(&self) -> usize {
self.managers.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ShardingConfig;
use crate::entity::{Entity, EntityContext, EntityHandler};
use crate::envelope::EnvelopeRequest;
use crate::error::ClusterError;
use crate::message::IncomingMessage;
use crate::snowflake::{Snowflake, SnowflakeGenerator};
use crate::types::{EntityAddress, EntityId, EntityType, RunnerAddress, ShardId};
use async_trait::async_trait;
use std::collections::HashMap;
use tokio::sync::mpsc;
struct SimpleEntity;
#[async_trait]
impl Entity for SimpleEntity {
fn entity_type(&self) -> EntityType {
EntityType::new("Simple")
}
async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
Ok(Box::new(SimpleHandler))
}
}
struct SimpleHandler;
#[async_trait]
impl EntityHandler for SimpleHandler {
async fn handle_request(
&self,
_tag: &str,
_payload: &[u8],
_headers: &HashMap<String, String>,
) -> Result<Vec<u8>, ClusterError> {
Ok(vec![])
}
}
fn make_manager() -> Arc<EntityManager> {
Arc::new(EntityManager::new(
Arc::new(SimpleEntity),
Arc::new(ShardingConfig::default()),
RunnerAddress::new("127.0.0.1", 9000),
Arc::new(SnowflakeGenerator::new()),
None,
))
}
fn test_request(entity_id: &str) -> (IncomingMessage, mpsc::Receiver<crate::reply::Reply>) {
let addr = EntityAddress {
shard_id: ShardId::new("default", 0),
entity_type: EntityType::new("Simple"),
entity_id: EntityId::new(entity_id),
};
let req = EnvelopeRequest {
request_id: Snowflake(1),
address: addr,
tag: "test".to_string(),
payload: vec![],
headers: HashMap::new(),
span_id: None,
trace_id: None,
sampled: None,
persisted: false,
uninterruptible: Default::default(),
deliver_at: None,
};
let (tx, rx) = mpsc::channel(1);
(
IncomingMessage::Request {
request: req,
reply_tx: tx,
},
rx,
)
}
#[tokio::test]
async fn register_adds_manager() {
let reaper = EntityReaper::new(CancellationToken::new());
assert_eq!(reaper.manager_count().await, 0);
reaper
.register(make_manager(), Duration::from_secs(60))
.await;
assert_eq!(reaper.manager_count().await, 1);
}
#[tokio::test]
async fn reap_all_removes_idle_entities() {
let reaper = EntityReaper::new(CancellationToken::new());
let mgr = make_manager();
let (msg, mut rx) = test_request("e-1");
mgr.send_local(msg).await.unwrap();
rx.recv().await.unwrap();
assert_eq!(mgr.active_count(), 1);
reaper.register(mgr.clone(), Duration::from_millis(1)).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let reaped = reaper.reap_all().await;
assert_eq!(reaped, 1);
assert_eq!(mgr.active_count(), 0);
}
#[tokio::test]
async fn reap_all_skips_non_idle_entities() {
let reaper = EntityReaper::new(CancellationToken::new());
let mgr = make_manager();
let (msg, mut rx) = test_request("e-1");
mgr.send_local(msg).await.unwrap();
rx.recv().await.unwrap();
reaper
.register(mgr.clone(), Duration::from_secs(3600))
.await;
let reaped = reaper.reap_all().await;
assert_eq!(reaped, 0);
assert_eq!(mgr.active_count(), 1);
}
#[tokio::test]
async fn run_loop_cancels_cleanly() {
let cancel = CancellationToken::new();
let reaper = Arc::new(EntityReaper::new(cancel.clone()));
reaper
.register(make_manager(), Duration::from_secs(60))
.await;
let reaper_clone = Arc::clone(&reaper);
let handle = tokio::spawn(async move {
reaper_clone.run().await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
cancel.cancel();
tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("reaper should stop")
.expect("task should not panic");
}
#[tokio::test]
async fn multiple_managers_reaped() {
let reaper = EntityReaper::new(CancellationToken::new());
let mgr1 = make_manager();
let mgr2 = make_manager();
let (msg1, mut rx1) = test_request("e-1");
mgr1.send_local(msg1).await.unwrap();
rx1.recv().await.unwrap();
let (msg2, mut rx2) = test_request("e-2");
mgr2.send_local(msg2).await.unwrap();
rx2.recv().await.unwrap();
reaper
.register(mgr1.clone(), Duration::from_millis(1))
.await;
reaper
.register(mgr2.clone(), Duration::from_millis(1))
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
let reaped = reaper.reap_all().await;
assert_eq!(reaped, 2);
assert_eq!(mgr1.active_count(), 0);
assert_eq!(mgr2.active_count(), 0);
}
#[tokio::test]
async fn dynamic_resolution_adapts_to_shortest_idle_time() {
let reaper = EntityReaper::new(CancellationToken::new());
assert_eq!(reaper.current_resolution(), Duration::from_secs(30));
reaper
.register(make_manager(), Duration::from_secs(20))
.await;
assert_eq!(reaper.current_resolution(), Duration::from_secs(20));
reaper
.register(make_manager(), Duration::from_secs(10))
.await;
assert_eq!(reaper.current_resolution(), Duration::from_secs(10));
reaper
.register(make_manager(), Duration::from_secs(60))
.await;
assert_eq!(reaper.current_resolution(), Duration::from_secs(10));
}
#[tokio::test]
async fn dynamic_resolution_floored_at_5_seconds() {
let reaper = EntityReaper::new(CancellationToken::new());
reaper
.register(make_manager(), Duration::from_secs(1))
.await;
assert_eq!(reaper.current_resolution(), Duration::from_secs(5));
reaper
.register(make_manager(), Duration::from_millis(100))
.await;
assert_eq!(reaper.current_resolution(), Duration::from_secs(5));
}
#[tokio::test]
async fn run_waits_for_first_registration() {
let cancel = CancellationToken::new();
let reaper = Arc::new(EntityReaper::new(cancel.clone()));
let reaper_clone = Arc::clone(&reaper);
let handle = tokio::spawn(async move {
reaper_clone.run().await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
reaper
.register(make_manager(), Duration::from_secs(60))
.await;
tokio::time::sleep(Duration::from_millis(50)).await;
cancel.cancel();
tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("reaper should stop")
.expect("task should not panic");
}
}