1use std::{
86 any::Any,
87 boxed,
88 collections::HashMap,
89 future::Future,
90 io,
91 pin::Pin,
92 sync::{
93 Arc, Mutex, RwLock,
94 atomic::{AtomicU64, Ordering},
95 },
96};
97
98#[allow(unused_imports)]
99use futures_util::{
100 FutureExt,
101 future::{self, Either},
102 pin_mut,
103};
104
105pub type Sender<T> = flume::Sender<T>;
107
108pub type Receiver<T> = flume::Receiver<T>;
112
113pub use futures_executor::block_on;
115
116pub trait ActorError: Sized + Send + 'static {
129 fn from_actor_message(msg: String) -> Self;
130}
131
132impl ActorError for io::Error {
134 fn from_actor_message(msg: String) -> Self {
135 io::Error::other(msg)
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, Default)]
141pub enum ActorState {
142 #[default]
143 Running,
144 Stopped,
145}
146
147#[cfg(feature = "anyhow")]
148impl ActorError for anyhow::Error {
149 fn from_actor_message(msg: String) -> Self {
150 anyhow::anyhow!(msg)
151 }
152}
153
154impl ActorError for String {
155 fn from_actor_message(msg: String) -> Self {
156 msg
157 }
158}
159
160impl ActorError for Box<dyn std::error::Error + Send + Sync> {
161 fn from_actor_message(msg: String) -> Self {
162 Box::new(io::Error::other(msg))
163 }
164}
165
166pub type PreBoxActorFut<'a, T> = dyn Future<Output = T> + Send + 'a;
168
169pub type ActorFut<'a, T> = Pin<boxed::Box<PreBoxActorFut<'a, T>>>;
171
172pub type Action<A> = Box<dyn for<'a> FnOnce(&'a mut A) -> ActorFut<'a, ()> + Send + 'static>;
176
177type BaseCallResult<R, E> = Result<
179 (
180 Receiver<Result<R, E>>,
181 Receiver<()>,
182 u64,
183 &'static std::panic::Location<'static>,
184 ),
185 E,
186>;
187
188type PendingCancelMap = Arc<Mutex<HashMap<u64, Sender<()>>>>;
189
190fn fail_pending_calls(pending: &PendingCancelMap) {
191 if let Ok(mut pending) = pending.lock() {
192 for (_, cancel_tx) in pending.drain() {
193 let _ = cancel_tx.send(());
194 }
195 }
196}
197
198#[doc(hidden)]
200pub fn into_actor_fut_res<'a, Fut, T, E>(fut: Fut) -> ActorFut<'a, Result<T, E>>
201where
202 Fut: Future<Output = Result<T, E>> + Send + 'a,
203 T: Send + 'a,
204{
205 Box::pin(fut)
206}
207
208#[doc(hidden)]
210pub fn into_actor_fut_ok<'a, Fut, T, E>(fut: Fut) -> ActorFut<'a, Result<T, E>>
211where
212 Fut: Future<Output = T> + Send + 'a,
213 T: Send + 'a,
214 E: ActorError,
215{
216 Box::pin(async move { Ok(fut.await) })
217}
218
219#[macro_export]
232macro_rules! act {
233 ($actor:ident => $expr:expr) => {{ move |$actor| $crate::into_actor_fut_res(($expr)) }};
234 ($actor:ident => $body:block) => {{ move |$actor| $crate::into_actor_fut_res($body) }};
235}
236
237#[macro_export]
247macro_rules! act_ok {
248 ($actor:ident => $expr:expr) => {{ move |$actor| $crate::into_actor_fut_ok(($expr)) }};
249 ($actor:ident => $body:block) => {{ move |$actor| $crate::into_actor_fut_ok($body) }};
250}
251
252
253
254fn panic_payload_message(panic_payload: Box<dyn Any + Send>) -> String {
255 if let Some(s) = panic_payload.downcast_ref::<&str>() {
256 (*s).to_string()
257 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
258 s.clone()
259 } else {
260 "unknown panic".to_string()
261 }
262}
263
264fn actor_loop_panic<E: ActorError>(panic_payload: Box<dyn Any + Send>) -> E {
265 E::from_actor_message(format!(
266 "panic in actor loop: {}",
267 panic_payload_message(panic_payload)
268 ))
269}
270
271
272#[derive(Debug)]
276pub struct Handle<A, E>
277where
278 A: Send + 'static,
279 E: ActorError,
280{
281 tx: Arc<Mutex<Option<Sender<Action<A>>>>>,
282 state: Arc<RwLock<ActorState>>,
283 pending: PendingCancelMap,
284 next_call_id: Arc<AtomicU64>,
285 stopped_rx: Receiver<()>,
286 _phantom: std::marker::PhantomData<E>,
287}
288
289impl<A, E> Clone for Handle<A, E>
290where
291 A: Send + 'static,
292 E: ActorError,
293{
294 fn clone(&self) -> Self {
295 Self {
296 tx: Arc::clone(&self.tx),
297 state: Arc::clone(&self.state),
298 pending: Arc::clone(&self.pending),
299 next_call_id: Arc::clone(&self.next_call_id),
300 stopped_rx: self.stopped_rx.clone(),
301 _phantom: std::marker::PhantomData,
302 }
303 }
304}
305
306impl<A, E> PartialEq for Handle<A, E>
307where
308 A: Send + 'static,
309 E: ActorError,
310{
311 fn eq(&self, other: &Self) -> bool {
312 Arc::ptr_eq(&self.state, &other.state)
313 }
314}
315
316impl<A, E> Eq for Handle<A, E>
317where
318 A: Send + 'static,
319 E: ActorError,
320{
321}
322
323impl<A, E> Handle<A, E>
324where
325 A: Send + 'static,
326 E: ActorError,
327{
328 pub fn state(&self) -> ActorState {
330 self.state.read().expect("poisned lock").clone()
331 }
332
333 #[cfg(all(feature = "tokio", not(feature = "async-std")))]
338 pub fn spawn(actor: A) -> (Self, tokio::task::JoinHandle<Result<(), E>>)
339 {
340 let (tx, rx) = flume::unbounded::<Action<A>>();
341 let state = Arc::new(RwLock::new(ActorState::default()));
342 let pending = Arc::new(Mutex::new(HashMap::new()));
343 let next_call_id = Arc::new(AtomicU64::new(0));
344 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
345
346 let join_handle = {
347 let state = Arc::clone(&state);
348 let pending = Arc::clone(&pending);
349 tokio::task::spawn(async move {
350 let _stopped_signal = stopped_tx;
351 let mut actor = actor;
352
353 let res = std::panic::AssertUnwindSafe(async {
354 while let Ok(action) = rx.recv_async().await {
355 action(&mut actor).await;
356 }
357 Ok::<(), E>(())
358 })
359 .catch_unwind()
360 .await;
361
362 if let Ok(mut st) = state.write() {
363 *st = ActorState::Stopped;
364 }
365 fail_pending_calls(&pending);
366 match res {
367 Ok(result) => result,
368 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
369 }
370 })
371 };
372
373 (
374 Self {
375 tx: Arc::new(Mutex::new(Some(tx))),
376 state,
377 pending,
378 next_call_id,
379 stopped_rx,
380 _phantom: std::marker::PhantomData,
381 },
382 join_handle,
383 )
384 }
385
386 #[cfg(all(feature = "tokio", not(feature = "async-std")))]
391 pub fn spawn_with<F, Fut>(actor: A, run: F) -> (Self, tokio::task::JoinHandle<Result<(), E>>)
392 where
393 F: FnOnce(A, Receiver<Action<A>>) -> Fut + Send + 'static,
394 Fut: Future<Output = Result<(), E>> + Send,
395 {
396 let (tx, rx) = flume::unbounded();
397 let state = Arc::new(RwLock::new(ActorState::default()));
398 let pending = Arc::new(Mutex::new(HashMap::new()));
399 let next_call_id = Arc::new(AtomicU64::new(0));
400 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
401
402 let join_handle = {
403 let state = Arc::clone(&state);
404 let pending = Arc::clone(&pending);
405 tokio::task::spawn(async move {
406 let _stopped_signal = stopped_tx;
407
408 let res = std::panic::AssertUnwindSafe(run(actor, rx))
409 .catch_unwind()
410 .await;
411
412 if let Ok(mut st) = state.write() {
413 *st = ActorState::Stopped;
414 }
415 fail_pending_calls(&pending);
416 match res {
417 Ok(result) => result,
418 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
419 }
420 })
421 };
422
423 (
424 Self {
425 tx: Arc::new(Mutex::new(Some(tx))),
426 state,
427 pending,
428 next_call_id,
429 stopped_rx,
430 _phantom: std::marker::PhantomData,
431 },
432 join_handle,
433 )
434 }
435
436 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
441 pub fn spawn(actor: A) -> (Self, async_std::task::JoinHandle<Result<(), E>>)
442 {
443 let (tx, rx) = flume::unbounded::<Action<A>>();
444 let state = Arc::new(RwLock::new(ActorState::default()));
445 let pending = Arc::new(Mutex::new(HashMap::new()));
446 let next_call_id = Arc::new(AtomicU64::new(0));
447 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
448
449 let join_handle = {
450 let state = Arc::clone(&state);
451 let pending = Arc::clone(&pending);
452 async_std::task::spawn(async move {
453 let _stopped_signal = stopped_tx;
454 let mut actor = actor;
455
456 let res = std::panic::AssertUnwindSafe(async {
457 while let Ok(action) = rx.recv_async().await {
458 action(&mut actor).await;
459 }
460 Ok::<(), E>(())
461 })
462 .catch_unwind()
463 .await;
464
465 if let Ok(mut st) = state.write() {
466 *st = ActorState::Stopped;
467 }
468 fail_pending_calls(&pending);
469 match res {
470 Ok(result) => result,
471 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
472 }
473 })
474 };
475
476 (
477 Self {
478 tx: Arc::new(Mutex::new(Some(tx))),
479 state,
480 pending,
481 next_call_id,
482 stopped_rx,
483 _phantom: std::marker::PhantomData,
484 },
485 join_handle,
486 )
487 }
488
489 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
494 pub fn spawn_with<F, Fut>(actor: A, run: F) -> (Self, async_std::task::JoinHandle<Result<(), E>>)
495 where
496 F: FnOnce(A, Receiver<Action<A>>) -> Fut + Send + 'static,
497 Fut: Future<Output = Result<(), E>> + Send,
498 {
499 let (tx, rx) = flume::unbounded();
500 let state = Arc::new(RwLock::new(ActorState::default()));
501 let pending = Arc::new(Mutex::new(HashMap::new()));
502 let next_call_id = Arc::new(AtomicU64::new(0));
503 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
504
505 let join_handle = {
506 let state = Arc::clone(&state);
507 let pending = Arc::clone(&pending);
508 async_std::task::spawn(async move {
509 let _stopped_signal = stopped_tx;
510
511 let res = std::panic::AssertUnwindSafe(run(actor, rx))
512 .catch_unwind()
513 .await;
514
515 if let Ok(mut st) = state.write() {
516 *st = ActorState::Stopped;
517 }
518 fail_pending_calls(&pending);
519 match res {
520 Ok(result) => result,
521 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
522 }
523 })
524 };
525
526 (
527 Self {
528 tx: Arc::new(Mutex::new(Some(tx))),
529 state,
530 pending,
531 next_call_id,
532 stopped_rx,
533 _phantom: std::marker::PhantomData,
534 },
535 join_handle,
536 )
537 }
538
539 pub fn spawn_blocking(actor: A) -> (Self, std::thread::JoinHandle<Result<(), E>>)
544 {
545 let (tx, rx) = flume::unbounded::<Action<A>>();
546 let state = Arc::new(RwLock::new(ActorState::default()));
547 let pending = Arc::new(Mutex::new(HashMap::new()));
548 let next_call_id = Arc::new(AtomicU64::new(0));
549 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
550
551 let join_handle = {
552 let state = Arc::clone(&state);
553 let pending = Arc::clone(&pending);
554 std::thread::spawn(move || {
555 let _stopped_signal = stopped_tx;
556 let mut actor = actor;
557
558 let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
559 while let Ok(action) = rx.recv() {
560 block_on(action(&mut actor));
561 }
562 Ok::<(), E>(())
563 }));
564
565 if let Ok(mut st) = state.write() {
566 *st = ActorState::Stopped;
567 }
568 fail_pending_calls(&pending);
569 match res {
570 Ok(result) => result,
571 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
572 }
573 })
574 };
575
576 (
577 Self {
578 tx: Arc::new(Mutex::new(Some(tx))),
579 state,
580 pending,
581 next_call_id,
582 stopped_rx,
583 _phantom: std::marker::PhantomData,
584 },
585 join_handle,
586 )
587 }
588
589 pub fn spawn_blocking_with<F>(actor: A, run: F) -> (Self, std::thread::JoinHandle<Result<(), E>>)
594 where
595 F: FnOnce(A, Receiver<Action<A>>) -> Result<(), E> + Send + 'static,
596 {
597 let (tx, rx) = flume::unbounded();
598 let state = Arc::new(RwLock::new(ActorState::default()));
599 let pending = Arc::new(Mutex::new(HashMap::new()));
600 let next_call_id = Arc::new(AtomicU64::new(0));
601 let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
602
603 let join_handle = {
604 let state = Arc::clone(&state);
605 let pending = Arc::clone(&pending);
606 std::thread::spawn(move || {
607 let _stopped_signal = stopped_tx;
608
609 let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
610 run(actor, rx)
611 }));
612
613 if let Ok(mut st) = state.write() {
614 *st = ActorState::Stopped;
615 }
616 fail_pending_calls(&pending);
617 match res {
618 Ok(result) => result,
619 Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
620 }
621 })
622 };
623
624 (
625 Self {
626 tx: Arc::new(Mutex::new(Some(tx))),
627 state,
628 pending,
629 next_call_id,
630 stopped_rx,
631 _phantom: std::marker::PhantomData,
632 },
633 join_handle,
634 )
635 }
636
637 fn base_call<R, F>(&self, f: F) -> BaseCallResult<R, E>
639 where
640 F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
641 R: Send + 'static,
642 {
643 if self.state() != ActorState::Running {
644 return Err(E::from_actor_message(
645 "actor stopped (call attempted while actor state is not running)".to_string(),
646 ));
647 }
648
649 let (rtx, rrx) = flume::unbounded();
650 let (cancel_tx, cancel_rx) = flume::bounded::<()>(1);
651 let loc = std::panic::Location::caller();
652 let call_id = self.next_call_id.fetch_add(1, Ordering::Relaxed);
653 self.pending
654 .lock()
655 .expect("poisoned lock")
656 .insert(call_id, cancel_tx);
657
658 let action: Action<A> = Box::new(move |actor: &mut A| {
659 Box::pin(async move {
660 let panic_result = std::panic::AssertUnwindSafe(async move { f(actor).await })
662 .catch_unwind()
663 .await;
664
665 let res = match panic_result {
666 Ok(action_result) => action_result,
667 Err(panic_payload) => {
668 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
670 (*s).to_string()
671 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
672 s.clone()
673 } else {
674 "unknown panic".to_string()
675 };
676 Err(E::from_actor_message(format!(
677 "panic in actor call at {}:{}: {}",
678 loc.file(),
679 loc.line(),
680 msg
681 )))
682 }
683 };
684
685 let _ = rtx.send(res);
687 })
688 });
689
690 let sent = {
691 let tx_guard = self.tx.lock().expect("poisoned lock");
692 tx_guard
693 .as_ref()
694 .map_or(false, |tx| tx.send(action).is_ok())
695 };
696
697 if !sent {
698 if let Ok(mut pending) = self.pending.lock() {
699 pending.remove(&call_id);
700 }
701 return Err(E::from_actor_message(format!(
702 "actor stopped (call send at {}:{})",
703 loc.file(),
704 loc.line()
705 )));
706 }
707
708 Ok((rrx, cancel_rx, call_id, loc))
709 }
710
711 pub fn call_blocking<R, F>(&self, f: F) -> Result<R, E>
721 where
722 F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
723 R: Send + 'static,
724 {
725 enum BlockingWaitResult<T, E> {
726 Result(Result<Result<T, E>, flume::RecvError>),
727 Canceled(Result<(), flume::RecvError>),
728 }
729
730 let (rrx, cancel_rx, call_id, loc) = self.base_call(f)?;
731 let out = match flume::Selector::new()
732 .recv(&rrx, BlockingWaitResult::Result)
733 .recv(&cancel_rx, BlockingWaitResult::Canceled)
734 .wait()
735 {
736 BlockingWaitResult::Result(msg) => msg.map_err(|_| {
737 E::from_actor_message(format!(
738 "actor stopped (call recv at {}:{})",
739 loc.file(),
740 loc.line()
741 ))
742 })?,
743 BlockingWaitResult::Canceled(Ok(())) => Err(E::from_actor_message(format!(
744 "actor stopped (call canceled at {}:{})",
745 loc.file(),
746 loc.line()
747 ))),
748 BlockingWaitResult::Canceled(Err(_)) => Err(E::from_actor_message(format!(
749 "actor stopped (call recv at {}:{})",
750 loc.file(),
751 loc.line()
752 ))),
753 };
754
755 if let Ok(mut pending) = self.pending.lock() {
756 pending.remove(&call_id);
757 }
758
759 out
760 }
761
762 #[cfg(any(feature = "tokio", feature = "async-std"))]
772 pub async fn call<R, F>(&self, f: F) -> Result<R, E>
773 where
774 F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
775 R: Send + 'static,
776 {
777 let (rrx, cancel_rx, call_id, loc) = self.base_call(f)?;
778
779 let recv_fut = rrx.recv_async();
780 let cancel_fut = cancel_rx.recv_async();
781 pin_mut!(recv_fut, cancel_fut);
782
783 let out = match future::select(recv_fut, cancel_fut).await {
784 Either::Left((msg, _)) => msg.map_err(|_| {
785 E::from_actor_message(format!(
786 "actor stopped (call recv at {}:{})",
787 loc.file(),
788 loc.line()
789 ))
790 })?,
791 Either::Right((Ok(_), _)) => Err(E::from_actor_message(format!(
792 "actor stopped (call canceled at {}:{})",
793 loc.file(),
794 loc.line()
795 ))),
796 Either::Right((Err(_), _)) => Err(E::from_actor_message(format!(
797 "actor stopped (call recv at {}:{})",
798 loc.file(),
799 loc.line()
800 ))),
801 };
802
803 if let Ok(mut pending) = self.pending.lock() {
804 pending.remove(&call_id);
805 }
806
807 out
808 }
809
810 pub fn shutdown(&self) {
816 if let Ok(mut tx) = self.tx.lock() {
817 tx.take();
818 }
819 }
820
821 pub fn wait_stopped_blocking(&self) {
823 if self.state() == ActorState::Stopped {
824 return;
825 }
826 let _ = self.stopped_rx.recv();
827 }
828
829 #[cfg(any(feature = "tokio", feature = "async-std"))]
831 pub async fn wait_stopped(&self) {
832 if self.state() == ActorState::Stopped {
833 return;
834 }
835 let _ = self.stopped_rx.recv_async().await;
836 }
837}