1use super::{Handle, ObjectOrVariant, PayloadContainer, Pool, PoolError, RefCounter};
22use crate::ComponentProvider;
23use std::{
24 any::TypeId,
25 cell::RefCell,
26 cmp::Ordering,
27 fmt::{Debug, Formatter},
28 marker::PhantomData,
29 ops::{Deref, DerefMut},
30};
31
32pub struct Ref<'a, 'b, T>
33where
34 T: ?Sized,
35{
36 data: &'a T,
37 ref_counter: &'a RefCounter,
38 phantom: PhantomData<&'b ()>,
39}
40
41impl<T> Debug for Ref<'_, '_, T>
42where
43 T: ?Sized + Debug,
44{
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 Debug::fmt(&self.data, f)
47 }
48}
49
50impl<T> Deref for Ref<'_, '_, T>
51where
52 T: ?Sized,
53{
54 type Target = T;
55
56 fn deref(&self) -> &Self::Target {
57 self.data
58 }
59}
60
61impl<T> Drop for Ref<'_, '_, T>
62where
63 T: ?Sized,
64{
65 fn drop(&mut self) {
66 unsafe {
69 self.ref_counter.decrement();
70 }
71 }
72}
73
74pub struct RefMut<'a, 'b, T>
75where
76 T: ?Sized,
77{
78 data: &'a mut T,
79 ref_counter: &'a RefCounter,
80 phantom: PhantomData<&'b ()>,
81}
82
83impl<T> Debug for RefMut<'_, '_, T>
84where
85 T: ?Sized + Debug,
86{
87 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88 Debug::fmt(&self.data, f)
89 }
90}
91
92impl<T> Deref for RefMut<'_, '_, T>
93where
94 T: ?Sized,
95{
96 type Target = T;
97
98 fn deref(&self) -> &Self::Target {
99 self.data
100 }
101}
102
103impl<T> DerefMut for RefMut<'_, '_, T>
104where
105 T: ?Sized,
106{
107 fn deref_mut(&mut self) -> &mut Self::Target {
108 self.data
109 }
110}
111
112impl<T> Drop for RefMut<'_, '_, T>
113where
114 T: ?Sized,
115{
116 fn drop(&mut self) {
117 unsafe {
120 self.ref_counter.increment();
121 }
122 }
123}
124
125pub struct MultiBorrowContext<'a, T, P = Option<T>>
128where
129 T: Sized,
130 P: PayloadContainer<Element = T> + 'static,
131{
132 pool: &'a mut Pool<T, P>,
133 free_indices: RefCell<Vec<u32>>,
134}
135
136impl<T, P> Drop for MultiBorrowContext<'_, T, P>
137where
138 T: Sized,
139 P: PayloadContainer<Element = T> + 'static,
140{
141 fn drop(&mut self) {
142 self.pool
143 .free_stack
144 .extend_from_slice(&self.free_indices.borrow())
145 }
146}
147
148impl<'a, T, P> MultiBorrowContext<'a, T, P>
149where
150 T: Sized,
151 P: PayloadContainer<Element = T> + 'static,
152{
153 #[inline]
154 pub fn new(pool: &'a mut Pool<T, P>) -> Self {
155 Self {
156 pool,
157 free_indices: Default::default(),
158 }
159 }
160
161 #[inline]
162 fn try_get_internal<'b: 'a, C, F>(
163 &'b self,
164 handle: Handle<T>,
165 func: F,
166 ) -> Result<Ref<'a, 'b, C>, PoolError>
167 where
168 C: ?Sized,
169 F: FnOnce(&T) -> Result<&C, PoolError>,
170 {
171 let record = self.pool.records_get(handle.index)?;
172
173 if handle.generation != record.generation {
174 return Err(PoolError::InvalidGeneration(handle.generation));
175 }
176
177 let current_ref_count = unsafe { record.ref_counter.get() };
178 if current_ref_count < 0 {
179 return Err(PoolError::MutablyBorrowed(handle.into()));
180 }
181
182 let payload_container = unsafe { &*record.payload.0.get() };
184
185 let Some(payload) = payload_container.as_ref() else {
186 return Err(PoolError::Empty(handle.into()));
187 };
188
189 unsafe {
190 record.ref_counter.increment();
191 }
192
193 Ok(Ref {
194 data: func(payload)?,
195 ref_counter: &record.ref_counter,
196 phantom: PhantomData,
197 })
198 }
199
200 #[inline]
207 pub fn try_get<'b, U>(&'b self, handle: Handle<U>) -> Result<Ref<'a, 'b, U>, PoolError>
208 where
209 'b: 'a,
210 U: ObjectOrVariant<T>,
211 {
212 self.try_get_internal(handle.to_base(), |obj| {
213 U::convert_to_dest_type(obj).ok_or(PoolError::InvalidType(handle.into()))
214 })
215 }
216
217 #[inline]
218 pub fn get<'b, U>(&'b self, handle: Handle<U>) -> Ref<'a, 'b, U>
219 where
220 'b: 'a,
221 U: ObjectOrVariant<T>,
222 {
223 self.try_get(handle).unwrap()
224 }
225
226 #[inline]
227 fn try_get_mut_internal<'b: 'a, C, F>(
228 &'b self,
229 handle: Handle<T>,
230 func: F,
231 ) -> Result<RefMut<'a, 'b, C>, PoolError>
232 where
233 C: ?Sized,
234 F: FnOnce(&mut T) -> Result<&mut C, PoolError>,
235 {
236 let record = self.pool.records_get(handle.index)?;
237
238 if handle.generation != record.generation {
239 return Err(PoolError::InvalidGeneration(handle.generation));
240 }
241
242 let current_ref_count = unsafe { record.ref_counter.get() };
245 match current_ref_count.cmp(&0) {
246 Ordering::Less => {
247 return Err(PoolError::MutablyBorrowed(handle.into()));
248 }
249 Ordering::Greater => {
250 return Err(PoolError::ImmutablyBorrowed(handle.into()));
251 }
252 _ => (),
253 }
254
255 let payload_container = unsafe { &mut *record.payload.0.get() };
257
258 let Some(payload) = payload_container.as_mut() else {
259 return Err(PoolError::Empty(handle.into()));
260 };
261
262 unsafe {
265 record.ref_counter.decrement();
266 }
267
268 Ok(RefMut {
269 data: func(payload)?,
270 ref_counter: &record.ref_counter,
271 phantom: PhantomData,
272 })
273 }
274
275 #[inline]
276 pub fn try_get_mut<'b, U>(&'b self, handle: Handle<U>) -> Result<RefMut<'a, 'b, U>, PoolError>
277 where
278 'b: 'a,
279 U: ObjectOrVariant<T>,
280 {
281 self.try_get_mut_internal(handle.to_base(), |obj| {
282 U::convert_to_dest_type_mut(obj).ok_or(PoolError::InvalidType(handle.into()))
283 })
284 }
285
286 #[inline]
287 pub fn get_mut<'b, U>(&'b self, handle: Handle<U>) -> RefMut<'a, 'b, U>
288 where
289 'b: 'a,
290 U: ObjectOrVariant<T>,
291 {
292 self.try_get_mut(handle).unwrap()
293 }
294
295 #[inline]
296 pub fn free(&self, handle: Handle<T>) -> Result<T, PoolError> {
297 let record = self.pool.records_get(handle.index)?;
298
299 if handle.generation != record.generation {
300 return Err(PoolError::InvalidGeneration(handle.generation));
301 }
302
303 let current_ref_count = unsafe { record.ref_counter.get() };
307 match current_ref_count.cmp(&0) {
308 Ordering::Less => {
309 return Err(PoolError::MutablyBorrowed(handle.into()));
310 }
311 Ordering::Greater => {
312 return Err(PoolError::ImmutablyBorrowed(handle.into()));
313 }
314 _ => (),
315 }
316
317 let payload_container = unsafe { &mut *record.payload.0.get() };
319
320 let Some(payload) = payload_container.take() else {
321 return Err(PoolError::Empty(handle.into()));
322 };
323
324 self.free_indices.borrow_mut().push(handle.index);
325
326 Ok(payload)
327 }
328}
329
330impl<'a, T, P> MultiBorrowContext<'a, T, P>
331where
332 T: Sized + ComponentProvider,
333 P: PayloadContainer<Element = T> + 'static,
334{
335 #[inline]
337 pub fn try_get_component_of_type<'b: 'a, C>(
338 &'b self,
339 handle: Handle<T>,
340 ) -> Result<Ref<'a, 'b, C>, PoolError>
341 where
342 C: 'static,
343 {
344 self.try_get_internal(handle, move |obj| {
345 obj.query_component_ref(TypeId::of::<C>())
346 .and_then(|c| c.downcast_ref())
347 .ok_or(PoolError::NoSuchComponent(handle.into()))
348 })
349 }
350
351 #[inline]
353 pub fn try_get_component_of_type_mut<'b: 'a, C>(
354 &'b self,
355 handle: Handle<T>,
356 ) -> Result<RefMut<'a, 'b, C>, PoolError>
357 where
358 C: 'static,
359 {
360 self.try_get_mut_internal(handle, move |obj| {
361 obj.query_component_mut(TypeId::of::<C>())
362 .and_then(|c| c.downcast_mut())
363 .ok_or(PoolError::NoSuchComponent(handle.into()))
364 })
365 }
366}
367
368#[cfg(test)]
369mod test {
370 use super::PoolError;
371 use crate::pool::Pool;
372
373 #[derive(PartialEq, Clone, Copy, Debug)]
374 struct MyPayload(u32);
375
376 #[test]
377 fn test_multi_borrow_context() {
378 let mut pool = Pool::<MyPayload>::new();
379
380 let mut val_a = MyPayload(123);
381 let mut val_b = MyPayload(321);
382 let mut val_c = MyPayload(42);
383 let val_d = MyPayload(666);
384
385 let a = pool.spawn(val_a);
386 let b = pool.spawn(val_b);
387 let c = pool.spawn(val_c);
388 let d = pool.spawn(val_d);
389
390 pool.free(d);
391
392 let ctx = pool.begin_multi_borrow();
393
394 {
396 assert_eq!(
397 ctx.try_get(d).as_deref(),
398 Err(PoolError::Empty(d.into())).as_ref()
399 );
400 assert_eq!(
401 ctx.try_get_mut(d).as_deref_mut(),
402 Err(PoolError::Empty(d.into())).as_mut()
403 );
404 }
405
406 {
408 let ref_a_1 = ctx.try_get(a);
409 let ref_a_2 = ctx.try_get(a);
410 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
411 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
412 }
413
414 {
416 let ref_a_1 = ctx.try_get(a);
417 assert_eq!(unsafe { ref_a_1.as_ref().unwrap().ref_counter.get() }, 1);
418 let ref_a_2 = ctx.try_get(a);
419 assert_eq!(unsafe { ref_a_2.as_ref().unwrap().ref_counter.get() }, 2);
420
421 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
422 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
423 assert_eq!(
424 ctx.try_get_mut(a).as_deref(),
425 Err(PoolError::ImmutablyBorrowed(a.into())).as_ref()
426 );
427
428 drop(ref_a_1);
429 drop(ref_a_2);
430
431 let mut mut_ref_a_1 = ctx.try_get_mut(a);
432 assert_eq!(mut_ref_a_1.as_deref_mut(), Ok(&mut val_a));
433
434 assert_eq!(
435 unsafe { mut_ref_a_1.as_ref().unwrap().ref_counter.get() },
436 -1
437 );
438 }
439
440 {
442 let ref_a_1 = ctx.try_get(a);
444 let ref_a_2 = ctx.try_get(a);
445 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
446 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
447
448 let mut ref_b_1 = ctx.try_get_mut(b);
450 let mut ref_b_2 = ctx.try_get_mut(b);
451 assert_eq!(ref_b_1.as_deref_mut(), Ok(&mut val_b));
452 assert_eq!(
453 ref_b_2.as_deref_mut(),
454 Err(PoolError::MutablyBorrowed(b.into())).as_mut()
455 );
456
457 let mut ref_c_1 = ctx.try_get_mut(c);
458 let mut ref_c_2 = ctx.try_get_mut(c);
459 assert_eq!(ref_c_1.as_deref_mut(), Ok(&mut val_c));
460 assert_eq!(
461 ref_c_2.as_deref_mut(),
462 Err(PoolError::MutablyBorrowed(c.into())).as_mut()
463 );
464 }
465 }
466}