latticeon 0.1.0

A math and ECS library focused on easy academic reproduction of animation, physics simulation, and AI
Documentation
//! Query: iterate over entities matching a component set.

use crate::ecs::archetype::{Archetype, ArchetypeId};
use crate::ecs::component::{Component, ComponentIdRegistry};
use crate::ecs::entity::Entity;
use crate::ecs::universe::Universe;

/// Required component mask and matching for a query.
/// Matching is done via bitvector AND: `archetype_id.contains_all(query_mask)`.
pub trait QueryFilter {
    fn query_mask(registry: &ComponentIdRegistry) -> ArchetypeId;
}

impl QueryFilter for () {
    fn query_mask(_registry: &ComponentIdRegistry) -> ArchetypeId {
        ArchetypeId::empty()
    }
}

macro_rules! impl_query_filter {
    ($( ($($T:ident),+) ),+) => {
        $(
            #[allow(unused_parens)]
            impl<$($T: Component),+> QueryFilter for ($(& $T),+) {
                fn query_mask(registry: &ComponentIdRegistry) -> ArchetypeId {
                    let mut mask = ArchetypeId::empty();
                    $(mask.set(registry.id_for::<$T>());)+
                    mask
                }
            }
            #[allow(unused_parens)]
            impl<$($T: Component),+> QueryFilter for ($(&mut $T),+) {
                fn query_mask(registry: &ComponentIdRegistry) -> ArchetypeId {
                    let mut mask = ArchetypeId::empty();
                    $(mask.set(registry.id_for::<$T>());)+
                    mask
                }
            }
        )+
    };
}

impl_query_filter!(
    (A), (A, B), (A, B, C), (A, B, C, D),
    (A, B, C, D, E), (A, B, C, D, E, F),
    (A, B, C, D, E, F, G), (A, B, C, D, E, F, G, H),
    (A, B, C, D, E, F, G, H, I), (A, B, C, D, E, F, G, H, I, J),
    (A, B, C, D, E, F, G, H, I, J, K), (A, B, C, D, E, F, G, H, I, J, K, L),
    (A, B, C, D, E, F, G, H, I, J, K, L, M), (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
    (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)
);

impl<A: Component, B: Component> QueryFilter for (&A, &mut B) {
    fn query_mask(registry: &ComponentIdRegistry) -> ArchetypeId {
        let mut mask = ArchetypeId::empty();
        mask.set(registry.id_for::<A>());
        mask.set(registry.id_for::<B>());
        mask
    }
}

impl<A: Component, B: Component> QueryFilter for (&mut A, &B) {
    fn query_mask(registry: &ComponentIdRegistry) -> ArchetypeId {
        let mut mask = ArchetypeId::empty();
        mask.set(registry.id_for::<A>());
        mask.set(registry.id_for::<B>());
        mask
    }
}

/// Read-only fetch from an archetype at a given row.
pub trait QueryFetch {
    type Item<'a>;
    fn fetch<'a>(arch: &'a Archetype, index: usize, registry: &ComponentIdRegistry) -> Option<Self::Item<'a>>;
}

macro_rules! impl_query_fetch {
    ($($T:ident),+) => {
        #[allow(unused_parens, non_snake_case)]
        impl<$($T: Component),+> QueryFetch for ($(&$T),+) {
            type Item<'a> = (Entity, $(&'a $T),+);
            fn fetch<'a>(
                arch: &'a Archetype, index: usize, registry: &ComponentIdRegistry,
            ) -> Option<Self::Item<'a>> {
                let entity = arch.entity_at(index)?;
                $(let $T = arch.get_comp::<$T>(index, registry)?;)+
                Some((entity, $($T),+))
            }
        }
    };
}

