use hibitset::{BitIter, BitSetLike};
use shred::{Fetch, FetchMut, Read, ReadExpect, Resource, Write, WriteExpect};
use std::ops::{Deref, DerefMut};
use crate::world::Index;
mod bit_and;
mod lend_join;
mod maybe;
#[cfg(feature = "parallel")]
mod par_join;
pub use bit_and::BitAnd;
#[nougat::gat(Type)]
pub use lend_join::LendJoin;
pub use lend_join::{JoinLendIter, LendJoinType, RepeatableLendGet};
pub use maybe::MaybeJoin;
#[cfg(feature = "parallel")]
pub use par_join::{JoinParIter, ParJoin};
pub unsafe trait Join {
type Type;
type Value;
type Mask: BitSetLike;
fn join(self) -> JoinIter<Self>
where
Self: Sized,
{
JoinIter::new(self)
}
unsafe fn open(self) -> (Self::Mask, Self::Value);
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type;
#[inline]
fn is_unconstrained() -> bool {
false
}
}
#[must_use]
pub struct JoinIter<J: Join> {
keys: BitIter<J::Mask>,
values: J::Value,
}
impl<J: Join> JoinIter<J> {
pub fn new(j: J) -> Self {
if <J as Join>::is_unconstrained() {
log::warn!(
"`Join` possibly iterating through all indices, \
you might've made a join with all `MaybeJoin`s, \
which is unbounded in length."
);
}
let (keys, values) = unsafe { j.open() };
JoinIter {
keys: keys.iter(),
values,
}
}
}
impl<J: Join> std::iter::Iterator for JoinIter<J> {
type Item = J::Type;
fn next(&mut self) -> Option<J::Type> {
self.keys
.next()
.map(|idx| unsafe { J::get(&mut self.values, idx) })
}
}
macro_rules! define_open {
($($from:ident),*) => {
#[nougat::gat]
unsafe impl<$($from,)*> LendJoin for ($($from),*,)
where $($from: LendJoin),*,
($(<$from as LendJoin>::Mask,)*): BitAnd,
{
type Type<'next> = ($(<$from as LendJoin>::Type<'next>),*,);
type Value = ($($from::Value),*,);
type Mask = <($($from::Mask,)*) as BitAnd>::Value;
#[allow(non_snake_case)]
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let ($($from,)*) = self;
let ($($from,)*) = unsafe { ($($from.open(),)*) };
(
($($from.0),*,).and(),
($($from.1),*,)
)
}
#[allow(non_snake_case)]
unsafe fn get<'next>(v: &'next mut Self::Value, i: Index) -> Self::Type<'next>
{
let &mut ($(ref mut $from,)*) = v;
unsafe { ($($from::get($from, i),)*) }
}
#[inline]
fn is_unconstrained() -> bool {
let mut unconstrained = true;
$( unconstrained = unconstrained && $from::is_unconstrained(); )*
unconstrained
}
}
unsafe impl<$($from,)*> RepeatableLendGet for ($($from),*,)
where $($from: RepeatableLendGet),*,
($(<$from as LendJoin>::Mask,)*): BitAnd, {}
unsafe impl<$($from,)*> Join for ($($from),*,)
where $($from: Join),*,
($(<$from as Join>::Mask,)*): BitAnd,
{
type Type = ($($from::Type),*,);
type Value = ($($from::Value),*,);
type Mask = <($($from::Mask,)*) as BitAnd>::Value;
#[allow(non_snake_case)]
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let ($($from,)*) = self;
let ($($from,)*) = unsafe { ($($from.open(),)*) };
(
($($from.0),*,).and(),
($($from.1),*,)
)
}
#[allow(non_snake_case)]
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
let &mut ($(ref mut $from,)*) = v;
unsafe { ($($from::get($from, i),)*) }
}
#[inline]
fn is_unconstrained() -> bool {
let mut unconstrained = true;
$( unconstrained = unconstrained && $from::is_unconstrained(); )*
unconstrained
}
}
#[cfg(feature = "parallel")]
unsafe impl<$($from,)*> ParJoin for ($($from),*,)
where $($from: ParJoin),*,
($(<$from as ParJoin>::Mask,)*): BitAnd,
{
type Type = ($($from::Type),*,);
type Value = ($($from::Value),*,);
type Mask = <($($from::Mask,)*) as BitAnd>::Value;
#[allow(non_snake_case)]
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let ($($from,)*) = self;
let ($($from,)*) = unsafe { ($($from.open(),)*) };
(
($($from.0),*,).and(),
($($from.1),*,)
)
}
#[allow(non_snake_case)]
unsafe fn get(v: &Self::Value, i: Index) -> Self::Type {
let &($(ref $from,)*) = v;
unsafe { ($($from::get($from, i),)*) }
}
#[inline]
fn is_unconstrained() -> bool {
let mut unconstrained = true;
$( unconstrained = unconstrained && $from::is_unconstrained(); )*
unconstrained
}
}
}
}
define_open! {A}
define_open! {A, B}
define_open! {A, B, C}
define_open! {A, B, C, D}
define_open! {A, B, C, D, E}
define_open! {A, B, C, D, E, F}
define_open! {A, B, C, D, E, F, G}
define_open! {A, B, C, D, E, F, G, H}
define_open! {A, B, C, D, E, F, G, H, I}
define_open! {A, B, C, D, E, F, G, H, I, J}
define_open! {A, B, C, D, E, F, G, H, I, J, K}
define_open! {A, B, C, D, E, F, G, H, I, J, K, L}
define_open! {A, B, C, D, E, F, G, H, I, J, K, L, M}
define_open! {A, B, C, D, E, F, G, H, I, J, K, L, M, N}
define_open! {A, B, C, D, E, F, G, H, I, J, K, L, M, N, O}
define_open! {A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P}
define_open!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q);
define_open!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R);
macro_rules! immutable_resource_join {
($($ty:ty),*) => {
$(
#[nougat::gat]
unsafe impl<'a, 'b, T> LendJoin for &'a $ty
where
&'a T: LendJoin,
T: Resource,
{
type Type<'next> = <&'a T as LendJoin>::Type<'next>;
type Value = <&'a T as LendJoin>::Value;
type Mask = <&'a T as LendJoin>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref().open() }
}
unsafe fn get<'next>(v: &'next mut Self::Value, i: Index) -> Self::Type<'next>
{
unsafe { <&'a T as LendJoin>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a T as LendJoin>::is_unconstrained()
}
}
unsafe impl<'a, 'b, T> RepeatableLendGet for &'a $ty
where
&'a T: RepeatableLendGet,
T: Resource,
{}
unsafe impl<'a, 'b, T> Join for &'a $ty
where
&'a T: Join,
T: Resource,
{
type Type = <&'a T as Join>::Type;
type Value = <&'a T as Join>::Value;
type Mask = <&'a T as Join>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref().open() }
}
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
unsafe { <&'a T as Join>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a T as Join>::is_unconstrained()
}
}
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a $ty
where
&'a T: ParJoin,
T: Resource,
{
type Type = <&'a T as ParJoin>::Type;
type Value = <&'a T as ParJoin>::Value;
type Mask = <&'a T as ParJoin>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref().open() }
}
unsafe fn get(v: &Self::Value, i: Index) -> Self::Type {
unsafe { <&'a T as ParJoin>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a T as ParJoin>::is_unconstrained()
}
}
)*
};
}
macro_rules! mutable_resource_join {
($($ty:ty),*) => {
$(
#[nougat::gat]
unsafe impl<'a, 'b, T> LendJoin for &'a mut $ty
where
&'a mut T: LendJoin,
T: Resource,
{
type Type<'next> = <&'a mut T as LendJoin>::Type<'next>;
type Value = <&'a mut T as LendJoin>::Value;
type Mask = <&'a mut T as LendJoin>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref_mut().open() }
}
unsafe fn get<'next>(v: &'next mut Self::Value, i: Index) -> Self::Type<'next>
{
unsafe { <&'a mut T as LendJoin>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a mut T as LendJoin>::is_unconstrained()
}
}
unsafe impl<'a, 'b, T> RepeatableLendGet for &'a mut $ty
where
&'a mut T: RepeatableLendGet,
T: Resource,
{}
unsafe impl<'a, 'b, T> Join for &'a mut $ty
where
&'a mut T: Join,
T: Resource,
{
type Type = <&'a mut T as Join>::Type;
type Value = <&'a mut T as Join>::Value;
type Mask = <&'a mut T as Join>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref_mut().open() }
}
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
unsafe { <&'a mut T as Join>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a mut T as Join>::is_unconstrained()
}
}
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a mut $ty
where
&'a mut T: ParJoin,
T: Resource,
{
type Type = <&'a mut T as ParJoin>::Type;
type Value = <&'a mut T as ParJoin>::Value;
type Mask = <&'a mut T as ParJoin>::Mask;
unsafe fn open(self) -> (Self::Mask, Self::Value) {
unsafe { self.deref_mut().open() }
}
unsafe fn get(v: &Self::Value, i: Index) -> Self::Type {
unsafe { <&'a mut T as ParJoin>::get(v, i) }
}
#[inline]
fn is_unconstrained() -> bool {
<&'a mut T as ParJoin>::is_unconstrained()
}
}
)*
};
}
immutable_resource_join!(Fetch<'b, T>, Read<'b, T>, ReadExpect<'b, T>);
mutable_resource_join!(FetchMut<'b, T>, Write<'b, T>, WriteExpect<'b, T>);