Skip to main content

dactor/
worker_ref.rs

1//! Worker reference for distributed actor pools.
2//!
3//! [`WorkerRef`] wraps either a local [`ActorRef`] or a
4//! [`RemoteActorRef`](crate::remote_ref::RemoteActorRef), implementing
5//! [`ActorRef<A>`] by delegating to the inner reference. This enables
6//! [`PoolRef`](crate::pool::PoolRef) to route messages across local and
7//! remote workers transparently.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use dactor::worker_ref::WorkerRef;
13//! use dactor::pool::{PoolRef, PoolRouting};
14//!
15//! let local_ref: LocalActorRef<Counter> = runtime.spawn("local", 0).await?;
16//! let remote_ref: RemoteActorRef<Counter> = /* ... */;
17//!
18//! let workers = vec![
19//!     WorkerRef::Local(local_ref),
20//!     WorkerRef::Remote(remote_ref),
21//! ];
22//! let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
23//! pool.tell(Increment(1))?; // routes to local or remote
24//! ```
25
26use crate::actor::{
27    Actor, ActorRef, AskReply, ExpandHandler, Handler, ReduceHandler,
28    TransformHandler,
29};
30use crate::errors::ActorSendError;
31use crate::message::Message;
32use crate::node::ActorId;
33use crate::remote_ref::RemoteActorRef;
34use crate::stream::{BatchConfig, BoxStream};
35use tokio_util::sync::CancellationToken;
36
37/// A worker reference that can be either local or remote.
38///
39/// Implements [`ActorRef<A>`] by delegating to the inner reference,
40/// allowing [`PoolRef`](crate::pool::PoolRef) to mix local and remote
41/// workers in a single pool.
42///
43/// # Limitations
44///
45/// - **Streaming**: `expand()`, `reduce()`, and `transform()` are not yet
46///   supported on [`RemoteActorRef`]. Mixed pools should only use `tell()`
47///   and `ask()` until streaming transport is implemented.
48/// - **LeastLoaded routing**: [`RemoteActorRef`] returns `0` for
49///   `pending_messages()`, so `LeastLoaded` routing will prefer remote
50///   workers over busy local workers. Use `RoundRobin` or `Random` for
51///   mixed pools until remote mailbox depth queries are available.
52pub enum WorkerRef<A: Actor, L: ActorRef<A>> {
53    /// A local actor reference (adapter-specific).
54    Local(L),
55    /// A remote actor reference (cross-node via transport).
56    Remote(RemoteActorRef<A>),
57}
58
59// Manual Clone: derive(Clone) would add an `A: Clone` bound, but
60// RemoteActorRef<A> implements Clone without requiring A: Clone.
61impl<A: Actor, L: ActorRef<A>> Clone for WorkerRef<A, L> {
62    fn clone(&self) -> Self {
63        match self {
64            WorkerRef::Local(r) => WorkerRef::Local(r.clone()),
65            WorkerRef::Remote(r) => WorkerRef::Remote(r.clone()),
66        }
67    }
68}
69
70impl<A: Actor + Sync, L: ActorRef<A>> ActorRef<A> for WorkerRef<A, L> {
71    fn id(&self) -> ActorId {
72        match self {
73            WorkerRef::Local(r) => r.id(),
74            WorkerRef::Remote(r) => r.id(),
75        }
76    }
77
78    fn name(&self) -> String {
79        match self {
80            WorkerRef::Local(r) => r.name(),
81            WorkerRef::Remote(r) => r.name(),
82        }
83    }
84
85    fn is_alive(&self) -> bool {
86        match self {
87            WorkerRef::Local(r) => r.is_alive(),
88            WorkerRef::Remote(r) => r.is_alive(),
89        }
90    }
91
92    fn pending_messages(&self) -> usize {
93        match self {
94            WorkerRef::Local(r) => r.pending_messages(),
95            WorkerRef::Remote(r) => r.pending_messages(),
96        }
97    }
98
99    fn stop(&self) {
100        match self {
101            WorkerRef::Local(r) => r.stop(),
102            WorkerRef::Remote(r) => r.stop(),
103        }
104    }
105
106    fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
107    where
108        A: Handler<M>,
109        M: Message<Reply = ()>,
110    {
111        match self {
112            WorkerRef::Local(r) => r.tell(msg),
113            WorkerRef::Remote(r) => r.tell(msg),
114        }
115    }
116
117    fn ask<M>(
118        &self,
119        msg: M,
120        cancel: Option<CancellationToken>,
121    ) -> Result<AskReply<M::Reply>, ActorSendError>
122    where
123        A: Handler<M>,
124        M: Message,
125    {
126        match self {
127            WorkerRef::Local(r) => r.ask(msg, cancel),
128            WorkerRef::Remote(r) => r.ask(msg, cancel),
129        }
130    }
131
132    fn expand<M, OutputItem>(
133        &self,
134        msg: M,
135        buffer: usize,
136        batch_config: Option<BatchConfig>,
137        cancel: Option<CancellationToken>,
138    ) -> Result<BoxStream<OutputItem>, ActorSendError>
139    where
140        A: ExpandHandler<M, OutputItem>,
141        M: Send + 'static,
142        OutputItem: Send + 'static,
143    {
144        match self {
145            WorkerRef::Local(r) => r.expand(msg, buffer, batch_config, cancel),
146            WorkerRef::Remote(r) => r.expand(msg, buffer, batch_config, cancel),
147        }
148    }
149
150    fn reduce<InputItem, Reply>(
151        &self,
152        input: BoxStream<InputItem>,
153        buffer: usize,
154        batch_config: Option<BatchConfig>,
155        cancel: Option<CancellationToken>,
156    ) -> Result<AskReply<Reply>, ActorSendError>
157    where
158        A: ReduceHandler<InputItem, Reply>,
159        InputItem: Send + 'static,
160        Reply: Send + 'static,
161    {
162        match self {
163            WorkerRef::Local(r) => r.reduce(input, buffer, batch_config, cancel),
164            WorkerRef::Remote(r) => r.reduce(input, buffer, batch_config, cancel),
165        }
166    }
167
168    fn transform<InputItem, OutputItem>(
169        &self,
170        input: BoxStream<InputItem>,
171        buffer: usize,
172        batch_config: Option<BatchConfig>,
173        cancel: Option<CancellationToken>,
174    ) -> Result<BoxStream<OutputItem>, ActorSendError>
175    where
176        A: TransformHandler<InputItem, OutputItem>,
177        InputItem: Send + 'static,
178        OutputItem: Send + 'static,
179    {
180        match self {
181            WorkerRef::Local(r) => r.transform(input, buffer, batch_config, cancel),
182            WorkerRef::Remote(r) => r.transform(input, buffer, batch_config, cancel),
183        }
184    }
185}
186
187impl<A: Actor, L: ActorRef<A>> WorkerRef<A, L> {
188    /// Returns `true` if this is a local worker.
189    #[must_use]
190    pub fn is_local(&self) -> bool {
191        matches!(self, WorkerRef::Local(_))
192    }
193
194    /// Returns `true` if this is a remote worker.
195    #[must_use]
196    pub fn is_remote(&self) -> bool {
197        matches!(self, WorkerRef::Remote(_))
198    }
199}
200
201// ---------------------------------------------------------------------------
202// Tests
203// ---------------------------------------------------------------------------
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::actor::ActorContext;
209    use crate::node::{ActorId, NodeId};
210    use crate::pool::{PoolRef, PoolRouting};
211    use crate::remote_ref::RemoteActorRefBuilder;
212    use crate::test_support::test_runtime::TestRuntime;
213    use crate::transport::InMemoryTransport;
214    use std::sync::Arc;
215
216    // A simple counter actor for testing.
217    struct Counter {
218        count: i64,
219    }
220
221    impl Actor for Counter {
222        type Args = i64;
223        type Deps = ();
224        fn create(args: i64, _deps: ()) -> Self {
225            Counter { count: args }
226        }
227    }
228
229    struct Increment(i64);
230    impl Message for Increment {
231        type Reply = ();
232    }
233
234    #[async_trait::async_trait]
235    impl Handler<Increment> for Counter {
236        async fn handle(&mut self, msg: Increment, _ctx: &mut ActorContext) {
237            self.count += msg.0;
238        }
239    }
240
241    struct GetCount;
242    impl Message for GetCount {
243        type Reply = i64;
244    }
245
246    #[async_trait::async_trait]
247    impl Handler<GetCount> for Counter {
248        async fn handle(&mut self, _msg: GetCount, _ctx: &mut ActorContext) -> i64 {
249            self.count
250        }
251    }
252
253    fn make_remote_ref() -> RemoteActorRef<Counter> {
254        let transport = Arc::new(InMemoryTransport::new(NodeId("test-node".into())));
255        RemoteActorRefBuilder::<Counter>::new(
256            ActorId {
257                node: NodeId("remote-node".into()),
258                local: 99,
259            },
260            "remote-counter",
261            transport,
262        )
263        .build()
264    }
265
266    #[test]
267    fn worker_ref_is_local_and_is_remote() {
268        let remote = make_remote_ref();
269        let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
270            WorkerRef::Remote(remote);
271        assert!(worker.is_remote());
272        assert!(!worker.is_local());
273    }
274
275    #[test]
276    fn worker_ref_delegates_id_and_name() {
277        let remote = make_remote_ref();
278        let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
279            WorkerRef::Remote(remote.clone());
280        assert_eq!(worker.id(), remote.id());
281        assert_eq!(worker.name(), remote.name());
282    }
283
284    #[tokio::test]
285    async fn distributed_pool_with_local_workers() {
286        let rt = TestRuntime::new();
287        let w1 = rt.spawn::<Counter>("c1", 0).await.unwrap();
288        let w2 = rt.spawn::<Counter>("c2", 0).await.unwrap();
289
290        // Keep refs to verify individual worker state
291        let w1_check = w1.clone();
292        let w2_check = w2.clone();
293
294        let workers = vec![
295            WorkerRef::Local(w1),
296            WorkerRef::Local(w2),
297        ];
298        let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
299
300        // Tell goes to w1 (round-robin index 0)
301        pool.tell(Increment(10)).unwrap();
302        // Tell goes to w2 (round-robin index 1)
303        pool.tell(Increment(20)).unwrap();
304
305        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
306
307        // Verify messages were distributed to the correct workers
308        let c1 = w1_check.ask(GetCount, None).unwrap().await.unwrap();
309        let c2 = w2_check.ask(GetCount, None).unwrap().await.unwrap();
310        assert_eq!(c1, 10, "w1 should have received Increment(10)");
311        assert_eq!(c2, 20, "w2 should have received Increment(20)");
312    }
313
314    #[tokio::test]
315    async fn distributed_pool_ask_round_robin() {
316        let rt = TestRuntime::new();
317        let w1 = rt.spawn::<Counter>("ask-c1", 100).await.unwrap();
318        let w2 = rt.spawn::<Counter>("ask-c2", 200).await.unwrap();
319
320        let workers = vec![
321            WorkerRef::Local(w1),
322            WorkerRef::Local(w2),
323        ];
324        let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
325
326        // First ask goes to w1 (100)
327        let count1 = pool.ask(GetCount, None).unwrap().await.unwrap();
328        // Second ask goes to w2 (200)
329        let count2 = pool.ask(GetCount, None).unwrap().await.unwrap();
330
331        assert_eq!(count1, 100);
332        assert_eq!(count2, 200);
333    }
334
335    #[tokio::test]
336    async fn distributed_pool_mixed_local_remote_creation() {
337        let rt = TestRuntime::new();
338        let local = rt.spawn::<Counter>("local-w", 0).await.unwrap();
339        let local_check = local.clone();
340        let remote = make_remote_ref();
341
342        let workers = vec![
343            WorkerRef::Local(local),
344            WorkerRef::Remote(remote),
345        ];
346        let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
347
348        assert!(pool.is_alive());
349
350        // First tell goes to local worker (index 0) — should succeed
351        pool.tell(Increment(42)).unwrap();
352        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
353
354        let count = local_check.ask(GetCount, None).unwrap().await.unwrap();
355        assert_eq!(count, 42, "local worker should have received the tell");
356    }
357
358    #[tokio::test]
359    async fn worker_ref_stop_delegates() {
360        let rt = TestRuntime::new();
361        let w = rt.spawn::<Counter>("stop-w", 0).await.unwrap();
362        let worker = WorkerRef::<Counter, _>::Local(w);
363        assert!(worker.is_alive());
364        worker.stop();
365        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
366        assert!(!worker.is_alive());
367    }
368
369    #[test]
370    fn worker_ref_pending_messages_delegates() {
371        let remote = make_remote_ref();
372        let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
373            WorkerRef::Remote(remote);
374        // RemoteActorRef returns 0 by default
375        assert_eq!(worker.pending_messages(), 0);
376    }
377}