use crate::actor::{
Actor, ActorRef, AskReply, ExpandHandler, Handler, ReduceHandler,
TransformHandler,
};
use crate::errors::ActorSendError;
use crate::message::Message;
use crate::node::ActorId;
use crate::remote_ref::RemoteActorRef;
use crate::stream::{BatchConfig, BoxStream};
use tokio_util::sync::CancellationToken;
pub enum WorkerRef<A: Actor, L: ActorRef<A>> {
Local(L),
Remote(RemoteActorRef<A>),
}
impl<A: Actor, L: ActorRef<A>> Clone for WorkerRef<A, L> {
fn clone(&self) -> Self {
match self {
WorkerRef::Local(r) => WorkerRef::Local(r.clone()),
WorkerRef::Remote(r) => WorkerRef::Remote(r.clone()),
}
}
}
impl<A: Actor + Sync, L: ActorRef<A>> ActorRef<A> for WorkerRef<A, L> {
fn id(&self) -> ActorId {
match self {
WorkerRef::Local(r) => r.id(),
WorkerRef::Remote(r) => r.id(),
}
}
fn name(&self) -> String {
match self {
WorkerRef::Local(r) => r.name(),
WorkerRef::Remote(r) => r.name(),
}
}
fn is_alive(&self) -> bool {
match self {
WorkerRef::Local(r) => r.is_alive(),
WorkerRef::Remote(r) => r.is_alive(),
}
}
fn pending_messages(&self) -> usize {
match self {
WorkerRef::Local(r) => r.pending_messages(),
WorkerRef::Remote(r) => r.pending_messages(),
}
}
fn stop(&self) {
match self {
WorkerRef::Local(r) => r.stop(),
WorkerRef::Remote(r) => r.stop(),
}
}
fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
where
A: Handler<M>,
M: Message<Reply = ()>,
{
match self {
WorkerRef::Local(r) => r.tell(msg),
WorkerRef::Remote(r) => r.tell(msg),
}
}
fn ask<M>(
&self,
msg: M,
cancel: Option<CancellationToken>,
) -> Result<AskReply<M::Reply>, ActorSendError>
where
A: Handler<M>,
M: Message,
{
match self {
WorkerRef::Local(r) => r.ask(msg, cancel),
WorkerRef::Remote(r) => r.ask(msg, cancel),
}
}
fn expand<M, OutputItem>(
&self,
msg: M,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: ExpandHandler<M, OutputItem>,
M: Send + 'static,
OutputItem: Send + 'static,
{
match self {
WorkerRef::Local(r) => r.expand(msg, buffer, batch_config, cancel),
WorkerRef::Remote(r) => r.expand(msg, buffer, batch_config, cancel),
}
}
fn reduce<InputItem, Reply>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<AskReply<Reply>, ActorSendError>
where
A: ReduceHandler<InputItem, Reply>,
InputItem: Send + 'static,
Reply: Send + 'static,
{
match self {
WorkerRef::Local(r) => r.reduce(input, buffer, batch_config, cancel),
WorkerRef::Remote(r) => r.reduce(input, buffer, batch_config, cancel),
}
}
fn transform<InputItem, OutputItem>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: TransformHandler<InputItem, OutputItem>,
InputItem: Send + 'static,
OutputItem: Send + 'static,
{
match self {
WorkerRef::Local(r) => r.transform(input, buffer, batch_config, cancel),
WorkerRef::Remote(r) => r.transform(input, buffer, batch_config, cancel),
}
}
}
impl<A: Actor, L: ActorRef<A>> WorkerRef<A, L> {
#[must_use]
pub fn is_local(&self) -> bool {
matches!(self, WorkerRef::Local(_))
}
#[must_use]
pub fn is_remote(&self) -> bool {
matches!(self, WorkerRef::Remote(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::actor::ActorContext;
use crate::node::{ActorId, NodeId};
use crate::pool::{PoolRef, PoolRouting};
use crate::remote_ref::RemoteActorRefBuilder;
use crate::test_support::test_runtime::TestRuntime;
use crate::transport::InMemoryTransport;
use std::sync::Arc;
struct Counter {
count: i64,
}
impl Actor for Counter {
type Args = i64;
type Deps = ();
fn create(args: i64, _deps: ()) -> Self {
Counter { count: args }
}
}
struct Increment(i64);
impl Message for Increment {
type Reply = ();
}
#[async_trait::async_trait]
impl Handler<Increment> for Counter {
async fn handle(&mut self, msg: Increment, _ctx: &mut ActorContext) {
self.count += msg.0;
}
}
struct GetCount;
impl Message for GetCount {
type Reply = i64;
}
#[async_trait::async_trait]
impl Handler<GetCount> for Counter {
async fn handle(&mut self, _msg: GetCount, _ctx: &mut ActorContext) -> i64 {
self.count
}
}
fn make_remote_ref() -> RemoteActorRef<Counter> {
let transport = Arc::new(InMemoryTransport::new(NodeId("test-node".into())));
RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("remote-node".into()),
local: 99,
},
"remote-counter",
transport,
)
.build()
}
#[test]
fn worker_ref_is_local_and_is_remote() {
let remote = make_remote_ref();
let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
WorkerRef::Remote(remote);
assert!(worker.is_remote());
assert!(!worker.is_local());
}
#[test]
fn worker_ref_delegates_id_and_name() {
let remote = make_remote_ref();
let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
WorkerRef::Remote(remote.clone());
assert_eq!(worker.id(), remote.id());
assert_eq!(worker.name(), remote.name());
}
#[tokio::test]
async fn distributed_pool_with_local_workers() {
let rt = TestRuntime::new();
let w1 = rt.spawn::<Counter>("c1", 0).await.unwrap();
let w2 = rt.spawn::<Counter>("c2", 0).await.unwrap();
let w1_check = w1.clone();
let w2_check = w2.clone();
let workers = vec![
WorkerRef::Local(w1),
WorkerRef::Local(w2),
];
let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
pool.tell(Increment(10)).unwrap();
pool.tell(Increment(20)).unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let c1 = w1_check.ask(GetCount, None).unwrap().await.unwrap();
let c2 = w2_check.ask(GetCount, None).unwrap().await.unwrap();
assert_eq!(c1, 10, "w1 should have received Increment(10)");
assert_eq!(c2, 20, "w2 should have received Increment(20)");
}
#[tokio::test]
async fn distributed_pool_ask_round_robin() {
let rt = TestRuntime::new();
let w1 = rt.spawn::<Counter>("ask-c1", 100).await.unwrap();
let w2 = rt.spawn::<Counter>("ask-c2", 200).await.unwrap();
let workers = vec![
WorkerRef::Local(w1),
WorkerRef::Local(w2),
];
let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
let count1 = pool.ask(GetCount, None).unwrap().await.unwrap();
let count2 = pool.ask(GetCount, None).unwrap().await.unwrap();
assert_eq!(count1, 100);
assert_eq!(count2, 200);
}
#[tokio::test]
async fn distributed_pool_mixed_local_remote_creation() {
let rt = TestRuntime::new();
let local = rt.spawn::<Counter>("local-w", 0).await.unwrap();
let local_check = local.clone();
let remote = make_remote_ref();
let workers = vec![
WorkerRef::Local(local),
WorkerRef::Remote(remote),
];
let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
assert!(pool.is_alive());
pool.tell(Increment(42)).unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let count = local_check.ask(GetCount, None).unwrap().await.unwrap();
assert_eq!(count, 42, "local worker should have received the tell");
}
#[tokio::test]
async fn worker_ref_stop_delegates() {
let rt = TestRuntime::new();
let w = rt.spawn::<Counter>("stop-w", 0).await.unwrap();
let worker = WorkerRef::<Counter, _>::Local(w);
assert!(worker.is_alive());
worker.stop();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(!worker.is_alive());
}
#[test]
fn worker_ref_pending_messages_delegates() {
let remote = make_remote_ref();
let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
WorkerRef::Remote(remote);
assert_eq!(worker.pending_messages(), 0);
}
}