1use std::mem;
2
3use hibitset::{BitIter, BitSet, BitSetLike};
4
5use crate::{
6 join::{Index, Join},
7 storage::{DenseStorage, RawStorage},
8 tracked::{ModifiedBitSet, TrackedStorage},
9};
10
11pub struct MaskedStorage<S: RawStorage> {
14 mask: BitSet,
15 storage: S,
16}
17
18impl<S: RawStorage + Default> Default for MaskedStorage<S> {
19 fn default() -> Self {
20 Self {
21 mask: Default::default(),
22 storage: Default::default(),
23 }
24 }
25}
26
27impl<S: RawStorage> MaskedStorage<S> {
28 pub fn mask(&self) -> &BitSet {
29 &self.mask
30 }
31
32 pub fn raw_storage(&self) -> &S {
33 &self.storage
34 }
35
36 pub fn raw_storage_mut(&mut self) -> &mut S {
37 &mut self.storage
38 }
39
40 pub fn contains(&self, index: Index) -> bool {
41 self.mask.contains(index)
42 }
43
44 pub fn get(&self, index: Index) -> Option<&S::Item> {
45 if self.mask.contains(index) {
46 Some(unsafe { self.storage.get(index) })
47 } else {
48 None
49 }
50 }
51
52 pub fn get_mut(&mut self, index: Index) -> Option<&mut S::Item> {
53 if self.mask.contains(index) {
54 Some(unsafe { self.storage.get_mut(index) })
55 } else {
56 None
57 }
58 }
59
60 pub fn get_or_insert_with(
61 &mut self,
62 index: Index,
63 f: impl FnOnce() -> S::Item,
64 ) -> &mut S::Item {
65 if !self.mask.contains(index) {
66 self.mask.add(index);
67 unsafe { self.storage.insert(index, f()) };
68 }
69 unsafe { self.storage.get_mut(index) }
70 }
71
72 pub fn insert(&mut self, index: Index, mut v: S::Item) -> Option<S::Item> {
73 if self.mask.contains(index) {
74 mem::swap(&mut v, unsafe { self.storage.get_mut(index) });
75 Some(v)
76 } else {
77 self.mask.add(index);
78 unsafe { self.storage.insert(index, v) };
79 None
80 }
81 }
82
83 pub fn update(&mut self, index: Index, mut v: S::Item) -> Option<S::Item>
89 where
90 S::Item: PartialEq,
91 {
92 if self.mask.contains(index) {
93 unsafe {
94 if &v != self.storage.get(index) {
95 mem::swap(&mut v, self.storage.get_mut(index));
96 }
97 }
98 Some(v)
99 } else {
100 None
101 }
102 }
103
104 pub fn remove(&mut self, index: Index) -> Option<S::Item> {
105 if self.mask.remove(index) {
106 Some(unsafe { self.storage.remove(index) })
107 } else {
108 None
109 }
110 }
111
112 pub fn guard(&mut self) -> GuardedJoin<S> {
117 GuardedJoin(self)
118 }
119}
120
121impl<S: DenseStorage> MaskedStorage<S> {
122 pub fn as_slice(&self) -> &[S::Item] {
123 self.storage.as_slice()
124 }
125
126 pub fn as_mut_slice(&mut self) -> &mut [S::Item] {
127 self.storage.as_mut_slice()
128 }
129}
130
131impl<S: TrackedStorage> MaskedStorage<S> {
132 pub fn tracking_modified(&self) -> bool {
133 self.storage.tracking_modified()
134 }
135
136 pub fn modified_indexes(&self) -> &ModifiedBitSet {
137 self.storage.modified_indexes()
138 }
139
140 pub fn set_track_modified(&mut self, flag: bool) {
141 self.storage.set_track_modified(flag);
142 }
143
144 pub fn mark_modified(&self, index: Index) {
145 self.storage.mark_modified(index);
146 }
147
148 pub fn clear_modified(&mut self) {
149 self.storage.clear_modified();
150 }
151
152 pub fn modified(&self) -> ModifiedJoin<S> {
157 ModifiedJoin(self)
158 }
159
160 pub fn modified_mut(&mut self) -> ModifiedJoinMut<S> {
164 ModifiedJoinMut(self)
165 }
166}
167
168impl<'a, S: RawStorage> Join for &'a MaskedStorage<S> {
169 type Item = &'a S::Item;
170 type Access = &'a S;
171 type Mask = &'a BitSet;
172
173 fn open(self) -> (Self::Mask, Self::Access) {
174 (&self.mask, &self.storage)
175 }
176
177 unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
178 access.get(index)
179 }
180}
181
182impl<'a, S: RawStorage> Join for &'a mut MaskedStorage<S> {
183 type Item = &'a mut S::Item;
184 type Access = &'a S;
185 type Mask = &'a BitSet;
186
187 fn open(self) -> (Self::Mask, Self::Access) {
188 (&self.mask, &self.storage)
189 }
190
191 unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
192 access.get_mut(index)
193 }
194}
195
196impl<S: RawStorage> Drop for MaskedStorage<S> {
197 fn drop(&mut self) {
198 struct DropGuard<'a, 'b, S: RawStorage>(Option<&'b mut BitIter<&'a BitSet>>, &'b mut S);
199
200 impl<'a, 'b, S: RawStorage> Drop for DropGuard<'a, 'b, S> {
201 fn drop(&mut self) {
202 if let Some(iter) = self.0.take() {
203 let mut guard: DropGuard<S> = DropGuard(Some(&mut *iter), &mut *self.1);
204 while let Some(index) = guard.0.as_mut().unwrap().next() {
205 unsafe { S::remove(&mut guard.1, index) };
206 }
207 guard.0 = None;
208 }
209 }
210 }
211
212 let mut iter = (&self.mask).iter();
213 DropGuard::<S>(Some(&mut iter), &mut self.storage);
214 }
215}
216
217pub struct GuardedJoin<'a, S: RawStorage>(&'a mut MaskedStorage<S>);
218
219impl<'a, S: RawStorage> Join for GuardedJoin<'a, S> {
220 type Item = ElementGuard<'a, S>;
221 type Access = &'a S;
222 type Mask = &'a BitSet;
223
224 fn open(self) -> (Self::Mask, Self::Access) {
225 (&self.0.mask, &self.0.storage)
226 }
227
228 unsafe fn get(access: &Self::Access, index: Index) -> Self::Item {
229 ElementGuard {
230 storage: *access,
231 index,
232 }
233 }
234}
235
236pub struct ElementGuard<'a, S> {
237 storage: &'a S,
238 index: Index,
239}
240
241impl<'a, S: RawStorage> ElementGuard<'a, S> {
242 pub fn get(&self) -> &'a S::Item {
243 unsafe { self.storage.get(self.index) }
244 }
245
246 pub fn get_mut(&mut self) -> &'a mut S::Item {
247 unsafe { self.storage.get_mut(self.index) }
248 }
249
250 pub fn update(&mut self, mut v: S::Item) -> S::Item
251 where
252 S::Item: PartialEq,
253 {
254 unsafe {
255 if &v != self.storage.get(self.index) {
256 mem::swap(&mut v, self.storage.get_mut(self.index));
257 }
258 v
259 }
260 }
261}
262
263impl<'a, S: TrackedStorage> ElementGuard<'a, S> {
264 pub fn mark_modified(&self) {
265 self.storage.mark_modified(self.index);
266 }
267}
268
269pub struct ModifiedJoin<'a, S: RawStorage>(&'a MaskedStorage<S>);
270
271impl<'a, S: TrackedStorage> Join for ModifiedJoin<'a, S> {
272 type Item = Option<&'a S::Item>;
273 type Access = (&'a BitSet, &'a S);
274 type Mask = &'a ModifiedBitSet;
275
276 fn open(self) -> (Self::Mask, Self::Access) {
277 (
278 &self.0.storage.modified_indexes(),
279 (&self.0.mask, &self.0.storage),
280 )
281 }
282
283 unsafe fn get((mask, storage): &Self::Access, index: Index) -> Self::Item {
284 if mask.contains(index) {
285 Some(storage.get(index))
286 } else {
287 None
288 }
289 }
290}
291
292pub struct ModifiedJoinMut<'a, S: RawStorage>(&'a mut MaskedStorage<S>);
293
294impl<'a, S: TrackedStorage> Join for ModifiedJoinMut<'a, S> {
295 type Item = Option<&'a mut S::Item>;
296 type Access = (&'a BitSet, &'a S);
297 type Mask = &'a ModifiedBitSet;
298
299 fn open(self) -> (Self::Mask, Self::Access) {
300 (
301 &self.0.storage.modified_indexes(),
302 (&self.0.mask, &self.0.storage),
303 )
304 }
305
306 unsafe fn get((mask, storage): &Self::Access, index: Index) -> Self::Item {
307 if mask.contains(index) {
308 Some(storage.get_mut(index))
309 } else {
310 None
311 }
312 }
313}