1use crate::actor::{
27 Actor, ActorRef, AskReply, ExpandHandler, Handler, ReduceHandler,
28 TransformHandler,
29};
30use crate::errors::ActorSendError;
31use crate::message::Message;
32use crate::node::ActorId;
33use crate::remote_ref::RemoteActorRef;
34use crate::stream::{BatchConfig, BoxStream};
35use tokio_util::sync::CancellationToken;
36
37pub enum WorkerRef<A: Actor, L: ActorRef<A>> {
53 Local(L),
55 Remote(RemoteActorRef<A>),
57}
58
59impl<A: Actor, L: ActorRef<A>> Clone for WorkerRef<A, L> {
62 fn clone(&self) -> Self {
63 match self {
64 WorkerRef::Local(r) => WorkerRef::Local(r.clone()),
65 WorkerRef::Remote(r) => WorkerRef::Remote(r.clone()),
66 }
67 }
68}
69
70impl<A: Actor + Sync, L: ActorRef<A>> ActorRef<A> for WorkerRef<A, L> {
71 fn id(&self) -> ActorId {
72 match self {
73 WorkerRef::Local(r) => r.id(),
74 WorkerRef::Remote(r) => r.id(),
75 }
76 }
77
78 fn name(&self) -> String {
79 match self {
80 WorkerRef::Local(r) => r.name(),
81 WorkerRef::Remote(r) => r.name(),
82 }
83 }
84
85 fn is_alive(&self) -> bool {
86 match self {
87 WorkerRef::Local(r) => r.is_alive(),
88 WorkerRef::Remote(r) => r.is_alive(),
89 }
90 }
91
92 fn pending_messages(&self) -> usize {
93 match self {
94 WorkerRef::Local(r) => r.pending_messages(),
95 WorkerRef::Remote(r) => r.pending_messages(),
96 }
97 }
98
99 fn stop(&self) {
100 match self {
101 WorkerRef::Local(r) => r.stop(),
102 WorkerRef::Remote(r) => r.stop(),
103 }
104 }
105
106 fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
107 where
108 A: Handler<M>,
109 M: Message<Reply = ()>,
110 {
111 match self {
112 WorkerRef::Local(r) => r.tell(msg),
113 WorkerRef::Remote(r) => r.tell(msg),
114 }
115 }
116
117 fn ask<M>(
118 &self,
119 msg: M,
120 cancel: Option<CancellationToken>,
121 ) -> Result<AskReply<M::Reply>, ActorSendError>
122 where
123 A: Handler<M>,
124 M: Message,
125 {
126 match self {
127 WorkerRef::Local(r) => r.ask(msg, cancel),
128 WorkerRef::Remote(r) => r.ask(msg, cancel),
129 }
130 }
131
132 fn expand<M, OutputItem>(
133 &self,
134 msg: M,
135 buffer: usize,
136 batch_config: Option<BatchConfig>,
137 cancel: Option<CancellationToken>,
138 ) -> Result<BoxStream<OutputItem>, ActorSendError>
139 where
140 A: ExpandHandler<M, OutputItem>,
141 M: Send + 'static,
142 OutputItem: Send + 'static,
143 {
144 match self {
145 WorkerRef::Local(r) => r.expand(msg, buffer, batch_config, cancel),
146 WorkerRef::Remote(r) => r.expand(msg, buffer, batch_config, cancel),
147 }
148 }
149
150 fn reduce<InputItem, Reply>(
151 &self,
152 input: BoxStream<InputItem>,
153 buffer: usize,
154 batch_config: Option<BatchConfig>,
155 cancel: Option<CancellationToken>,
156 ) -> Result<AskReply<Reply>, ActorSendError>
157 where
158 A: ReduceHandler<InputItem, Reply>,
159 InputItem: Send + 'static,
160 Reply: Send + 'static,
161 {
162 match self {
163 WorkerRef::Local(r) => r.reduce(input, buffer, batch_config, cancel),
164 WorkerRef::Remote(r) => r.reduce(input, buffer, batch_config, cancel),
165 }
166 }
167
168 fn transform<InputItem, OutputItem>(
169 &self,
170 input: BoxStream<InputItem>,
171 buffer: usize,
172 batch_config: Option<BatchConfig>,
173 cancel: Option<CancellationToken>,
174 ) -> Result<BoxStream<OutputItem>, ActorSendError>
175 where
176 A: TransformHandler<InputItem, OutputItem>,
177 InputItem: Send + 'static,
178 OutputItem: Send + 'static,
179 {
180 match self {
181 WorkerRef::Local(r) => r.transform(input, buffer, batch_config, cancel),
182 WorkerRef::Remote(r) => r.transform(input, buffer, batch_config, cancel),
183 }
184 }
185}
186
187impl<A: Actor, L: ActorRef<A>> WorkerRef<A, L> {
188 #[must_use]
190 pub fn is_local(&self) -> bool {
191 matches!(self, WorkerRef::Local(_))
192 }
193
194 #[must_use]
196 pub fn is_remote(&self) -> bool {
197 matches!(self, WorkerRef::Remote(_))
198 }
199}
200
201#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::actor::ActorContext;
209 use crate::node::{ActorId, NodeId};
210 use crate::pool::{PoolRef, PoolRouting};
211 use crate::remote_ref::RemoteActorRefBuilder;
212 use crate::test_support::test_runtime::TestRuntime;
213 use crate::transport::InMemoryTransport;
214 use std::sync::Arc;
215
216 struct Counter {
218 count: i64,
219 }
220
221 impl Actor for Counter {
222 type Args = i64;
223 type Deps = ();
224 fn create(args: i64, _deps: ()) -> Self {
225 Counter { count: args }
226 }
227 }
228
229 struct Increment(i64);
230 impl Message for Increment {
231 type Reply = ();
232 }
233
234 #[async_trait::async_trait]
235 impl Handler<Increment> for Counter {
236 async fn handle(&mut self, msg: Increment, _ctx: &mut ActorContext) {
237 self.count += msg.0;
238 }
239 }
240
241 struct GetCount;
242 impl Message for GetCount {
243 type Reply = i64;
244 }
245
246 #[async_trait::async_trait]
247 impl Handler<GetCount> for Counter {
248 async fn handle(&mut self, _msg: GetCount, _ctx: &mut ActorContext) -> i64 {
249 self.count
250 }
251 }
252
253 fn make_remote_ref() -> RemoteActorRef<Counter> {
254 let transport = Arc::new(InMemoryTransport::new(NodeId("test-node".into())));
255 RemoteActorRefBuilder::<Counter>::new(
256 ActorId {
257 node: NodeId("remote-node".into()),
258 local: 99,
259 },
260 "remote-counter",
261 transport,
262 )
263 .build()
264 }
265
266 #[test]
267 fn worker_ref_is_local_and_is_remote() {
268 let remote = make_remote_ref();
269 let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
270 WorkerRef::Remote(remote);
271 assert!(worker.is_remote());
272 assert!(!worker.is_local());
273 }
274
275 #[test]
276 fn worker_ref_delegates_id_and_name() {
277 let remote = make_remote_ref();
278 let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
279 WorkerRef::Remote(remote.clone());
280 assert_eq!(worker.id(), remote.id());
281 assert_eq!(worker.name(), remote.name());
282 }
283
284 #[tokio::test]
285 async fn distributed_pool_with_local_workers() {
286 let rt = TestRuntime::new();
287 let w1 = rt.spawn::<Counter>("c1", 0).await.unwrap();
288 let w2 = rt.spawn::<Counter>("c2", 0).await.unwrap();
289
290 let w1_check = w1.clone();
292 let w2_check = w2.clone();
293
294 let workers = vec![
295 WorkerRef::Local(w1),
296 WorkerRef::Local(w2),
297 ];
298 let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
299
300 pool.tell(Increment(10)).unwrap();
302 pool.tell(Increment(20)).unwrap();
304
305 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
306
307 let c1 = w1_check.ask(GetCount, None).unwrap().await.unwrap();
309 let c2 = w2_check.ask(GetCount, None).unwrap().await.unwrap();
310 assert_eq!(c1, 10, "w1 should have received Increment(10)");
311 assert_eq!(c2, 20, "w2 should have received Increment(20)");
312 }
313
314 #[tokio::test]
315 async fn distributed_pool_ask_round_robin() {
316 let rt = TestRuntime::new();
317 let w1 = rt.spawn::<Counter>("ask-c1", 100).await.unwrap();
318 let w2 = rt.spawn::<Counter>("ask-c2", 200).await.unwrap();
319
320 let workers = vec![
321 WorkerRef::Local(w1),
322 WorkerRef::Local(w2),
323 ];
324 let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
325
326 let count1 = pool.ask(GetCount, None).unwrap().await.unwrap();
328 let count2 = pool.ask(GetCount, None).unwrap().await.unwrap();
330
331 assert_eq!(count1, 100);
332 assert_eq!(count2, 200);
333 }
334
335 #[tokio::test]
336 async fn distributed_pool_mixed_local_remote_creation() {
337 let rt = TestRuntime::new();
338 let local = rt.spawn::<Counter>("local-w", 0).await.unwrap();
339 let local_check = local.clone();
340 let remote = make_remote_ref();
341
342 let workers = vec![
343 WorkerRef::Local(local),
344 WorkerRef::Remote(remote),
345 ];
346 let pool = PoolRef::new(workers, PoolRouting::RoundRobin);
347
348 assert!(pool.is_alive());
349
350 pool.tell(Increment(42)).unwrap();
352 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
353
354 let count = local_check.ask(GetCount, None).unwrap().await.unwrap();
355 assert_eq!(count, 42, "local worker should have received the tell");
356 }
357
358 #[tokio::test]
359 async fn worker_ref_stop_delegates() {
360 let rt = TestRuntime::new();
361 let w = rt.spawn::<Counter>("stop-w", 0).await.unwrap();
362 let worker = WorkerRef::<Counter, _>::Local(w);
363 assert!(worker.is_alive());
364 worker.stop();
365 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
366 assert!(!worker.is_alive());
367 }
368
369 #[test]
370 fn worker_ref_pending_messages_delegates() {
371 let remote = make_remote_ref();
372 let worker: WorkerRef<Counter, RemoteActorRef<Counter>> =
373 WorkerRef::Remote(remote);
374 assert_eq!(worker.pending_messages(), 0);
376 }
377}