use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::actor::{Actor, ActorRef, AskReply, ReduceHandler, Handler, ExpandHandler, TransformHandler};
use crate::errors::ActorSendError;
use crate::message::Message;
#[cfg(feature = "metrics")]
use crate::metrics::ActorMetricsHandle;
use crate::node::{ActorId, NodeId};
use crate::stream::{BatchConfig, BoxStream};
#[derive(Debug, Clone)]
pub enum PoolRouting {
RoundRobin,
Random,
KeyBased,
LeastLoaded,
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub pool_size: usize,
pub routing: PoolRouting,
}
impl PoolConfig {
pub fn new(pool_size: usize, routing: PoolRouting) -> Self {
Self { pool_size, routing }
}
}
pub trait Keyed {
fn routing_key(&self) -> u64;
}
pub struct PoolRef<A: Actor, R: ActorRef<A>> {
workers: Vec<R>,
routing: PoolRouting,
counter: Arc<AtomicU64>,
pool_id: u64,
name: String,
#[cfg(feature = "metrics")]
metrics_handle: Option<Arc<ActorMetricsHandle>>,
_phantom: PhantomData<fn() -> A>,
}
impl<A: Actor, R: ActorRef<A>> Clone for PoolRef<A, R> {
fn clone(&self) -> Self {
Self {
workers: self.workers.clone(),
routing: self.routing.clone(),
counter: self.counter.clone(),
pool_id: self.pool_id,
name: self.name.clone(),
#[cfg(feature = "metrics")]
metrics_handle: self.metrics_handle.clone(),
_phantom: PhantomData,
}
}
}
static NEXT_POOL_ID: AtomicU64 = AtomicU64::new(1);
impl<A: Actor, R: ActorRef<A>> PoolRef<A, R> {
pub fn new(workers: Vec<R>, routing: PoolRouting) -> Self {
assert!(!workers.is_empty(), "pool must have at least one worker");
let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
let name = format!("pool({})", workers[0].name());
Self {
workers,
routing,
counter: Arc::new(AtomicU64::new(0)),
pool_id,
name,
#[cfg(feature = "metrics")]
metrics_handle: None,
_phantom: PhantomData,
}
}
#[cfg(feature = "metrics")]
pub fn with_metrics_handle(mut self, handle: Arc<ActorMetricsHandle>) -> Self {
self.metrics_handle = Some(handle);
self
}
#[cfg(feature = "metrics")]
pub fn metrics_handle(&self) -> Option<&Arc<ActorMetricsHandle>> {
self.metrics_handle.as_ref()
}
pub fn pool_size(&self) -> usize {
self.workers.len()
}
fn round_robin_index(&self) -> usize {
let idx = self.counter.fetch_add(1, Ordering::Relaxed);
(idx % (self.workers.len() as u64)) as usize
}
fn random_index(&self) -> usize {
let raw = self.counter.fetch_add(1, Ordering::Relaxed);
let mixed = splitmix64(raw);
(mixed % (self.workers.len() as u64)) as usize
}
fn keyed_index(&self, key: u64) -> usize {
(key % (self.workers.len() as u64)) as usize
}
fn least_loaded_index(&self) -> usize {
let min_load = self
.workers
.iter()
.map(|w| w.pending_messages())
.min()
.unwrap_or(0);
let candidates: Vec<usize> = self
.workers
.iter()
.enumerate()
.filter(|(_, w)| w.pending_messages() == min_load)
.map(|(i, _)| i)
.collect();
if candidates.len() == 1 {
candidates[0]
} else {
let idx = self.counter.fetch_add(1, Ordering::Relaxed);
candidates[(idx as usize) % candidates.len()]
}
}
fn select_worker(&self) -> &R {
let idx = match &self.routing {
PoolRouting::RoundRobin | PoolRouting::KeyBased => self.round_robin_index(),
PoolRouting::Random => self.random_index(),
PoolRouting::LeastLoaded => self.least_loaded_index(),
};
&self.workers[idx]
}
pub fn to_broadcast(&self) -> crate::broadcast::BroadcastRef<A, R> {
crate::broadcast::BroadcastRef::new(self.workers.clone())
}
pub fn tell_keyed<M>(&self, msg: M) -> Result<(), ActorSendError>
where
A: Handler<M>,
M: Message<Reply = ()> + Keyed,
{
let idx = self.keyed_index(msg.routing_key());
self.workers[idx].tell(msg)
}
pub fn ask_keyed<M>(
&self,
msg: M,
cancel: Option<CancellationToken>,
) -> Result<AskReply<M::Reply>, ActorSendError>
where
A: Handler<M>,
M: Message + Keyed,
{
let idx = self.keyed_index(msg.routing_key());
self.workers[idx].ask(msg, cancel)
}
}
fn splitmix64(mut x: u64) -> u64 {
x = x.wrapping_add(0x9e3779b97f4a7c15);
x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
x ^ (x >> 31)
}
impl<A: Actor, R: ActorRef<A>> ActorRef<A> for PoolRef<A, R> {
fn id(&self) -> ActorId {
ActorId {
node: NodeId("pool".into()),
local: self.pool_id,
}
}
fn name(&self) -> String {
self.name.clone()
}
fn is_alive(&self) -> bool {
self.workers.iter().any(|w| w.is_alive())
}
fn stop(&self) {
for w in &self.workers {
w.stop();
}
}
fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
where
A: Handler<M>,
M: Message<Reply = ()>,
{
self.select_worker().tell(msg)
}
fn ask<M>(
&self,
msg: M,
cancel: Option<CancellationToken>,
) -> Result<AskReply<M::Reply>, ActorSendError>
where
A: Handler<M>,
M: Message,
{
self.select_worker().ask(msg, cancel)
}
fn expand<M, OutputItem>(
&self,
msg: M,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: ExpandHandler<M, OutputItem>,
M: Send + 'static,
OutputItem: Send + 'static,
{
self.select_worker()
.expand(msg, buffer, batch_config, cancel)
}
fn reduce<InputItem, Reply>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<AskReply<Reply>, ActorSendError>
where
A: ReduceHandler<InputItem, Reply>,
InputItem: Send + 'static,
Reply: Send + 'static,
{
self.select_worker()
.reduce(input, buffer, batch_config, cancel)
}
fn transform<InputItem, OutputItem>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: TransformHandler<InputItem, OutputItem>,
InputItem: Send + 'static,
OutputItem: Send + 'static,
{
self.select_worker().transform(input, buffer, batch_config, cancel)
}
}
#[cfg(feature = "test-support")]
impl crate::test_support::test_runtime::TestRuntime {
pub async fn spawn_pool<A>(
&self,
name: &str,
pool_size: usize,
routing: PoolRouting,
args: A::Args,
) -> Result<PoolRef<A, crate::test_support::test_runtime::TestActorRef<A>>, crate::errors::RuntimeError>
where
A: Actor<Deps = ()> + 'static,
A::Args: Clone,
{
let mut workers = Vec::with_capacity(pool_size);
for i in 0..pool_size {
workers.push(self.spawn(&format!("{}-{}", name, i), args.clone()).await?);
}
Ok(PoolRef::new(workers, routing))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use crate::actor::{Actor, ActorContext, Handler};
use crate::message::Message;
use crate::test_support::test_runtime::TestRuntime;
struct PoolWorker {
id: u64,
call_count: Arc<AtomicU64>,
}
impl Actor for PoolWorker {
type Args = (u64, Arc<AtomicU64>);
type Deps = ();
fn create(args: Self::Args, _deps: ()) -> Self {
Self {
id: args.0,
call_count: args.1,
}
}
}
struct Ping;
impl Message for Ping {
type Reply = ();
}
struct WhoAreYou;
impl Message for WhoAreYou {
type Reply = u64;
}
struct KeyedPing {
key: u64,
}
impl Message for KeyedPing {
type Reply = u64;
}
impl Keyed for KeyedPing {
fn routing_key(&self) -> u64 {
self.key
}
}
#[async_trait]
impl Handler<Ping> for PoolWorker {
async fn handle(&mut self, _msg: Ping, _ctx: &mut ActorContext) {
self.call_count.fetch_add(1, Ordering::Relaxed);
}
}
#[async_trait]
impl Handler<WhoAreYou> for PoolWorker {
async fn handle(&mut self, _msg: WhoAreYou, _ctx: &mut ActorContext) -> u64 {
self.id
}
}
#[async_trait]
impl Handler<KeyedPing> for PoolWorker {
async fn handle(&mut self, _msg: KeyedPing, _ctx: &mut ActorContext) -> u64 {
self.id
}
}
async fn make_pool(
rt: &TestRuntime,
size: usize,
routing: PoolRouting,
) -> (
PoolRef<PoolWorker, crate::test_support::test_runtime::TestActorRef<PoolWorker>>,
Vec<Arc<AtomicU64>>,
) {
let mut counters = Vec::new();
let mut workers = Vec::new();
for i in 0..size {
let ctr = Arc::new(AtomicU64::new(0));
counters.push(ctr.clone());
let r = rt.spawn::<PoolWorker>(&format!("w-{}", i), (i as u64, ctr)).await.unwrap();
workers.push(r);
}
(PoolRef::new(workers, routing), counters)
}
#[tokio::test]
async fn round_robin_distributes_across_workers() {
let rt = TestRuntime::new();
let (pool, counters) = make_pool(&rt, 3, PoolRouting::RoundRobin).await;
for _ in 0..9 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
for (i, ctr) in counters.iter().enumerate() {
assert_eq!(
ctr.load(Ordering::Relaxed),
3,
"worker {} should have received 3 messages",
i
);
}
}
#[tokio::test]
async fn random_routing_delivers_to_all() {
let rt = TestRuntime::new();
let (pool, counters) = make_pool(&rt, 3, PoolRouting::Random).await;
for _ in 0..300 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let total: u64 = counters.iter().map(|c| c.load(Ordering::Relaxed)).sum();
assert_eq!(total, 300, "all messages should be delivered");
for (i, ctr) in counters.iter().enumerate() {
assert!(
ctr.load(Ordering::Relaxed) > 0,
"worker {} should have received at least 1 message with random routing",
i,
);
}
}
#[tokio::test]
async fn key_based_routes_same_key_to_same_worker() {
let rt = TestRuntime::new();
let (pool, _counters) = make_pool(&rt, 4, PoolRouting::KeyBased).await;
for key in [10u64, 42, 99, 1000] {
let mut results = Vec::new();
for _ in 0..5 {
let id = pool
.ask_keyed(KeyedPing { key }, None)
.unwrap()
.await
.unwrap();
results.push(id);
}
assert!(
results.windows(2).all(|w| w[0] == w[1]),
"key {} should always route to the same worker, got {:?}",
key,
results,
);
}
}
#[tokio::test]
async fn pool_ref_ask_works() {
let rt = TestRuntime::new();
let ctr = Arc::new(AtomicU64::new(0));
let worker = rt.spawn::<PoolWorker>("solo", (42, ctr)).await.unwrap();
let pool = PoolRef::new(vec![worker], PoolRouting::RoundRobin);
let id = pool.ask(WhoAreYou, None).unwrap().await.unwrap();
assert_eq!(id, 42);
}
#[tokio::test]
async fn pool_size_one_works() {
let rt = TestRuntime::new();
let (pool, counters) = make_pool(&rt, 1, PoolRouting::RoundRobin).await;
for _ in 0..5 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(counters[0].load(Ordering::Relaxed), 5);
}
#[tokio::test]
async fn pool_is_alive_and_stop() {
let rt = TestRuntime::new();
let (pool, _) = make_pool(&rt, 2, PoolRouting::RoundRobin).await;
assert!(pool.is_alive());
pool.stop();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(!pool.is_alive());
}
#[tokio::test]
async fn spawn_pool_helper() {
let rt = TestRuntime::new();
let ctr = Arc::new(AtomicU64::new(0));
let pool = rt.spawn_pool::<PoolWorker>("sp", 3, PoolRouting::RoundRobin, (0, ctr.clone())).await.unwrap();
assert_eq!(pool.pool_size(), 3);
for _ in 0..6 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(ctr.load(Ordering::Relaxed), 6);
}
#[tokio::test]
async fn key_based_falls_back_to_round_robin_for_non_keyed_messages() {
let rt = TestRuntime::new();
let (pool, counters) = make_pool(&rt, 3, PoolRouting::KeyBased).await;
for _ in 0..6 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
for (i, ctr) in counters.iter().enumerate() {
assert_eq!(
ctr.load(Ordering::Relaxed),
2,
"worker {} should have received 2 messages via round-robin fallback",
i
);
}
}
#[test]
#[should_panic(expected = "pool must have at least one worker")]
fn empty_pool_panics() {
let workers: Vec<crate::test_support::test_runtime::TestActorRef<PoolWorker>> = vec![];
PoolRef::new(workers, PoolRouting::RoundRobin);
}
#[tokio::test]
async fn least_loaded_distributes_when_equal() {
let rt = TestRuntime::new();
let (pool, counters) = make_pool(&rt, 3, PoolRouting::LeastLoaded).await;
for _ in 0..9 {
pool.tell(Ping).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let total: u64 = counters.iter().map(|c| c.load(Ordering::Relaxed)).sum();
assert_eq!(total, 9, "all messages should be delivered");
for (i, ctr) in counters.iter().enumerate() {
assert_eq!(
ctr.load(Ordering::Relaxed),
3,
"worker {} should have received 3 messages",
i
);
}
}
#[tokio::test]
async fn least_loaded_prefers_emptier_workers() {
use crate::mailbox::{MailboxConfig, OverflowStrategy};
use crate::test_support::test_runtime::SpawnOptions;
let rt = TestRuntime::new();
struct SlowWorker;
impl Actor for SlowWorker {
type Args = ();
type Deps = ();
fn create(_args: (), _deps: ()) -> Self {
SlowWorker
}
}
#[derive(Clone)]
struct SlowPing;
impl Message for SlowPing {
type Reply = ();
}
#[async_trait]
impl Handler<SlowPing> for SlowWorker {
async fn handle(&mut self, _msg: SlowPing, _ctx: &mut ActorContext) {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
}
}
let mut workers = Vec::new();
for i in 0..3 {
let opts = SpawnOptions {
interceptors: Vec::new(),
mailbox: MailboxConfig::Bounded {
capacity: 100,
overflow: OverflowStrategy::RejectWithError,
},
};
let r = rt
.spawn_with_options::<SlowWorker>(&format!("sw-{i}"), (), opts)
.await
.unwrap();
workers.push(r);
}
let pool = PoolRef::new(workers, PoolRouting::LeastLoaded);
for _ in 0..6 {
pool.tell(SlowPing).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let loads: Vec<usize> = pool
.workers
.iter()
.map(|w| w.pending_messages())
.collect();
let total_pending: usize = loads.iter().sum();
assert!(
total_pending > 0,
"some messages should be pending due to slow handlers, got {:?}",
loads
);
let workers_with_messages = loads.iter().filter(|&&l| l > 0).count();
assert!(
workers_with_messages > 1,
"messages should be distributed across workers, got loads: {:?}",
loads
);
}
}