use std::{
borrow::{Borrow, BorrowMut},
marker::PhantomData,
ops::{Deref, DerefMut},
};
use hibitset::BitSet;
use shred::Fetch;
#[nougat::gat(Type)]
use crate::join::LendJoin;
use crate::join::{Join, RepeatableLendGet};
#[cfg(feature = "parallel")]
use crate::join::ParJoin;
use crate::{
storage::{
AccessMutReturn, DistinctStorage, MaskedStorage, SharedGetMutStorage, Storage,
UnprotectedStorage,
},
world::{Component, EntitiesRes, Entity, Index},
};
pub struct RestrictedStorage<'rf, C, S> {
bitset: &'rf BitSet,
data: S,
entities: &'rf Fetch<'rf, EntitiesRes>,
phantom: PhantomData<C>,
}
impl<T, D> Storage<'_, T, D>
where
T: Component,
D: Deref<Target = MaskedStorage<T>>,
{
pub fn restrict<'rf>(&'rf self) -> RestrictedStorage<'rf, T, &T::Storage> {
RestrictedStorage {
bitset: &self.data.mask,
data: &self.data.inner,
entities: &self.entities,
phantom: PhantomData,
}
}
}
impl<T, D> Storage<'_, T, D>
where
T: Component,
D: DerefMut<Target = MaskedStorage<T>>,
{
pub fn restrict_mut<'rf>(&'rf mut self) -> RestrictedStorage<'rf, T, &mut T::Storage> {
let (mask, data) = self.data.open_mut();
RestrictedStorage {
bitset: mask,
data,
entities: &self.entities,
phantom: PhantomData,
}
}
}
#[nougat::gat]
unsafe impl<'rf, C, S> LendJoin for &'rf RestrictedStorage<'rf, C, S>
where
C: Component,
S: Borrow<C::Storage>,
{
type Mask = &'rf BitSet;
type Type<'next> = PairedStorageRead<'rf, C>;
type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(
self.bitset,
(self.data.borrow(), self.entities, self.bitset),
)
}
unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
PairedStorageRead {
index: id,
storage: value.0,
entities: value.1,
bitset: value.2,
}
}
}
unsafe impl<'rf, C, S> RepeatableLendGet for &'rf RestrictedStorage<'rf, C, S>
where
C: Component,
S: Borrow<C::Storage>,
{
}
#[nougat::gat]
unsafe impl<'rf, C, S> LendJoin for &'rf mut RestrictedStorage<'rf, C, S>
where
C: Component,
S: BorrowMut<C::Storage>,
{
type Mask = &'rf BitSet;
type Type<'next> = PairedStorageWriteExclusive<'next, C>;
type Value = (
&'rf mut C::Storage,
&'rf Fetch<'rf, EntitiesRes>,
&'rf BitSet,
);
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(
self.bitset,
(self.data.borrow_mut(), self.entities, self.bitset),
)
}
unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
PairedStorageWriteExclusive {
index: id,
storage: value.0,
entities: value.1,
bitset: value.2,
}
}
}
unsafe impl<'rf, C, S> RepeatableLendGet for &'rf mut RestrictedStorage<'rf, C, S>
where
C: Component,
S: BorrowMut<C::Storage>,
{
}
unsafe impl<'rf, C, S> Join for &'rf RestrictedStorage<'rf, C, S>
where
C: Component,
S: Borrow<C::Storage>,
{
type Mask = &'rf BitSet;
type Type = PairedStorageRead<'rf, C>;
type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(
self.bitset,
(self.data.borrow(), self.entities, self.bitset),
)
}
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
PairedStorageRead {
index: id,
storage: value.0,
entities: value.1,
bitset: value.2,
}
}
}
mod shared_get_only {
use super::{DistinctStorage, Index, SharedGetMutStorage, UnprotectedStorage};
use core::marker::PhantomData;
pub struct SharedGetOnly<'a, T, S>(&'a S, PhantomData<T>);
unsafe impl<'a, T, S> Send for SharedGetOnly<'a, T, S>
where
for<'b> &'b S: Send,
PhantomData<T>: Send,
S: DistinctStorage,
{
}
unsafe impl<'a, T, S> Sync for SharedGetOnly<'a, T, S>
where
for<'b> &'b S: Sync,
PhantomData<T>: Sync,
S: DistinctStorage,
{
}
impl<'a, T, S> SharedGetOnly<'a, T, S> {
pub(super) fn new(storage: &'a mut S) -> Self {
Self(storage, PhantomData)
}
pub(crate) fn duplicate(this: &Self) -> Self {
Self(this.0, this.1)
}
pub(super) unsafe fn get_mut(
this: &Self,
id: Index,
) -> <S as UnprotectedStorage<T>>::AccessMut<'a>
where
S: SharedGetMutStorage<T>,
{
unsafe { this.0.shared_get_mut(id) }
}
pub(super) unsafe fn get(this: &Self, id: Index) -> &'a T
where
S: UnprotectedStorage<T>,
{
unsafe { this.0.get(id) }
}
}
}
pub use shared_get_only::SharedGetOnly;
unsafe impl<'rf, C, S> Join for &'rf mut RestrictedStorage<'rf, C, S>
where
C: Component,
S: BorrowMut<C::Storage>,
C::Storage: SharedGetMutStorage<C>,
{
type Mask = &'rf BitSet;
type Type = PairedStorageWriteShared<'rf, C>;
type Value = SharedGetOnly<'rf, C, C::Storage>;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let bitset = &self.bitset;
let storage = SharedGetOnly::new(self.data.borrow_mut());
(bitset, storage)
}
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
PairedStorageWriteShared {
index: id,
storage: SharedGetOnly::duplicate(value),
}
}
}
#[cfg(feature = "parallel")]
unsafe impl<'rf, C, S> ParJoin for &'rf RestrictedStorage<'rf, C, S>
where
C: Component,
S: Borrow<C::Storage>,
C::Storage: Sync,
{
type Mask = &'rf BitSet;
type Type = PairedStorageRead<'rf, C>;
type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(
self.bitset,
(self.data.borrow(), self.entities, self.bitset),
)
}
unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
PairedStorageRead {
index: id,
storage: value.0,
entities: value.1,
bitset: value.2,
}
}
}
#[cfg(feature = "parallel")]
unsafe impl<'rf, C, S> ParJoin for &'rf mut RestrictedStorage<'rf, C, S>
where
C: Component,
S: BorrowMut<C::Storage>,
C::Storage: Sync + SharedGetMutStorage<C> + DistinctStorage,
{
type Mask = &'rf BitSet;
type Type = PairedStorageWriteShared<'rf, C>;
type Value = SharedGetOnly<'rf, C, C::Storage>;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let bitset = &self.bitset;
let storage = SharedGetOnly::new(self.data.borrow_mut());
(bitset, storage)
}
unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
PairedStorageWriteShared {
index: id,
storage: SharedGetOnly::duplicate(value),
}
}
}
pub struct PairedStorageRead<'rf, C: Component> {
index: Index,
storage: &'rf C::Storage,
bitset: &'rf BitSet,
entities: &'rf Fetch<'rf, EntitiesRes>,
}
pub struct PairedStorageWriteShared<'rf, C: Component> {
index: Index,
storage: SharedGetOnly<'rf, C, C::Storage>,
}
unsafe impl<C> Send for PairedStorageWriteShared<'_, C>
where
C: Component,
Index: Send,
for<'a> SharedGetOnly<'a, C, C::Storage>: Send,
C::Storage: DistinctStorage,
{
}
fn _dummy() {}
pub struct PairedStorageWriteExclusive<'rf, C: Component> {
index: Index,
storage: &'rf mut C::Storage,
bitset: &'rf BitSet,
entities: &'rf Fetch<'rf, EntitiesRes>,
}
impl<'rf, C> PairedStorageRead<'rf, C>
where
C: Component,
{
pub fn get(&self) -> &C {
unsafe { self.storage.get(self.index) }
}
pub fn get_other(&self, entity: Entity) -> Option<&C> {
if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
Some(unsafe { self.storage.get(entity.id()) })
} else {
None
}
}
}
impl<'rf, C> PairedStorageWriteShared<'rf, C>
where
C: Component,
C::Storage: SharedGetMutStorage<C>,
{
pub fn get(&self) -> &C {
unsafe { SharedGetOnly::get(&self.storage, self.index) }
}
pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
unsafe { SharedGetOnly::get_mut(&self.storage, self.index) }
}
}
impl<'rf, C> PairedStorageWriteExclusive<'rf, C>
where
C: Component,
{
pub fn get(&self) -> &C {
unsafe { self.storage.get(self.index) }
}
pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
unsafe { self.storage.get_mut(self.index) }
}
pub fn get_other(&self, entity: Entity) -> Option<&C> {
if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
Some(unsafe { self.storage.get(entity.id()) })
} else {
None
}
}
pub fn get_other_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, C>> {
if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
Some(unsafe { self.storage.get_mut(entity.id()) })
} else {
None
}
}
}