impl_query_fetch!(A);
impl_query_fetch!(A, B);
impl_query_fetch!(A, B, C);
impl_query_fetch!(A, B, C, D);
impl_query_fetch!(A, B, C, D, E);
impl_query_fetch!(A, B, C, D, E, F);
impl_query_fetch!(A, B, C, D, E, F, G);
impl_query_fetch!(A, B, C, D, E, F, G, H);
impl_query_fetch!(A, B, C, D, E, F, G, H, I);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K, L);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K, L, M);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
impl_query_fetch!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);

/// Mutable fetch: requires `&mut Archetype`.
pub trait QueryFetchMut {
    type Item<'a>;
    fn fetch_mut<'a>(arch: &'a mut Archetype, index: usize, registry: &ComponentIdRegistry) -> Option<Self::Item<'a>>;
}

macro_rules! impl_query_fetch_mut_all {
    ($($T:ident),+) => {
        #[allow(unused_parens, non_snake_case)]
        impl<$($T: Component),+> QueryFetchMut for ($(&mut $T),+) {
            type Item<'a> = (Entity, $(&'a mut $T),+);
            fn fetch_mut<'a>(
                arch: &'a mut Archetype, index: usize, registry: &ComponentIdRegistry,
            ) -> Option<Self::Item<'a>> {
                let entity = arch.entity_at(index)?;
                let arch_ptr = arch as *mut Archetype;
                $(let $T = unsafe { (*arch_ptr).get_comp_mut::<$T>(index, registry)? as *mut $T };)+
                Some((entity, $(unsafe { &mut *$T }),+))
            }
        }
    };
}

impl_query_fetch_mut_all!(A);
impl_query_fetch_mut_all!(A, B);
impl_query_fetch_mut_all!(A, B, C);
impl_query_fetch_mut_all!(A, B, C, D);
impl_query_fetch_mut_all!(A, B, C, D, E);
impl_query_fetch_mut_all!(A, B, C, D, E, F);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K, L);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K, L, M);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
impl_query_fetch_mut_all!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);

// Mixed mutability impls for 2-tuple (the most common case).
impl<A: Component, B: Component> QueryFetchMut for (&A, &mut B) {
    type Item<'a> = (Entity, &'a A, &'a mut B);
    fn fetch_mut<'a>(arch: &'a mut Archetype, index: usize, registry: &ComponentIdRegistry) -> Option<Self::Item<'a>> {
        let entity = arch.entity_at(index)?;
        let arch_ptr = arch as *mut Archetype;
        let a = unsafe { (*arch_ptr).get_comp::<A>(index, registry)? as *const A };
        let b = unsafe { (*arch_ptr).get_comp_mut::<B>(index, registry)? as *mut B };
        Some((entity, unsafe { &*a }, unsafe { &mut *b }))
    }
}

impl<A: Component, B: Component> QueryFetchMut for (&mut A, &B) {
    type Item<'a> = (Entity, &'a mut A, &'a B);
    fn fetch_mut<'a>(arch: &'a mut Archetype, index: usize, registry: &ComponentIdRegistry) -> Option<Self::Item<'a>> {
        let entity = arch.entity_at(index)?;
        let arch_ptr = arch as *mut Archetype;
        let a = unsafe { (*arch_ptr).get_comp_mut::<A>(index, registry)? as *mut A };
        let b = unsafe { (*arch_ptr).get_comp::<B>(index, registry)? as *const B };
        Some((entity, unsafe { &mut *a }, unsafe { &*b }))
    }
}

/// Read-only query iterator.
pub struct QueryIter<'q, Q: QueryFilter + QueryFetch> {
    archetypes: Vec<&'q Archetype>,
    registry: &'q ComponentIdRegistry,
    arch_idx: usize,
    row: usize,
    _marker: std::marker::PhantomData<Q>,
}

impl<'q, Q: QueryFilter + QueryFetch> QueryIter<'q, Q> {
    pub fn new(universe: &'q Universe) -> Self {
        let mask = Q::query_mask(universe.registry());
        let archetypes: Vec<&Archetype> = universe
            .archetypes()
            .filter(|a| a.id().contains_all(&mask))
            .collect();
        Self {
            archetypes,
            registry: universe.registry(),
            arch_idx: 0,
            row: 0,
            _marker: std::marker::PhantomData,
        }
    }
}

