Skip to main content

fyrox_core/pool/
multiborrow.rs

1// Copyright (c) 2019-present Dmitry Stepanov and Fyrox Engine contributors.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19// SOFTWARE.
20
21use super::{Handle, ObjectOrVariant, PayloadContainer, Pool, PoolError, RefCounter};
22use crate::ComponentProvider;
23use std::{
24    any::TypeId,
25    cell::RefCell,
26    cmp::Ordering,
27    fmt::{Debug, Formatter},
28    marker::PhantomData,
29    ops::{Deref, DerefMut},
30};
31
32pub struct Ref<'a, 'b, T>
33where
34    T: ?Sized,
35{
36    data: &'a T,
37    ref_counter: &'a RefCounter,
38    phantom: PhantomData<&'b ()>,
39}
40
41impl<T> Debug for Ref<'_, '_, T>
42where
43    T: ?Sized + Debug,
44{
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        Debug::fmt(&self.data, f)
47    }
48}
49
50impl<T> Deref for Ref<'_, '_, T>
51where
52    T: ?Sized,
53{
54    type Target = T;
55
56    fn deref(&self) -> &Self::Target {
57        self.data
58    }
59}
60
61impl<T> Drop for Ref<'_, '_, T>
62where
63    T: ?Sized,
64{
65    fn drop(&mut self) {
66        // SAFETY: This is safe, because this ref lifetime is managed by the borrow checker,
67        // so it cannot outlive the pool record.
68        unsafe {
69            self.ref_counter.decrement();
70        }
71    }
72}
73
74pub struct RefMut<'a, 'b, T>
75where
76    T: ?Sized,
77{
78    data: &'a mut T,
79    ref_counter: &'a RefCounter,
80    phantom: PhantomData<&'b ()>,
81}
82
83impl<T> Debug for RefMut<'_, '_, T>
84where
85    T: ?Sized + Debug,
86{
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        Debug::fmt(&self.data, f)
89    }
90}
91
92impl<T> Deref for RefMut<'_, '_, T>
93where
94    T: ?Sized,
95{
96    type Target = T;
97
98    fn deref(&self) -> &Self::Target {
99        self.data
100    }
101}
102
103impl<T> DerefMut for RefMut<'_, '_, T>
104where
105    T: ?Sized,
106{
107    fn deref_mut(&mut self) -> &mut Self::Target {
108        self.data
109    }
110}
111
112impl<T> Drop for RefMut<'_, '_, T>
113where
114    T: ?Sized,
115{
116    fn drop(&mut self) {
117        // SAFETY: This is safe, because this ref lifetime is managed by the borrow checker,
118        // so it cannot outlive the pool record.
119        unsafe {
120            self.ref_counter.increment();
121        }
122    }
123}
124
125/// Multi-borrow context allows you to get as many **unique** references to elements in
126/// a pool as you want.
127pub struct MultiBorrowContext<'a, T, P = Option<T>>
128where
129    T: Sized,
130    P: PayloadContainer<Element = T> + 'static,
131{
132    pool: &'a mut Pool<T, P>,
133    free_indices: RefCell<Vec<u32>>,
134}
135
136impl<T, P> Drop for MultiBorrowContext<'_, T, P>
137where
138    T: Sized,
139    P: PayloadContainer<Element = T> + 'static,
140{
141    fn drop(&mut self) {
142        self.pool
143            .free_stack
144            .extend_from_slice(&self.free_indices.borrow())
145    }
146}
147
148impl<'a, T, P> MultiBorrowContext<'a, T, P>
149where
150    T: Sized,
151    P: PayloadContainer<Element = T> + 'static,
152{
153    #[inline]
154    pub fn new(pool: &'a mut Pool<T, P>) -> Self {
155        Self {
156            pool,
157            free_indices: Default::default(),
158        }
159    }
160
161    #[inline]
162    fn try_get_internal<'b: 'a, C, F>(
163        &'b self,
164        handle: Handle<T>,
165        func: F,
166    ) -> Result<Ref<'a, 'b, C>, PoolError>
167    where
168        C: ?Sized,
169        F: FnOnce(&T) -> Result<&C, PoolError>,
170    {
171        let record = self.pool.records_get(handle.index)?;
172
173        if handle.generation != record.generation {
174            return Err(PoolError::InvalidGeneration(handle.generation));
175        }
176
177        let current_ref_count = unsafe { record.ref_counter.get() };
178        if current_ref_count < 0 {
179            return Err(PoolError::MutablyBorrowed(handle.into()));
180        }
181
182        // SAFETY: We've enforced borrowing rules by the previous check.
183        let payload_container = unsafe { &*record.payload.0.get() };
184
185        let Some(payload) = payload_container.as_ref() else {
186            return Err(PoolError::Empty(handle.into()));
187        };
188
189        unsafe {
190            record.ref_counter.increment();
191        }
192
193        Ok(Ref {
194            data: func(payload)?,
195            ref_counter: &record.ref_counter,
196            phantom: PhantomData,
197        })
198    }
199
200    /// Tries to get a mutable reference to a pool element located at the given handle. The method could
201    /// fail in the two main reasons:
202    ///
203    /// 1) A reference to an element is already taken - returning multiple mutable references to the
204    /// same element is forbidden by Rust safety rules.
205    /// 2) A given handle is invalid.
206    #[inline]
207    pub fn try_get<'b, U>(&'b self, handle: Handle<U>) -> Result<Ref<'a, 'b, U>, PoolError>
208    where
209        'b: 'a,
210        U: ObjectOrVariant<T>,
211    {
212        self.try_get_internal(handle.to_base(), |obj| {
213            U::convert_to_dest_type(obj).ok_or(PoolError::InvalidType(handle.into()))
214        })
215    }
216
217    #[inline]
218    pub fn get<'b, U>(&'b self, handle: Handle<U>) -> Ref<'a, 'b, U>
219    where
220        'b: 'a,
221        U: ObjectOrVariant<T>,
222    {
223        self.try_get(handle).unwrap()
224    }
225
226    #[inline]
227    fn try_get_mut_internal<'b: 'a, C, F>(
228        &'b self,
229        handle: Handle<T>,
230        func: F,
231    ) -> Result<RefMut<'a, 'b, C>, PoolError>
232    where
233        C: ?Sized,
234        F: FnOnce(&mut T) -> Result<&mut C, PoolError>,
235    {
236        let record = self.pool.records_get(handle.index)?;
237
238        if handle.generation != record.generation {
239            return Err(PoolError::InvalidGeneration(handle.generation));
240        }
241
242        // SAFETY: It is safe to access the counter because of borrow checker guarantees that
243        // the record is alive.
244        let current_ref_count = unsafe { record.ref_counter.get() };
245        match current_ref_count.cmp(&0) {
246            Ordering::Less => {
247                return Err(PoolError::MutablyBorrowed(handle.into()));
248            }
249            Ordering::Greater => {
250                return Err(PoolError::ImmutablyBorrowed(handle.into()));
251            }
252            _ => (),
253        }
254
255        // SAFETY: We've enforced borrowing rules by the previous check.
256        let payload_container = unsafe { &mut *record.payload.0.get() };
257
258        let Some(payload) = payload_container.as_mut() else {
259            return Err(PoolError::Empty(handle.into()));
260        };
261
262        // SAFETY: It is safe to access the counter because of borrow checker guarantees that
263        // the record is alive.
264        unsafe {
265            record.ref_counter.decrement();
266        }
267
268        Ok(RefMut {
269            data: func(payload)?,
270            ref_counter: &record.ref_counter,
271            phantom: PhantomData,
272        })
273    }
274
275    #[inline]
276    pub fn try_get_mut<'b, U>(&'b self, handle: Handle<U>) -> Result<RefMut<'a, 'b, U>, PoolError>
277    where
278        'b: 'a,
279        U: ObjectOrVariant<T>,
280    {
281        self.try_get_mut_internal(handle.to_base(), |obj| {
282            U::convert_to_dest_type_mut(obj).ok_or(PoolError::InvalidType(handle.into()))
283        })
284    }
285
286    #[inline]
287    pub fn get_mut<'b, U>(&'b self, handle: Handle<U>) -> RefMut<'a, 'b, U>
288    where
289        'b: 'a,
290        U: ObjectOrVariant<T>,
291    {
292        self.try_get_mut(handle).unwrap()
293    }
294
295    #[inline]
296    pub fn free(&self, handle: Handle<T>) -> Result<T, PoolError> {
297        let record = self.pool.records_get(handle.index)?;
298
299        if handle.generation != record.generation {
300            return Err(PoolError::InvalidGeneration(handle.generation));
301        }
302
303        // The record must be non-borrowed to be freed.
304        // SAFETY: It is safe to access the counter because of borrow checker guarantees that
305        // the record is alive.
306        let current_ref_count = unsafe { record.ref_counter.get() };
307        match current_ref_count.cmp(&0) {
308            Ordering::Less => {
309                return Err(PoolError::MutablyBorrowed(handle.into()));
310            }
311            Ordering::Greater => {
312                return Err(PoolError::ImmutablyBorrowed(handle.into()));
313            }
314            _ => (),
315        }
316
317        // SAFETY: We've enforced borrowing rules by the previous check.
318        let payload_container = unsafe { &mut *record.payload.0.get() };
319
320        let Some(payload) = payload_container.take() else {
321            return Err(PoolError::Empty(handle.into()));
322        };
323
324        self.free_indices.borrow_mut().push(handle.index);
325
326        Ok(payload)
327    }
328}
329
330impl<'a, T, P> MultiBorrowContext<'a, T, P>
331where
332    T: Sized + ComponentProvider,
333    P: PayloadContainer<Element = T> + 'static,
334{
335    /// Tries to mutably borrow an object and fetch its component of specified type.
336    #[inline]
337    pub fn try_get_component_of_type<'b: 'a, C>(
338        &'b self,
339        handle: Handle<T>,
340    ) -> Result<Ref<'a, 'b, C>, PoolError>
341    where
342        C: 'static,
343    {
344        self.try_get_internal(handle, move |obj| {
345            obj.query_component_ref(TypeId::of::<C>())
346                .and_then(|c| c.downcast_ref())
347                .ok_or(PoolError::NoSuchComponent(handle.into()))
348        })
349    }
350
351    /// Tries to mutably borrow an object and fetch its component of specified type.
352    #[inline]
353    pub fn try_get_component_of_type_mut<'b: 'a, C>(
354        &'b self,
355        handle: Handle<T>,
356    ) -> Result<RefMut<'a, 'b, C>, PoolError>
357    where
358        C: 'static,
359    {
360        self.try_get_mut_internal(handle, move |obj| {
361            obj.query_component_mut(TypeId::of::<C>())
362                .and_then(|c| c.downcast_mut())
363                .ok_or(PoolError::NoSuchComponent(handle.into()))
364        })
365    }
366}
367
368#[cfg(test)]
369mod test {
370    use super::PoolError;
371    use crate::pool::Pool;
372
373    #[derive(PartialEq, Clone, Copy, Debug)]
374    struct MyPayload(u32);
375
376    #[test]
377    fn test_multi_borrow_context() {
378        let mut pool = Pool::<MyPayload>::new();
379
380        let mut val_a = MyPayload(123);
381        let mut val_b = MyPayload(321);
382        let mut val_c = MyPayload(42);
383        let val_d = MyPayload(666);
384
385        let a = pool.spawn(val_a);
386        let b = pool.spawn(val_b);
387        let c = pool.spawn(val_c);
388        let d = pool.spawn(val_d);
389
390        pool.free(d);
391
392        let ctx = pool.begin_multi_borrow();
393
394        // Test empty.
395        {
396            assert_eq!(
397                ctx.try_get(d).as_deref(),
398                Err(PoolError::Empty(d.into())).as_ref()
399            );
400            assert_eq!(
401                ctx.try_get_mut(d).as_deref_mut(),
402                Err(PoolError::Empty(d.into())).as_mut()
403            );
404        }
405
406        // Test immutable borrowing of the same element.
407        {
408            let ref_a_1 = ctx.try_get(a);
409            let ref_a_2 = ctx.try_get(a);
410            assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
411            assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
412        }
413
414        // Test immutable borrowing of the same element with the following mutable borrowing.
415        {
416            let ref_a_1 = ctx.try_get(a);
417            assert_eq!(unsafe { ref_a_1.as_ref().unwrap().ref_counter.get() }, 1);
418            let ref_a_2 = ctx.try_get(a);
419            assert_eq!(unsafe { ref_a_2.as_ref().unwrap().ref_counter.get() }, 2);
420
421            assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
422            assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
423            assert_eq!(
424                ctx.try_get_mut(a).as_deref(),
425                Err(PoolError::ImmutablyBorrowed(a.into())).as_ref()
426            );
427
428            drop(ref_a_1);
429            drop(ref_a_2);
430
431            let mut mut_ref_a_1 = ctx.try_get_mut(a);
432            assert_eq!(mut_ref_a_1.as_deref_mut(), Ok(&mut val_a));
433
434            assert_eq!(
435                unsafe { mut_ref_a_1.as_ref().unwrap().ref_counter.get() },
436                -1
437            );
438        }
439
440        // Test immutable and mutable borrowing.
441        {
442            // Borrow two immutable refs to the same element.
443            let ref_a_1 = ctx.try_get(a);
444            let ref_a_2 = ctx.try_get(a);
445            assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
446            assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
447
448            // Borrow immutable ref to other element.
449            let mut ref_b_1 = ctx.try_get_mut(b);
450            let mut ref_b_2 = ctx.try_get_mut(b);
451            assert_eq!(ref_b_1.as_deref_mut(), Ok(&mut val_b));
452            assert_eq!(
453                ref_b_2.as_deref_mut(),
454                Err(PoolError::MutablyBorrowed(b.into())).as_mut()
455            );
456
457            let mut ref_c_1 = ctx.try_get_mut(c);
458            let mut ref_c_2 = ctx.try_get_mut(c);
459            assert_eq!(ref_c_1.as_deref_mut(), Ok(&mut val_c));
460            assert_eq!(
461                ref_c_2.as_deref_mut(),
462                Err(PoolError::MutablyBorrowed(c.into())).as_mut()
463            );
464        }
465    }
466}