1use super::{Handle, PayloadContainer, Pool, RefCounter};
22use crate::ComponentProvider;
23use std::cell::RefCell;
24use std::cmp::Ordering;
25use std::{
26 any::TypeId,
27 fmt::{Debug, Display, Formatter},
28 marker::PhantomData,
29 ops::{Deref, DerefMut},
30 sync::atomic,
31};
32
33pub struct Ref<'a, 'b, T>
34where
35 T: ?Sized,
36{
37 data: &'a T,
38 ref_counter: &'a RefCounter,
39 phantom: PhantomData<&'b ()>,
40}
41
42impl<T> Debug for Ref<'_, '_, T>
43where
44 T: ?Sized + Debug,
45{
46 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47 Debug::fmt(&self.data, f)
48 }
49}
50
51impl<T> Deref for Ref<'_, '_, T>
52where
53 T: ?Sized,
54{
55 type Target = T;
56
57 fn deref(&self) -> &Self::Target {
58 self.data
59 }
60}
61
62impl<T> Drop for Ref<'_, '_, T>
63where
64 T: ?Sized,
65{
66 fn drop(&mut self) {
67 self.ref_counter.decrement();
68 }
69}
70
71pub struct RefMut<'a, 'b, T>
72where
73 T: ?Sized,
74{
75 data: &'a mut T,
76 ref_counter: &'a RefCounter,
77 phantom: PhantomData<&'b ()>,
78}
79
80impl<T> Debug for RefMut<'_, '_, T>
81where
82 T: ?Sized + Debug,
83{
84 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
85 Debug::fmt(&self.data, f)
86 }
87}
88
89impl<T> Deref for RefMut<'_, '_, T>
90where
91 T: ?Sized,
92{
93 type Target = T;
94
95 fn deref(&self) -> &Self::Target {
96 self.data
97 }
98}
99
100impl<T> DerefMut for RefMut<'_, '_, T>
101where
102 T: ?Sized,
103{
104 fn deref_mut(&mut self) -> &mut Self::Target {
105 self.data
106 }
107}
108
109impl<T> Drop for RefMut<'_, '_, T>
110where
111 T: ?Sized,
112{
113 fn drop(&mut self) {
114 self.ref_counter.increment();
115 }
116}
117
118pub struct MultiBorrowContext<'a, T, P = Option<T>>
121where
122 T: Sized,
123 P: PayloadContainer<Element = T> + 'static,
124{
125 pool: &'a mut Pool<T, P>,
126 free_indices: RefCell<Vec<u32>>,
127}
128
129#[derive(PartialEq)]
130pub enum MultiBorrowError<T> {
131 Empty(Handle<T>),
132 NoSuchComponent(Handle<T>),
133 MutablyBorrowed(Handle<T>),
134 ImmutablyBorrowed(Handle<T>),
135 InvalidHandleIndex(Handle<T>),
136 InvalidHandleGeneration(Handle<T>),
137}
138
139impl<T> Debug for MultiBorrowError<T> {
140 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
141 Display::fmt(self, f)
142 }
143}
144
145impl<T> Display for MultiBorrowError<T> {
146 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147 match self {
148 Self::Empty(handle) => {
149 write!(f, "There's no object at {handle} handle.")
150 }
151 Self::NoSuchComponent(handle) => write!(
152 f,
153 "An object at {handle} handle does not have such component.",
154 ),
155 Self::MutablyBorrowed(handle) => {
156 write!(
157 f,
158 "An object at {handle} handle cannot be borrowed immutably, because it is \
159 already borrowed mutably."
160 )
161 }
162 Self::ImmutablyBorrowed(handle) => {
163 write!(
164 f,
165 "An object at {handle} handle cannot be borrowed mutably, because it is \
166 already borrowed immutably."
167 )
168 }
169 Self::InvalidHandleIndex(handle) => {
170 write!(
171 f,
172 "The index {} in {handle} handle is out of bounds.",
173 handle.index
174 )
175 }
176 Self::InvalidHandleGeneration(handle) => {
177 write!(
178 f,
179 "The generation {} in {handle} handle does not match the record's generation. \
180 It means that the object at the handle was freed and it position was taken \
181 by some other object.",
182 handle.generation
183 )
184 }
185 }
186 }
187}
188
189impl<T, P> Drop for MultiBorrowContext<'_, T, P>
190where
191 T: Sized,
192 P: PayloadContainer<Element = T> + 'static,
193{
194 fn drop(&mut self) {
195 self.pool
196 .free_stack
197 .extend_from_slice(&self.free_indices.borrow())
198 }
199}
200
201impl<'a, T, P> MultiBorrowContext<'a, T, P>
202where
203 T: Sized,
204 P: PayloadContainer<Element = T> + 'static,
205{
206 #[inline]
207 pub fn new(pool: &'a mut Pool<T, P>) -> Self {
208 Self {
209 pool,
210 free_indices: Default::default(),
211 }
212 }
213
214 #[inline]
215 fn try_get_internal<'b: 'a, C, F>(
216 &'b self,
217 handle: Handle<T>,
218 func: F,
219 ) -> Result<Ref<'a, 'b, C>, MultiBorrowError<T>>
220 where
221 C: ?Sized,
222 F: FnOnce(&T) -> Result<&C, MultiBorrowError<T>>,
223 {
224 let Some(record) = self.pool.records_get(handle.index) else {
225 return Err(MultiBorrowError::InvalidHandleIndex(handle));
226 };
227
228 if handle.generation != record.generation {
229 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
230 }
231
232 let current_ref_count = record.ref_counter.0.load(atomic::Ordering::Relaxed);
233 if current_ref_count < 0 {
234 return Err(MultiBorrowError::MutablyBorrowed(handle));
235 }
236
237 let payload_container = unsafe { &*record.payload.0.get() };
239
240 let Some(payload) = payload_container.as_ref() else {
241 return Err(MultiBorrowError::Empty(handle));
242 };
243
244 if let Err(ref_count) = record.ref_counter.0.compare_exchange(
245 current_ref_count,
246 current_ref_count + 1,
247 atomic::Ordering::Acquire,
248 atomic::Ordering::Relaxed,
249 ) {
250 if ref_count < 0 {
252 return Err(MultiBorrowError::MutablyBorrowed(handle));
253 }
254 }
255
256 Ok(Ref {
257 data: func(payload)?,
258 ref_counter: &record.ref_counter,
259 phantom: PhantomData,
260 })
261 }
262
263 #[inline]
272 pub fn try_get<'b: 'a>(
273 &'b self,
274 handle: Handle<T>,
275 ) -> Result<Ref<'a, 'b, T>, MultiBorrowError<T>> {
276 self.try_get_internal(handle, |obj| Ok(obj))
277 }
278
279 #[inline]
280 pub fn get<'b: 'a>(&'b self, handle: Handle<T>) -> Ref<'a, 'b, T> {
281 self.try_get(handle).unwrap()
282 }
283
284 #[inline]
285 fn try_get_mut_internal<'b: 'a, C, F>(
286 &'b self,
287 handle: Handle<T>,
288 func: F,
289 ) -> Result<RefMut<'a, 'b, C>, MultiBorrowError<T>>
290 where
291 C: ?Sized,
292 F: FnOnce(&mut T) -> Result<&mut C, MultiBorrowError<T>>,
293 {
294 let Some(record) = self.pool.records_get(handle.index) else {
295 return Err(MultiBorrowError::InvalidHandleIndex(handle));
296 };
297
298 if handle.generation != record.generation {
299 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
300 }
301
302 let current_ref_count = record.ref_counter.0.load(atomic::Ordering::Relaxed);
303 match current_ref_count.cmp(&0) {
304 Ordering::Less => {
305 return Err(MultiBorrowError::MutablyBorrowed(handle));
306 }
307 Ordering::Greater => {
308 return Err(MultiBorrowError::ImmutablyBorrowed(handle));
309 }
310 _ => (),
311 }
312
313 let payload_container = unsafe { &mut *record.payload.0.get() };
315
316 let Some(payload) = payload_container.as_mut() else {
317 return Err(MultiBorrowError::Empty(handle));
318 };
319
320 if let Err(ref_count) = record.ref_counter.0.compare_exchange(
321 0,
322 -1,
323 atomic::Ordering::Acquire,
324 atomic::Ordering::Relaxed,
325 ) {
326 match ref_count.cmp(&0) {
327 Ordering::Less => {
328 return Err(MultiBorrowError::MutablyBorrowed(handle));
329 }
330 Ordering::Greater => {
331 return Err(MultiBorrowError::ImmutablyBorrowed(handle));
332 }
333 _ => (),
334 }
335 }
336
337 Ok(RefMut {
338 data: func(payload)?,
339 ref_counter: &record.ref_counter,
340 phantom: PhantomData,
341 })
342 }
343
344 #[inline]
345 pub fn try_get_mut<'b: 'a>(
346 &'b self,
347 handle: Handle<T>,
348 ) -> Result<RefMut<'a, 'b, T>, MultiBorrowError<T>> {
349 self.try_get_mut_internal(handle, |obj| Ok(obj))
350 }
351
352 #[inline]
353 pub fn get_mut<'b: 'a>(&'b self, handle: Handle<T>) -> RefMut<'a, 'b, T> {
354 self.try_get_mut(handle).unwrap()
355 }
356
357 #[inline]
358 pub fn free(&self, handle: Handle<T>) -> Result<T, MultiBorrowError<T>> {
359 let Some(record) = self.pool.records_get(handle.index) else {
360 return Err(MultiBorrowError::InvalidHandleIndex(handle));
361 };
362
363 if handle.generation != record.generation {
364 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
365 }
366
367 if let Err(ref_count) = record.ref_counter.0.compare_exchange(
369 0,
370 -1,
371 atomic::Ordering::Acquire,
372 atomic::Ordering::Relaxed,
373 ) {
374 match ref_count.cmp(&0) {
375 Ordering::Less => {
376 return Err(MultiBorrowError::MutablyBorrowed(handle));
377 }
378 Ordering::Greater => {
379 return Err(MultiBorrowError::ImmutablyBorrowed(handle));
380 }
381 _ => (),
382 }
383 }
384
385 let payload_container = unsafe { &mut *record.payload.0.get() };
387
388 let Some(payload) = payload_container.take() else {
389 return Err(MultiBorrowError::Empty(handle));
390 };
391
392 self.free_indices.borrow_mut().push(handle.index);
393
394 record.ref_counter.increment();
395
396 Ok(payload)
397 }
398}
399
400impl<'a, T, P> MultiBorrowContext<'a, T, P>
401where
402 T: Sized + ComponentProvider,
403 P: PayloadContainer<Element = T> + 'static,
404{
405 #[inline]
407 pub fn try_get_component_of_type<'b: 'a, C>(
408 &'b self,
409 handle: Handle<T>,
410 ) -> Result<Ref<'a, 'b, C>, MultiBorrowError<T>>
411 where
412 C: 'static,
413 {
414 self.try_get_internal(handle, move |obj| {
415 obj.query_component_ref(TypeId::of::<C>())
416 .and_then(|c| c.downcast_ref())
417 .ok_or(MultiBorrowError::NoSuchComponent(handle))
418 })
419 }
420
421 #[inline]
423 pub fn try_get_component_of_type_mut<'b: 'a, C>(
424 &'b self,
425 handle: Handle<T>,
426 ) -> Result<RefMut<'a, 'b, C>, MultiBorrowError<T>>
427 where
428 C: 'static,
429 {
430 self.try_get_mut_internal(handle, move |obj| {
431 obj.query_component_mut(TypeId::of::<C>())
432 .and_then(|c| c.downcast_mut())
433 .ok_or(MultiBorrowError::NoSuchComponent(handle))
434 })
435 }
436}
437
438#[cfg(test)]
439mod test {
440 use super::MultiBorrowError;
441 use crate::pool::Pool;
442 use std::sync::atomic;
443
444 #[derive(PartialEq, Clone, Copy, Debug)]
445 struct MyPayload(u32);
446
447 #[test]
448 fn test_multi_borrow_context() {
449 let mut pool = Pool::<MyPayload>::new();
450
451 let mut val_a = MyPayload(123);
452 let mut val_b = MyPayload(321);
453 let mut val_c = MyPayload(42);
454 let val_d = MyPayload(666);
455
456 let a = pool.spawn(val_a);
457 let b = pool.spawn(val_b);
458 let c = pool.spawn(val_c);
459 let d = pool.spawn(val_d);
460
461 pool.free(d);
462
463 let ctx = pool.begin_multi_borrow();
464
465 {
467 assert_eq!(
468 ctx.try_get(d).as_deref(),
469 Err(MultiBorrowError::Empty(d)).as_ref()
470 );
471 assert_eq!(
472 ctx.try_get_mut(d).as_deref_mut(),
473 Err(MultiBorrowError::Empty(d)).as_mut()
474 );
475 }
476
477 {
479 let ref_a_1 = ctx.try_get(a);
480 let ref_a_2 = ctx.try_get(a);
481 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
482 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
483 }
484
485 {
487 let ref_a_1 = ctx.try_get(a);
488 assert_eq!(
489 ref_a_1
490 .as_ref()
491 .unwrap()
492 .ref_counter
493 .0
494 .load(atomic::Ordering::Relaxed),
495 1
496 );
497 let ref_a_2 = ctx.try_get(a);
498 assert_eq!(
499 ref_a_2
500 .as_ref()
501 .unwrap()
502 .ref_counter
503 .0
504 .load(atomic::Ordering::Relaxed),
505 2
506 );
507
508 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
509 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
510 assert_eq!(
511 ctx.try_get_mut(a).as_deref(),
512 Err(MultiBorrowError::ImmutablyBorrowed(a)).as_ref()
513 );
514
515 drop(ref_a_1);
516 drop(ref_a_2);
517
518 let mut mut_ref_a_1 = ctx.try_get_mut(a);
519 assert_eq!(mut_ref_a_1.as_deref_mut(), Ok(&mut val_a));
520
521 assert_eq!(
522 mut_ref_a_1
523 .as_ref()
524 .unwrap()
525 .ref_counter
526 .0
527 .load(atomic::Ordering::Relaxed),
528 -1
529 );
530 }
531
532 {
534 let ref_a_1 = ctx.try_get(a);
536 let ref_a_2 = ctx.try_get(a);
537 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
538 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
539
540 let mut ref_b_1 = ctx.try_get_mut(b);
542 let mut ref_b_2 = ctx.try_get_mut(b);
543 assert_eq!(ref_b_1.as_deref_mut(), Ok(&mut val_b));
544 assert_eq!(
545 ref_b_2.as_deref_mut(),
546 Err(MultiBorrowError::MutablyBorrowed(b)).as_mut()
547 );
548
549 let mut ref_c_1 = ctx.try_get_mut(c);
550 let mut ref_c_2 = ctx.try_get_mut(c);
551 assert_eq!(ref_c_1.as_deref_mut(), Ok(&mut val_c));
552 assert_eq!(
553 ref_c_2.as_deref_mut(),
554 Err(MultiBorrowError::MutablyBorrowed(c)).as_mut()
555 );
556 }
557 }
558}