impl<'q, Q: QueryFilter + QueryFetch> Iterator for QueryIter<'q, Q> {
    type Item = Q::Item<'q>;

    fn next(&mut self) -> Option<Self::Item> {
        loop {
            let arch = self.archetypes.get(self.arch_idx)?;
            if self.row < arch.len() {
                let item = Q::fetch(arch, self.row, self.registry);
                self.row += 1;
                return item;
            }
            self.arch_idx += 1;
            self.row = 0;
        }
    }
}

/// Mutable query: use `for_each` to iterate.
pub struct QueryIterMut<'q, Q: QueryFilter + QueryFetchMut> {
    universe: &'q mut Universe,
    arch_ids: Vec<ArchetypeId>,
    arch_idx: usize,
    row: usize,
    _marker: std::marker::PhantomData<Q>,
}

impl<'q, Q: QueryFilter + QueryFetchMut> QueryIterMut<'q, Q> {
    pub fn new(universe: &'q mut Universe) -> Self {
        let mask = Q::query_mask(universe.registry());
        let arch_ids: Vec<ArchetypeId> = universe
            .archetypes()
            .filter(|a| a.id().contains_all(&mask))
            .map(|a| a.id().clone())
            .collect();
        Self {
            universe,
            arch_ids,
            arch_idx: 0,
            row: 0,
            _marker: std::marker::PhantomData,
        }
    }

    pub fn for_each<F>(mut self, mut f: F)
    where
        F: FnMut(&mut Universe, Entity),
    {
        while self.arch_idx < self.arch_ids.len() {
            let arch_id = &self.arch_ids[self.arch_idx];
            let (entity, next_row) = {
                let arch = match self.universe.get_archetype_mut(arch_id) {
                    Some(a) => a,
                    None => {
                        self.arch_idx += 1;
                        self.row = 0;
                        continue;
                    }
                };
                if self.row >= arch.len() {
                    self.arch_idx += 1;
                    self.row = 0;
                    continue;
                }
                let entity = arch.entity_at(self.row).unwrap();
                let next_row = self.row + 1;
                (entity, next_row)
            };
            self.row = next_row;
            f(&mut self.universe, entity);
        }
    }
}

/// Extension methods on Universe for querying.
impl Universe {
    pub fn query<Q: QueryFilter + QueryFetch>(&self) -> QueryIter<'_, Q> {
        QueryIter::new(self)
    }

    pub fn query_mut<Q: QueryFilter + QueryFetchMut>(&mut self) -> QueryIterMut<'_, Q> {
        QueryIterMut::new(self)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ecs::component::Component;

    #[derive(Debug, Clone, PartialEq)]
    struct Pos(i32, i32);
    impl Component for Pos {}

    #[derive(Debug, Clone, PartialEq)]
    struct Vel(i32, i32);
    impl Component for Vel {}

    #[test]
    fn query_iter() {
        let mut u = Universe::new();
        u.spawn((Pos(1, 2), Vel(3, 4)));
        u.spawn((Pos(5, 6), Vel(7, 8)));
        let count = u.query::<(&Pos, &Vel)>().count();
        assert_eq!(count, 2);
    }

    #[test]
    fn query_mut_for_each() {
        let mut u = Universe::new();
        u.spawn((Pos(1, 2), Vel(3, 4)));
        u.query_mut::<(&Pos, &mut Vel)>().for_each(|universe, e| {
            if let Some(v) = universe.get_mut::<Vel>(e) {
                v.0 += 1;
            }
        });
        assert_eq!(
            u.query::<(&Pos, &Vel)>().next().map(|(_, _, v)| v.0),
            Some(4)
        );
    }
}