use std::any;
use std::fmt;
use std::ops::Deref;
use std::sync::{Arc, Weak};
extern crate stable_deref_trait;
use stable_deref_trait::CloneStableDeref;
use dashmap::DashMap;
use downcast_rs::{Downcast, DowncastSync};
use parking_lot::{Once, RwLock, RwLockReadGuard, RwLockWriteGuard};
use super::{ComponentManager, Entity};
use crate::snowflake::Snowflake;
use crate::util::Result;
rental! {
pub mod handle_ref {
use super::*;
#[rental(debug, clone, deref_suffix, covariant, map_suffix = "T")]
pub struct HandleReadRef<H: CloneStableDeref + Deref + 'static, T: 'static> {
head: H,
suffix: RwLockReadGuard<'head, T>,
}
#[rental(debug, clone, deref_mut_suffix, covariant, map_suffix = "T")]
pub struct HandleWriteRef<H: CloneStableDeref + Deref + 'static, T: 'static> {
head: H,
suffix: RwLockWriteGuard<'head, T>,
}
}
}
pub use handle_ref::{HandleReadRef, HandleWriteRef};
pub type StoreReference<T> = Arc<RwLock<T>>;
type WeakStoreReference<T> = Weak<RwLock<T>>;
pub type ReadReference<T> = HandleReadRef<StoreReference<T>, T>;
pub type WriteReference<T> = HandleWriteRef<StoreReference<T>, T>;
pub fn read_store_reference<T: 'static>(head: StoreReference<T>) -> ReadReference<T> {
HandleReadRef::new(head, |s| s.read())
}
pub fn write_store_reference<T: 'static>(head: StoreReference<T>) -> WriteReference<T> {
HandleWriteRef::new(head, |s| s.write())
}
pub trait SharedStore<T, U>
where
T: Entity + 'static,
U: EntityBackend<T> + Sync + Send + 'static,
{
fn get_store<'a>(&'a self) -> &'a Store<T, U>;
}
pub struct StoreHandle<T>
where
T: Entity + 'static,
{
backend: Arc<dyn EntityBackend<T> + Sync + Send + 'static>,
id: Snowflake,
object: Option<T>,
}
impl<T> StoreHandle<T>
where
T: Entity + 'static,
{
fn new<U>(backend: Arc<U>, id: Snowflake, object: Option<T>) -> StoreHandle<T>
where
U: EntityBackend<T> + Sync + Send + 'static,
{
StoreHandle {
backend,
id,
object,
}
}
pub fn get(&self) -> Option<&T> {
self.object.as_ref()
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.object.as_mut()
}
pub fn replace(&mut self, object: T) -> Option<T> {
self.object.replace(object)
}
pub fn id(&self) -> Snowflake {
self.id
}
pub fn exists(&self) -> bool {
self.object.is_some()
}
pub fn store(&self) -> Result<()> {
match &self.object {
None => self.backend.delete(self.id),
Some(obj) => self.backend.store(self.id, &obj),
}
}
pub fn delete(&mut self) -> Result<()> {
if let Some(obj) = &mut self.object {
obj.clear_components()?;
}
self.object = None;
self.backend.delete(self.id)
}
fn set_object(&mut self, object: Option<T>) {
self.object = object;
}
}
impl<T> Drop for StoreHandle<T>
where
T: Entity + 'static,
{
fn drop(&mut self) {
if let Some(entity) = &self.object {
if entity.dirty() {
let _e = self.backend.store(self.id, &entity);
}
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct StoredHandleData<T>
where
T: Entity + 'static,
{
initializer: Arc<Once>,
handle: WeakStoreReference<StoreHandle<T>>,
}
#[doc(hidden)]
#[derive(Clone)]
pub struct HandleData<T>
where
T: Entity + 'static,
{
initializer: Arc<Once>,
handle: StoreReference<StoreHandle<T>>,
}
pub struct Store<T, U>
where
T: Entity + 'static,
U: EntityBackend<T> + Sync + Send + 'static,
{
backend: Arc<U>,
refs: DashMap<Snowflake, StoredHandleData<T>>,
}
impl<T, U> Store<T, U>
where
T: Entity + 'static,
U: EntityBackend<T> + Sync + Send + 'static,
{
pub fn new(backend: Arc<U>) -> Store<T, U> {
Store {
backend,
refs: DashMap::new(),
}
}
fn get_handle(&self, id: Snowflake) -> HandleData<T> {
let mut entry = self.refs.entry(id).or_insert_with(|| StoredHandleData {
initializer: Arc::new(Once::new()),
handle: Weak::new(),
});
if let Some(strong) = entry.handle.upgrade() {
HandleData {
initializer: entry.initializer.clone(),
handle: strong,
}
} else {
let handle: StoreHandle<T> = StoreHandle::new(self.backend.clone(), id, None);
let initializer = Arc::new(Once::new());
let strong = Arc::new(RwLock::new(handle));
entry.handle = Arc::downgrade(&strong);
entry.initializer = initializer.clone();
HandleData {
initializer,
handle: strong,
}
}
}
fn initialize_handle(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
handle_data: HandleData<T>,
) -> Result<HandleData<T>> {
let mut res: Result<()> = Result::Ok(());
handle_data.initializer.call_once(|| {
let mut write_handle = handle_data.handle.write();
match self.backend.load(id, cm) {
Err(e) => {
res = Err(e);
write_handle.set_object(None);
}
Ok(data) => {
write_handle.set_object(data);
}
};
});
match res {
Err(e) => Err(e),
Ok(_v) => Ok(handle_data),
}
}
pub fn load_handle(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<StoreReference<StoreHandle<T>>> {
let handle_data = self.initialize_handle(id, cm, self.get_handle(id))?;
Ok(handle_data.handle.clone())
}
pub fn load(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<ReadReference<StoreHandle<T>>> {
let handle_data = self.initialize_handle(id, cm, self.get_handle(id))?;
Ok(read_store_reference(handle_data.handle))
}
pub fn load_mut(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<WriteReference<StoreHandle<T>>> {
let handle_data = self.initialize_handle(id, cm, self.get_handle(id))?;
Ok(write_store_reference(handle_data.handle))
}
pub fn store(&self, object: T) -> Result<()> {
let id = object.id();
let handle_data = self.get_handle(id);
let mut object: Option<T> = Some(object);
let mut initializer_result: Option<Result<()>> = None;
handle_data.initializer.call_once(|| {
let mut handle = handle_data.handle.write();
handle.set_object(object.take());
initializer_result = Some(handle.store());
});
if let Some(obj) = object {
let mut handle = handle_data.handle.write();
handle.set_object(Some(obj));
handle.store()
} else {
initializer_result.unwrap()
}
}
pub fn insert(&self, object: T) -> WriteReference<StoreHandle<T>> {
let id = object.id();
let handle_data = self.get_handle(id);
let mut object: Option<T> = Some(object);
let mut initializer_result: Option<WriteReference<StoreHandle<T>>> = None;
handle_data.initializer.call_once(|| {
let mut handle = write_store_reference(handle_data.handle.clone());
handle.set_object(object.take());
initializer_result = Some(handle);
});
if let Some(obj) = object {
let mut ret = write_store_reference(handle_data.handle.clone());
ret.set_object(Some(obj));
ret
} else {
initializer_result.unwrap()
}
}
pub fn delete(&self, id: Snowflake, cm: Arc<ComponentManager<T>>) -> Result<()> {
let mut handle = self.load_mut(id, cm)?;
handle.delete()
}
pub fn exists(&self, id: Snowflake) -> Result<bool> {
let handle_data = self.get_handle(id);
if handle_data.initializer.state().done() {
let read_lock = handle_data.handle.read();
Ok(read_lock.exists())
} else {
self.backend.exists(id)
}
}
pub fn keys(&self, page: u64, limit: u64) -> Result<Vec<Snowflake>> {
self.backend.keys(page, limit)
}
}
#[doc(hidden)]
pub struct EntityStoreDowncastHelper<T: Entity + 'static>(pub Box<dyn EntityStore<T> + 'static>);
#[doc(hidden)]
pub trait EntityStoreDowncast: Downcast + Send + Sync + 'static {}
downcast_rs::impl_downcast!(EntityStoreDowncast);
impl<T: Entity + 'static> EntityStoreDowncast for EntityStoreDowncastHelper<T> {}
pub trait EntityStore<T>: DowncastSync
where
T: Entity + 'static,
{
fn load(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<ReadReference<StoreHandle<T>>>;
fn load_mut(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<WriteReference<StoreHandle<T>>>;
fn store(&self, object: T) -> Result<()>;
fn insert(&self, object: T) -> WriteReference<StoreHandle<T>>;
fn delete(&self, id: Snowflake, cm: Arc<ComponentManager<T>>) -> Result<()>;
fn exists(&self, id: Snowflake) -> Result<bool>;
fn keys(&self, page: u64, limit: u64) -> Result<Vec<Snowflake>>;
}
downcast_rs::impl_downcast!(sync EntityStore<T> where T: Entity + 'static);
impl<T, U> EntityStore<T> for Store<T, U>
where
T: Entity + 'static,
U: EntityBackend<T> + Sync + Send + 'static,
{
fn load(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<ReadReference<StoreHandle<T>>> {
self.load(id, cm)
}
fn load_mut(
&self,
id: Snowflake,
cm: Arc<ComponentManager<T>>,
) -> Result<WriteReference<StoreHandle<T>>> {
self.load_mut(id, cm)
}
fn store(&self, object: T) -> Result<()> {
self.store(object)
}
fn insert(&self, object: T) -> WriteReference<StoreHandle<T>> {
self.insert(object)
}
fn delete(&self, id: Snowflake, cm: Arc<ComponentManager<T>>) -> Result<()> {
self.delete(id, cm)
}
fn exists(&self, id: Snowflake) -> Result<bool> {
self.exists(id)
}
fn keys(&self, page: u64, limit: u64) -> Result<Vec<Snowflake>> {
self.keys(page, limit)
}
}
impl<T, U> fmt::Debug for Store<T, U>
where
T: Entity + 'static,
U: EntityBackend<T> + Sync + Send + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Store<{}, {}> {{ {} keys }}",
any::type_name::<T>(),
any::type_name::<U>(),
self.refs.len()
)
}
}
pub trait EntityBackend<T: Entity + 'static> {
fn load(&self, id: Snowflake, cm: Arc<ComponentManager<T>>) -> Result<Option<T>>;
fn exists(&self, id: Snowflake) -> Result<bool>;
fn store(&self, id: Snowflake, object: &T) -> Result<()>;
fn delete(&self, id: Snowflake) -> Result<()>;
fn keys(&self, page: u64, limit: u64) -> Result<Vec<Snowflake>>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::any::TypeId;
use std::collections::{HashMap, HashSet};
use std::sync::{mpsc, Arc, Barrier, RwLock};
use std::thread;
use dashmap::DashMap;
use crate::ecs::Component;
use crate::snowflake::SnowflakeGenerator;
struct MockStoredData {
id: Snowflake,
field_a: String,
field_b: u64,
cm: Arc<ComponentManager<MockStoredData>>,
components_attached: HashSet<TypeId>,
component_preloads: DashMap<TypeId, Box<dyn Component<Self> + Send + Sync + 'static>>,
dirty: bool,
}
impl MockStoredData {
fn new(
id: Snowflake,
field_a: String,
field_b: u64,
cm: Arc<ComponentManager<MockStoredData>>,
) -> MockStoredData {
MockStoredData {
id,
field_a,
field_b,
cm,
components_attached: HashSet::new(),
component_preloads: DashMap::new(),
dirty: false,
}
}
fn id<'a>(&'a self) -> &'a Snowflake {
&self.id
}
}
impl Entity for MockStoredData {
fn new(
id: Snowflake,
cm: Arc<ComponentManager<Self>>,
_components: HashSet<TypeId>,
) -> Self {
MockStoredData::new(id, String::from(""), 0, cm)
}
fn id(&self) -> Snowflake {
self.id
}
fn component_manager(&self) -> &ComponentManager<MockStoredData> {
&self.cm
}
fn components_attached(&self) -> &HashSet<TypeId> {
&self.components_attached
}
fn components_attached_mut(&mut self) -> &mut HashSet<TypeId> {
&mut self.components_attached
}
fn preloaded_components(
&self,
) -> &DashMap<TypeId, Box<dyn Component<Self> + Send + Sync + 'static>> {
&self.component_preloads
}
fn dirty(&self) -> bool {
self.dirty
}
fn dirty_mut(&mut self) -> &mut bool {
&mut self.dirty
}
}
impl Clone for MockStoredData {
fn clone(&self) -> Self {
Self {
id: self.id,
dirty: self.dirty,
cm: self.cm.clone(),
components_attached: self.components_attached.clone(),
component_preloads: DashMap::new(),
field_a: self.field_a.clone(),
field_b: self.field_b,
}
}
}
struct MockEntityBackend {
data: RwLock<HashMap<Snowflake, MockStoredData>>,
remove_on_load: bool,
}
impl MockEntityBackend {
fn new() -> MockEntityBackend {
MockEntityBackend {
data: RwLock::new(HashMap::new()),
remove_on_load: false,
}
}
fn set_remove_on_load(&mut self, flag: bool) {
self.remove_on_load = flag;
}
}
impl EntityBackend<MockStoredData> for MockEntityBackend {
fn exists(&self, id: Snowflake) -> Result<bool> {
let map = self.data.read().unwrap();
Ok(map.contains_key(&id))
}
fn load(
&self,
id: Snowflake,
_cm: Arc<ComponentManager<MockStoredData>>,
) -> Result<Option<MockStoredData>> {
if !self.remove_on_load {
let map = self.data.read().unwrap();
Ok(map.get(&id).map(|pl| pl.clone()))
} else {
let mut map = self.data.write().unwrap();
let res = Ok(map.get(&id).map(|pl| pl.clone()));
map.remove(&id);
res
}
}
fn store(&self, id: Snowflake, data: &MockStoredData) -> Result<()> {
let mut map = self.data.write().unwrap();
map.insert(id, data.clone());
Ok(())
}
fn delete(&self, id: Snowflake) -> Result<()> {
let mut map = self.data.write().unwrap();
map.remove(&id);
Ok(())
}
fn keys(&self, page: u64, limit: u64) -> Result<Vec<Snowflake>> {
let ids: Vec<Snowflake>;
let start_index = page * limit;
let data = self.data.read().unwrap();
ids = data
.keys()
.skip(start_index as usize)
.take(limit as usize)
.map(|x| *x)
.collect();
Ok(ids)
}
}
type MockStore = Store<MockStoredData, MockEntityBackend>;
#[test]
fn test_exists() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let backend = Arc::new(MockEntityBackend::new());
let data = MockStoredData::new(
snowflake_gen.generate(),
"foo".to_owned(),
1,
Arc::new(ComponentManager::new()),
);
backend.store(*data.id(), &data).unwrap();
let store = MockStore::new(backend);
let id2 = snowflake_gen.generate();
assert!(store.exists(*data.id()).unwrap());
assert!(!store.exists(id2).unwrap());
}
#[test]
fn test_load_nonexistent() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let backend = Arc::new(MockEntityBackend::new());
let store = MockStore::new(backend);
let handle = store
.load(snowflake_gen.generate(), Arc::new(ComponentManager::new()))
.unwrap();
assert!(!handle.exists());
}
#[test]
fn test_load() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let backend = Arc::new(MockEntityBackend::new());
let data = MockStoredData::new(
snowflake_gen.generate(),
"foo".to_owned(),
1,
Arc::new(ComponentManager::new()),
);
backend.store(*data.id(), &data).unwrap();
let store = MockStore::new(backend);
let handle = store
.load(*data.id(), Arc::new(ComponentManager::new()))
.unwrap();
assert!(handle.exists());
let data_copy = handle.get().unwrap();
assert_eq!(*data.id(), *data_copy.id());
assert_eq!(data.field_a, data_copy.field_a);
assert_eq!(data.field_b, data_copy.field_b);
}
#[test]
fn test_concurrent_load() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let backend = Arc::new(MockEntityBackend::new());
let id = snowflake_gen.generate();
let data = MockStoredData::new(id, "foo".to_owned(), 1, Arc::new(ComponentManager::new()));
backend.store(id, &data).unwrap();
let store = Arc::new(MockStore::new(backend));
let mut threads = Vec::with_capacity(9);
let barrier = Arc::new(Barrier::new(10));
for _ in 0..9 {
let b_clone = barrier.clone();
let s_clone = store.clone();
let handle = thread::spawn(move || {
b_clone.wait();
let wrapper = s_clone.get_handle(id);
wrapper
});
threads.push(handle);
}
barrier.wait();
let our_wrapper = store.get_handle(id);
for thread in threads {
let their_wrapper = thread.join().unwrap();
assert!(Arc::ptr_eq(&our_wrapper.handle, &their_wrapper.handle));
}
}
#[test]
fn test_concurrent_access() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let mut backend = MockEntityBackend::new();
let id = snowflake_gen.generate();
let data = MockStoredData::new(id, "foo".to_owned(), 1, Arc::new(ComponentManager::new()));
backend.set_remove_on_load(true);
backend.store(id, &data).unwrap();
let store = Arc::new(MockStore::new(Arc::new(backend)));
let barrier = Arc::new(Barrier::new(10));
let mut threads = Vec::with_capacity(9);
let (tx, rx) = mpsc::channel();
for _ in 0..9 {
let b_clone = barrier.clone();
let s_clone = store.clone();
let tx_clone = tx.clone();
let handle = thread::spawn(move || {
b_clone.wait();
let handle = s_clone.load(id, Arc::new(ComponentManager::new())).unwrap();
let data = handle.get().unwrap();
tx_clone
.send((data.id, data.field_a.clone(), data.field_b))
.unwrap();
b_clone.wait();
assert_eq!(data.id, id);
});
threads.push(handle);
}
{
barrier.wait();
let handle = store.load(id, Arc::new(ComponentManager::new())).unwrap();
let data = handle.get().unwrap();
for _ in 0..9 {
let their_data = rx.recv().unwrap();
assert_eq!(data.id, their_data.0);
assert_eq!(data.field_a, their_data.1);
assert_eq!(data.field_b, their_data.2);
}
}
barrier.wait();
for thread in threads {
thread.join().unwrap();
}
let handle = store.load(id, Arc::new(ComponentManager::new())).unwrap();
let data = handle.get();
assert!(data.is_none());
}
#[test]
fn test_multiple_single_thread_access() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let backend = MockEntityBackend::new();
let id = snowflake_gen.generate();
let data = MockStoredData::new(id, "foo".to_owned(), 1, Arc::new(ComponentManager::new()));
backend.store(id, &data).unwrap();
let store = MockStore::new(Arc::new(backend));
let handle_1 = store.load(id, Arc::new(ComponentManager::new())).unwrap();
let data_1 = handle_1.get().unwrap();
let handle_2 = store.load(id, Arc::new(ComponentManager::new())).unwrap();
let data_2 = handle_2.get().unwrap();
assert_eq!(data_1.id, data_2.id);
assert_eq!(data_1.field_a, data_2.field_a);
assert_eq!(data_1.field_b, data_2.field_b);
assert_eq!(data_1.id, data.id);
assert_eq!(data_1.field_a, data.field_a);
assert_eq!(data_1.field_b, data.field_b);
}
#[test]
fn test_store() {
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let id = snowflake_gen.generate();
let cm = Arc::new(ComponentManager::new());
let data = MockStoredData::new(id, "foo".to_owned(), 1, cm.clone());
let backend = Arc::new(MockEntityBackend::new());
let store = MockStore::new(backend);
{
let mut handle = store.load_mut(*data.id(), cm.clone()).unwrap();
assert!(!handle.exists());
handle.replace(data.clone());
handle.store().unwrap();
}
let handle = store.load(*data.id(), cm).unwrap();
let data_copy = handle.get().unwrap();
assert_eq!(*data_copy.id(), id);
assert_eq!(data.field_a, data_copy.field_a);
assert_eq!(data.field_b, data_copy.field_b);
}
#[test]
fn test_handle_drop() {
use super::ComponentManager;
use crate::local_storage::LocalComponentStorage;
use crate::Component;
#[derive(Clone)]
struct TestComponent(u64);
impl Component<MockStoredData> for TestComponent {};
let backend = Arc::new(MockEntityBackend::new());
let store = MockStore::new(backend);
let mut cm: ComponentManager<MockStoredData> = ComponentManager::new();
cm.register_component(
"TestComponent",
LocalComponentStorage::<MockStoredData, TestComponent>::new(),
)
.unwrap();
let cm = Arc::new(cm);
let mut snowflake_gen = SnowflakeGenerator::new(0, 0);
let id = snowflake_gen.generate();
{
let data = MockStoredData::new(id, "foo".to_owned(), 1, cm.clone());
let mut handle = store.insert(data);
let data = handle.get_mut().unwrap();
data.set_component(TestComponent(50)).unwrap();
}
let handle = store.load(id, cm.clone()).unwrap();
let data = handle.get().unwrap();
let component: TestComponent = data.get_component().unwrap().unwrap();
assert_eq!(component.0, 50);
}
}