Skip to main content

goggles/
masked.rs

1use std::mem;
2
3use hibitset::{BitIter, BitSet, BitSetLike};
4
5use crate::{
6    join::{Index, Join},
7    storage::{DenseStorage, RawStorage},
8    tracked::{ModifiedBitSet, TrackedStorage},
9};
10
11/// Wraps a `RawStorage` for some component with a `BitSet` mask to provide a safe, `Join`-able
12/// interface for component storage.
13pub struct MaskedStorage<S: RawStorage> {
14    mask: BitSet,
15    storage: S,
16}
17
18impl<S: RawStorage + Default> Default for MaskedStorage<S> {
19    fn default() -> Self {
20        Self {
21            mask: Default::default(),
22            storage: Default::default(),
23        }
24    }
25}
26
27impl<S: RawStorage> MaskedStorage<S> {
28    pub fn mask(&self) -> &BitSet {
29        &self.mask
30    }
31
32    pub fn raw_storage(&self) -> &S {
33        &self.storage
34    }
35
36    pub fn raw_storage_mut(&mut self) -> &mut S {
37        &mut self.storage
38    }
39
40    pub fn contains(&self, index: Index) -> bool {
41        self.mask.contains(index)
42    }
43
44    pub fn get(&self, index: Index) -> Option<&S::Item> {
45        if self.mask.contains(index) {
46            Some(unsafe { self.storage.get(index) })
47        } else {
48            None
49        }
50    }
51
52    pub fn get_mut(&mut self, index: Index) -> Option<&mut S::Item> {
53        if self.mask.contains(index) {
54            Some(unsafe { self.storage.get_mut(index) })
55        } else {
56            None
57        }
58    }
59
60    pub fn get_or_insert_with(
61        &mut self,
62        index: Index,
63        f: impl FnOnce() -> S::Item,
64    ) -> &mut S::Item {
65        if !self.mask.contains(index) {
66            self.mask.add(index);
67            unsafe { self.storage.insert(index, f()) };
68        }
69        unsafe { self.storage.get_mut(index) }
70    }
71
72    pub fn insert(&mut self, index: Index, mut v: S::Item) -> Option<S::Item> {
73        if self.mask.contains(index) {
74            mem::swap(&mut v, unsafe { self.storage.get_mut(index) });
75            Some(v)
76        } else {
77            self.mask.add(index);
78            unsafe { self.storage.insert(index, v) };
79            None
80        }
81    }
82
83    /// Update the value at this index only if it has changed.
84    ///
85    /// This is useful when combined with `FlaggedStorage`, which keeps track of modified
86    /// components.  By using this method, you can avoid flagging changes unnecessarily when the new
87    /// value of the component is equal to the old one.
88    pub fn update(&mut self, index: Index, mut v: S::Item) -> Option<S::Item>
89    where
90        S::Item: PartialEq,
91    {
92        if self.mask.contains(index) {
93            unsafe {
94                if &v != self.storage.get(index) {
95                    mem::swap(&mut v, self.storage.get_mut(index));
96                }
97            }
98            Some(v)
99        } else {
100            None
101        }
102    }
103
104    pub fn remove(&mut self, index: Index) -> Option<S::Item> {
105        if self.mask.remove(index) {
106            Some(unsafe { self.storage.remove(index) })
107        } else {
108            None
109        }
110    }
111
112    /// Returns an `IntoJoin` type whose values are `GuardedJoin` wrappers.
113    ///
114    /// A `GuardedJoin` wrapper does not automatically call `RawStorage::get_mut`, so it can be
115    /// useful to avoid flagging modifications with a `FlaggedStorage`.
116    pub fn guard(&mut self) -> GuardedJoin<S> {
117        GuardedJoin(self)
118    }
119}
120
121impl<S: DenseStorage> MaskedStorage<S> {
122    pub fn as_slice(&self) -> &[S::Item] {
123        self.storage.as_slice()
124    }
125
126    pub fn as_mut_slice(&mut self) -> &mut [S::Item] {
127        self.storage.as_mut_slice()
128    }
129}
130
131impl<S: TrackedStorage> MaskedStorage<S> {
132    pub fn tracking_modified(&self) -> bool {
133        self.storage.tracking_modified()
134    }
135
136    pub fn modified_indexes(&self) -> &ModifiedBitSet {
137        self.storage.modified_indexes()
138    }
139
140    pub fn set_track_modified(&mut self, flag: bool) {
141        self.storage.set_track_modified(flag);
142    }
143
144    pub fn mark_modified(&self, index: Index) {
145        self.storage.mark_modified(index);
146    }
147
148    pub fn clear_modified(&mut self) {
149        self.storage.clear_modified();
150    }
151
152    /// Returns an `IntoJoin` type which joins over all the modified elements.
153    ///
154    /// The items on the returned join are all `Option<&S::Item>`, removed elements will show up as
155    /// None.
156    pub fn modified(&self) -> ModifiedJoin<S> {
157        ModifiedJoin(self)
158    }
159
160    /// Returns an `IntoJoin` type which joins over all the modified elements mutably.
161    ///
162    /// This is similar to `MaskedStorage::modified`, but returns mutable access to each item.
163    pub fn modified_mut(&mut self) -> ModifiedJoinMut<S> {
164        ModifiedJoinMut(self)
165    }
166}
167
168impl<'a, S: RawStorage> Join for &'a MaskedStorage<S> {
169    type Item = &'a S::Item;
170    type Access = &'a S;
171    type Mask = &'a BitSet;
172
173    fn open(self) -> (Self::Mask, Self::Access) {
174        (&self.mask, &self.storage)
175    }
176
177    unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
178        access.get(index)
179    }
180}
181
182impl<'a, S: RawStorage> Join for &'a mut MaskedStorage<S> {
183    type Item = &'a mut S::Item;
184    type Access = &'a S;
185    type Mask = &'a BitSet;
186
187    fn open(self) -> (Self::Mask, Self::Access) {
188        (&self.mask, &self.storage)
189    }
190
191    unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
192        access.get_mut(index)
193    }
194}
195
196impl<S: RawStorage> Drop for MaskedStorage<S> {
197    fn drop(&mut self) {
198        struct DropGuard<'a, 'b, S: RawStorage>(Option<&'b mut BitIter<&'a BitSet>>, &'b mut S);
199
200        impl<'a, 'b, S: RawStorage> Drop for DropGuard<'a, 'b, S> {
201            fn drop(&mut self) {
202                if let Some(iter) = self.0.take() {
203                    let mut guard: DropGuard<S> = DropGuard(Some(&mut *iter), &mut *self.1);
204                    while let Some(index) = guard.0.as_mut().unwrap().next() {
205                        unsafe { S::remove(&mut guard.1, index) };
206                    }
207                    guard.0 = None;
208                }
209            }
210        }
211
212        let mut iter = (&self.mask).iter();
213        DropGuard::<S>(Some(&mut iter), &mut self.storage);
214    }
215}
216
217pub struct GuardedJoin<'a, S: RawStorage>(&'a mut MaskedStorage<S>);
218
219impl<'a, S: RawStorage> Join for GuardedJoin<'a, S> {
220    type Item = ElementGuard<'a, S>;
221    type Access = &'a S;
222    type Mask = &'a BitSet;
223
224    fn open(self) -> (Self::Mask, Self::Access) {
225        (&self.0.mask, &self.0.storage)
226    }
227
228    unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
229        ElementGuard {
230            storage: *access,
231            index,
232        }
233    }
234}
235
236pub struct ElementGuard<'a, S> {
237    storage: &'a S,
238    index: Index,
239}
240
241impl<'a, S: RawStorage> ElementGuard<'a, S> {
242    pub fn get(&self) -> &'a S::Item {
243        unsafe { self.storage.get(self.index) }
244    }
245
246    pub fn get_mut(&mut self) -> &'a mut S::Item {
247        unsafe { self.storage.get_mut(self.index) }
248    }
249
250    pub fn update(&mut self, mut v: S::Item) -> S::Item
251    where
252        S::Item: PartialEq,
253    {
254        unsafe {
255            if &v != self.storage.get(self.index) {
256                mem::swap(&mut v, self.storage.get_mut(self.index));
257            }
258            v
259        }
260    }
261}
262
263impl<'a, S: TrackedStorage> ElementGuard<'a, S> {
264    pub fn mark_modified(&self) {
265        self.storage.mark_modified(self.index);
266    }
267}
268
269pub struct ModifiedJoin<'a, S: RawStorage>(&'a MaskedStorage<S>);
270
271impl<'a, S: TrackedStorage> Join for ModifiedJoin<'a, S> {
272    type Item = Option<&'a S::Item>;
273    type Access = (&'a BitSet, &'a S);
274    type Mask = &'a ModifiedBitSet;
275
276    fn open(self) -> (Self::Mask, Self::Access) {
277        (
278            &self.0.storage.modified_indexes(),
279            (&self.0.mask, &self.0.storage),
280        )
281    }
282
283    unsafe fn get((mask, storage): &Self::Access, index: Index) -> Self::Item {
284        if mask.contains(index) {
285            Some(storage.get(index))
286        } else {
287            None
288        }
289    }
290}
291
292pub struct ModifiedJoinMut<'a, S: RawStorage>(&'a mut MaskedStorage<S>);
293
294impl<'a, S: TrackedStorage> Join for ModifiedJoinMut<'a, S> {
295    type Item = Option<&'a mut S::Item>;
296    type Access = (&'a BitSet, &'a S);
297    type Mask = &'a ModifiedBitSet;
298
299    fn open(self) -> (Self::Mask, Self::Access) {
300        (
301            &self.0.storage.modified_indexes(),
302            (&self.0.mask, &self.0.storage),
303        )
304    }
305
306    unsafe fn get((mask, storage): &Self::Access, index: Index) -> Self::Item {
307        if mask.contains(index) {
308            Some(storage.get_mut(index))
309        } else {
310            None
311        }
312    }
313}