use std::{
any::TypeId,
cell::{RefCell, UnsafeCell},
fmt::Debug,
marker::PhantomData,
ops::{Deref, DerefMut},
rc::Rc,
};
use ahash::AHashMap;
use ustr::Ustr;
use super::Actor;
pub struct ActorRef<T: Actor> {
actor_rc: Rc<UnsafeCell<dyn Actor>>,
_marker: PhantomData<T>,
}
impl<T: Actor> Debug for ActorRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(ActorRef))
.field("actor_id", &self.deref().id())
.finish()
}
}
impl<T: Actor> Deref for ActorRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*(self.actor_rc.get() as *const T) }
}
}
impl<T: Actor> DerefMut for ActorRef<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.actor_rc.get().cast::<T>() }
}
}
thread_local! {
static ACTOR_REGISTRY: ActorRegistry = ActorRegistry::new();
}
pub struct ActorRegistry {
actors: RefCell<AHashMap<Ustr, Rc<UnsafeCell<dyn Actor>>>>,
}
impl Debug for ActorRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let actors_ref = self.actors.borrow();
let keys: Vec<&Ustr> = actors_ref.keys().collect();
f.debug_struct(stringify!(ActorRegistry))
.field("actors", &keys)
.finish()
}
}
impl Default for ActorRegistry {
fn default() -> Self {
Self::new()
}
}
impl ActorRegistry {
pub fn new() -> Self {
Self {
actors: RefCell::new(AHashMap::new()),
}
}
pub fn insert(&self, id: Ustr, actor: Rc<UnsafeCell<dyn Actor>>) {
let mut actors = self.actors.borrow_mut();
if actors.contains_key(&id) {
log::warn!("Replacing existing actor with id: {id}");
}
actors.insert(id, actor);
}
pub fn get(&self, id: &Ustr) -> Option<Rc<UnsafeCell<dyn Actor>>> {
self.actors.borrow().get(id).cloned()
}
pub fn len(&self) -> usize {
self.actors.borrow().len()
}
pub fn is_empty(&self) -> bool {
self.actors.borrow().is_empty()
}
pub fn remove(&self, id: &Ustr) -> Option<Rc<UnsafeCell<dyn Actor>>> {
self.actors.borrow_mut().remove(id)
}
pub fn contains(&self, id: &Ustr) -> bool {
self.actors.borrow().contains_key(id)
}
}
pub fn get_actor_registry() -> &'static ActorRegistry {
ACTOR_REGISTRY.with(|registry| unsafe {
std::mem::transmute::<&ActorRegistry, &'static ActorRegistry>(registry)
})
}
pub fn register_actor<T>(actor: T) -> Rc<UnsafeCell<T>>
where
T: Actor + 'static,
{
let actor_id = actor.id();
let actor_ref = Rc::new(UnsafeCell::new(actor));
let actor_trait_ref: Rc<UnsafeCell<dyn Actor>> = actor_ref.clone();
get_actor_registry().insert(actor_id, actor_trait_ref);
actor_ref
}
pub fn get_actor(id: &Ustr) -> Option<Rc<UnsafeCell<dyn Actor>>> {
get_actor_registry().get(id)
}
#[must_use]
pub fn get_actor_unchecked<T: Actor>(id: &Ustr) -> ActorRef<T> {
let registry = get_actor_registry();
let actor_rc = registry
.get(id)
.unwrap_or_else(|| panic!("Actor for {id} not found"));
let actor_ref = unsafe { &*actor_rc.get() };
let actual_type = actor_ref.as_any().type_id();
let expected_type = TypeId::of::<T>();
assert!(
actual_type == expected_type,
"Actor type mismatch for '{id}': expected {expected_type:?}, found {actual_type:?}"
);
ActorRef {
actor_rc,
_marker: PhantomData,
}
}
#[must_use]
pub fn try_get_actor_unchecked<T: Actor>(id: &Ustr) -> Option<ActorRef<T>> {
let registry = get_actor_registry();
let actor_rc = registry.get(id)?;
let actor_ref = unsafe { &*actor_rc.get() };
let actual_type = actor_ref.as_any().type_id();
let expected_type = TypeId::of::<T>();
if actual_type != expected_type {
return None;
}
Some(ActorRef {
actor_rc,
_marker: PhantomData,
})
}
pub fn actor_exists(id: &Ustr) -> bool {
get_actor_registry().contains(id)
}
pub fn actor_count() -> usize {
get_actor_registry().len()
}
#[cfg(test)]
pub fn clear_actor_registry() {
let registry = get_actor_registry();
registry.actors.borrow_mut().clear();
}
#[cfg(test)]
mod tests {
use std::any::Any;
use rstest::rstest;
use super::*;
#[derive(Debug)]
struct TestActor {
id: Ustr,
value: i32,
}
impl Actor for TestActor {
fn id(&self) -> Ustr {
self.id
}
fn handle(&mut self, _msg: &dyn Any) {}
fn as_any(&self) -> &dyn Any {
self
}
}
#[rstest]
fn test_register_and_get_actor() {
clear_actor_registry();
let id = Ustr::from("test-actor");
let actor = TestActor { id, value: 42 };
register_actor(actor);
let actor_ref = get_actor_unchecked::<TestActor>(&id);
assert_eq!(actor_ref.value, 42);
}
#[rstest]
fn test_mutation_through_reference() {
clear_actor_registry();
let id = Ustr::from("test-actor-mut");
let actor = TestActor { id, value: 0 };
register_actor(actor);
let mut actor_ref = get_actor_unchecked::<TestActor>(&id);
actor_ref.value = 999;
let actor_ref2 = get_actor_unchecked::<TestActor>(&id);
assert_eq!(actor_ref2.value, 999);
}
#[rstest]
fn test_try_get_returns_none_for_missing() {
clear_actor_registry();
let id = Ustr::from("nonexistent");
let result = try_get_actor_unchecked::<TestActor>(&id);
assert!(result.is_none());
}
#[rstest]
fn test_try_get_returns_none_for_wrong_type() {
#[derive(Debug)]
struct OtherActor {
id: Ustr,
}
impl Actor for OtherActor {
fn id(&self) -> Ustr {
self.id
}
fn handle(&mut self, _msg: &dyn Any) {}
fn as_any(&self) -> &dyn Any {
self
}
}
clear_actor_registry();
let id = Ustr::from("other-actor");
let actor = OtherActor { id };
register_actor(actor);
let result = try_get_actor_unchecked::<TestActor>(&id);
assert!(result.is_none());
}
#[rstest]
fn test_registry_is_thread_local() {
clear_actor_registry();
let id = Ustr::from("thread-local-actor");
let actor = TestActor { id, value: 42 };
register_actor(actor);
assert!(actor_exists(&id));
assert_eq!(actor_count(), 1);
let visible_on_other_thread = std::thread::spawn(move || {
(actor_exists(&id), actor_count())
})
.join()
.unwrap();
assert!(!visible_on_other_thread.0);
assert_eq!(visible_on_other_thread.1, 0);
}
#[rstest]
fn test_actor_ref_survives_registry_removal() {
clear_actor_registry();
let id = Ustr::from("removable-actor");
let actor = TestActor { id, value: 7 };
register_actor(actor);
assert_eq!(actor_count(), 1);
let mut guard = get_actor_unchecked::<TestActor>(&id);
get_actor_registry().remove(&id);
assert!(!actor_exists(&id));
assert_eq!(actor_count(), 0);
assert_eq!(guard.value, 7);
guard.value = 99;
assert_eq!(guard.value, 99);
}
#[rstest]
fn test_actor_ref_survives_same_id_replacement() {
clear_actor_registry();
let id = Ustr::from("replaceable-actor");
let actor_a = TestActor { id, value: 1 };
register_actor(actor_a);
let guard_a = get_actor_unchecked::<TestActor>(&id);
assert_eq!(guard_a.value, 1);
let actor_b = TestActor { id, value: 2 };
register_actor(actor_b);
assert_eq!(guard_a.value, 1);
let guard_b = get_actor_unchecked::<TestActor>(&id);
assert_eq!(guard_b.value, 2);
assert_eq!(actor_count(), 1);
}
#[should_panic(expected = "Actor type mismatch")]
#[rstest]
fn test_get_actor_unchecked_panics_on_type_mismatch() {
#[derive(Debug)]
struct OtherActor {
id: Ustr,
}
impl Actor for OtherActor {
fn id(&self) -> Ustr {
self.id
}
fn handle(&mut self, _msg: &dyn Any) {}
fn as_any(&self) -> &dyn Any {
self
}
}
clear_actor_registry();
let id = Ustr::from("typed-actor");
let actor = OtherActor { id };
register_actor(actor);
let _guard = get_actor_unchecked::<TestActor>(&id);
}
}