1use super::{Handle, PayloadContainer, Pool, RefCounter};
22use crate::ComponentProvider;
23use std::{
24 any::TypeId,
25 cell::RefCell,
26 cmp::Ordering,
27 fmt::{Debug, Display, Formatter},
28 marker::PhantomData,
29 ops::{Deref, DerefMut},
30};
31
32pub struct Ref<'a, 'b, T>
33where
34 T: ?Sized,
35{
36 data: &'a T,
37 ref_counter: &'a RefCounter,
38 phantom: PhantomData<&'b ()>,
39}
40
41impl<T> Debug for Ref<'_, '_, T>
42where
43 T: ?Sized + Debug,
44{
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 Debug::fmt(&self.data, f)
47 }
48}
49
50impl<T> Deref for Ref<'_, '_, T>
51where
52 T: ?Sized,
53{
54 type Target = T;
55
56 fn deref(&self) -> &Self::Target {
57 self.data
58 }
59}
60
61impl<T> Drop for Ref<'_, '_, T>
62where
63 T: ?Sized,
64{
65 fn drop(&mut self) {
66 unsafe {
69 self.ref_counter.decrement();
70 }
71 }
72}
73
74pub struct RefMut<'a, 'b, T>
75where
76 T: ?Sized,
77{
78 data: &'a mut T,
79 ref_counter: &'a RefCounter,
80 phantom: PhantomData<&'b ()>,
81}
82
83impl<T> Debug for RefMut<'_, '_, T>
84where
85 T: ?Sized + Debug,
86{
87 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88 Debug::fmt(&self.data, f)
89 }
90}
91
92impl<T> Deref for RefMut<'_, '_, T>
93where
94 T: ?Sized,
95{
96 type Target = T;
97
98 fn deref(&self) -> &Self::Target {
99 self.data
100 }
101}
102
103impl<T> DerefMut for RefMut<'_, '_, T>
104where
105 T: ?Sized,
106{
107 fn deref_mut(&mut self) -> &mut Self::Target {
108 self.data
109 }
110}
111
112impl<T> Drop for RefMut<'_, '_, T>
113where
114 T: ?Sized,
115{
116 fn drop(&mut self) {
117 unsafe {
120 self.ref_counter.increment();
121 }
122 }
123}
124
125pub struct MultiBorrowContext<'a, T, P = Option<T>>
128where
129 T: Sized,
130 P: PayloadContainer<Element = T> + 'static,
131{
132 pool: &'a mut Pool<T, P>,
133 free_indices: RefCell<Vec<u32>>,
134}
135
136#[derive(PartialEq)]
137pub enum MultiBorrowError<T> {
138 Empty(Handle<T>),
139 NoSuchComponent(Handle<T>),
140 MutablyBorrowed(Handle<T>),
141 ImmutablyBorrowed(Handle<T>),
142 InvalidHandleIndex(Handle<T>),
143 InvalidHandleGeneration(Handle<T>),
144}
145
146impl<T> Debug for MultiBorrowError<T> {
147 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148 Display::fmt(self, f)
149 }
150}
151
152impl<T> Display for MultiBorrowError<T> {
153 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
154 match self {
155 Self::Empty(handle) => {
156 write!(f, "There's no object at {handle} handle.")
157 }
158 Self::NoSuchComponent(handle) => write!(
159 f,
160 "An object at {handle} handle does not have such component.",
161 ),
162 Self::MutablyBorrowed(handle) => {
163 write!(
164 f,
165 "An object at {handle} handle cannot be borrowed immutably, because it is \
166 already borrowed mutably."
167 )
168 }
169 Self::ImmutablyBorrowed(handle) => {
170 write!(
171 f,
172 "An object at {handle} handle cannot be borrowed mutably, because it is \
173 already borrowed immutably."
174 )
175 }
176 Self::InvalidHandleIndex(handle) => {
177 write!(
178 f,
179 "The index {} in {handle} handle is out of bounds.",
180 handle.index
181 )
182 }
183 Self::InvalidHandleGeneration(handle) => {
184 write!(
185 f,
186 "The generation {} in {handle} handle does not match the record's generation. \
187 It means that the object at the handle was freed and it position was taken \
188 by some other object.",
189 handle.generation
190 )
191 }
192 }
193 }
194}
195
196impl<T, P> Drop for MultiBorrowContext<'_, T, P>
197where
198 T: Sized,
199 P: PayloadContainer<Element = T> + 'static,
200{
201 fn drop(&mut self) {
202 self.pool
203 .free_stack
204 .extend_from_slice(&self.free_indices.borrow())
205 }
206}
207
208impl<'a, T, P> MultiBorrowContext<'a, T, P>
209where
210 T: Sized,
211 P: PayloadContainer<Element = T> + 'static,
212{
213 #[inline]
214 pub fn new(pool: &'a mut Pool<T, P>) -> Self {
215 Self {
216 pool,
217 free_indices: Default::default(),
218 }
219 }
220
221 #[inline]
222 fn try_get_internal<'b: 'a, C, F>(
223 &'b self,
224 handle: Handle<T>,
225 func: F,
226 ) -> Result<Ref<'a, 'b, C>, MultiBorrowError<T>>
227 where
228 C: ?Sized,
229 F: FnOnce(&T) -> Result<&C, MultiBorrowError<T>>,
230 {
231 let Some(record) = self.pool.records_get(handle.index) else {
232 return Err(MultiBorrowError::InvalidHandleIndex(handle));
233 };
234
235 if handle.generation != record.generation {
236 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
237 }
238
239 let current_ref_count = unsafe { record.ref_counter.get() };
240 if current_ref_count < 0 {
241 return Err(MultiBorrowError::MutablyBorrowed(handle));
242 }
243
244 let payload_container = unsafe { &*record.payload.0.get() };
246
247 let Some(payload) = payload_container.as_ref() else {
248 return Err(MultiBorrowError::Empty(handle));
249 };
250
251 unsafe {
252 record.ref_counter.increment();
253 }
254
255 Ok(Ref {
256 data: func(payload)?,
257 ref_counter: &record.ref_counter,
258 phantom: PhantomData,
259 })
260 }
261
262 #[inline]
271 pub fn try_get<'b: 'a>(
272 &'b self,
273 handle: Handle<T>,
274 ) -> Result<Ref<'a, 'b, T>, MultiBorrowError<T>> {
275 self.try_get_internal(handle, |obj| Ok(obj))
276 }
277
278 #[inline]
279 pub fn get<'b: 'a>(&'b self, handle: Handle<T>) -> Ref<'a, 'b, T> {
280 self.try_get(handle).unwrap()
281 }
282
283 #[inline]
284 fn try_get_mut_internal<'b: 'a, C, F>(
285 &'b self,
286 handle: Handle<T>,
287 func: F,
288 ) -> Result<RefMut<'a, 'b, C>, MultiBorrowError<T>>
289 where
290 C: ?Sized,
291 F: FnOnce(&mut T) -> Result<&mut C, MultiBorrowError<T>>,
292 {
293 let Some(record) = self.pool.records_get(handle.index) else {
294 return Err(MultiBorrowError::InvalidHandleIndex(handle));
295 };
296
297 if handle.generation != record.generation {
298 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
299 }
300
301 let current_ref_count = unsafe { record.ref_counter.get() };
304 match current_ref_count.cmp(&0) {
305 Ordering::Less => {
306 return Err(MultiBorrowError::MutablyBorrowed(handle));
307 }
308 Ordering::Greater => {
309 return Err(MultiBorrowError::ImmutablyBorrowed(handle));
310 }
311 _ => (),
312 }
313
314 let payload_container = unsafe { &mut *record.payload.0.get() };
316
317 let Some(payload) = payload_container.as_mut() else {
318 return Err(MultiBorrowError::Empty(handle));
319 };
320
321 unsafe {
324 record.ref_counter.decrement();
325 }
326
327 Ok(RefMut {
328 data: func(payload)?,
329 ref_counter: &record.ref_counter,
330 phantom: PhantomData,
331 })
332 }
333
334 #[inline]
335 pub fn try_get_mut<'b: 'a>(
336 &'b self,
337 handle: Handle<T>,
338 ) -> Result<RefMut<'a, 'b, T>, MultiBorrowError<T>> {
339 self.try_get_mut_internal(handle, |obj| Ok(obj))
340 }
341
342 #[inline]
343 pub fn get_mut<'b: 'a>(&'b self, handle: Handle<T>) -> RefMut<'a, 'b, T> {
344 self.try_get_mut(handle).unwrap()
345 }
346
347 #[inline]
348 pub fn free(&self, handle: Handle<T>) -> Result<T, MultiBorrowError<T>> {
349 let Some(record) = self.pool.records_get(handle.index) else {
350 return Err(MultiBorrowError::InvalidHandleIndex(handle));
351 };
352
353 if handle.generation != record.generation {
354 return Err(MultiBorrowError::InvalidHandleGeneration(handle));
355 }
356
357 let current_ref_count = unsafe { record.ref_counter.get() };
361 match current_ref_count.cmp(&0) {
362 Ordering::Less => {
363 return Err(MultiBorrowError::MutablyBorrowed(handle));
364 }
365 Ordering::Greater => {
366 return Err(MultiBorrowError::ImmutablyBorrowed(handle));
367 }
368 _ => (),
369 }
370
371 let payload_container = unsafe { &mut *record.payload.0.get() };
373
374 let Some(payload) = payload_container.take() else {
375 return Err(MultiBorrowError::Empty(handle));
376 };
377
378 self.free_indices.borrow_mut().push(handle.index);
379
380 Ok(payload)
381 }
382}
383
384impl<'a, T, P> MultiBorrowContext<'a, T, P>
385where
386 T: Sized + ComponentProvider,
387 P: PayloadContainer<Element = T> + 'static,
388{
389 #[inline]
391 pub fn try_get_component_of_type<'b: 'a, C>(
392 &'b self,
393 handle: Handle<T>,
394 ) -> Result<Ref<'a, 'b, C>, MultiBorrowError<T>>
395 where
396 C: 'static,
397 {
398 self.try_get_internal(handle, move |obj| {
399 obj.query_component_ref(TypeId::of::<C>())
400 .and_then(|c| c.downcast_ref())
401 .ok_or(MultiBorrowError::NoSuchComponent(handle))
402 })
403 }
404
405 #[inline]
407 pub fn try_get_component_of_type_mut<'b: 'a, C>(
408 &'b self,
409 handle: Handle<T>,
410 ) -> Result<RefMut<'a, 'b, C>, MultiBorrowError<T>>
411 where
412 C: 'static,
413 {
414 self.try_get_mut_internal(handle, move |obj| {
415 obj.query_component_mut(TypeId::of::<C>())
416 .and_then(|c| c.downcast_mut())
417 .ok_or(MultiBorrowError::NoSuchComponent(handle))
418 })
419 }
420}
421
422#[cfg(test)]
423mod test {
424 use super::MultiBorrowError;
425 use crate::pool::Pool;
426
427 #[derive(PartialEq, Clone, Copy, Debug)]
428 struct MyPayload(u32);
429
430 #[test]
431 fn test_multi_borrow_context() {
432 let mut pool = Pool::<MyPayload>::new();
433
434 let mut val_a = MyPayload(123);
435 let mut val_b = MyPayload(321);
436 let mut val_c = MyPayload(42);
437 let val_d = MyPayload(666);
438
439 let a = pool.spawn(val_a);
440 let b = pool.spawn(val_b);
441 let c = pool.spawn(val_c);
442 let d = pool.spawn(val_d);
443
444 pool.free(d);
445
446 let ctx = pool.begin_multi_borrow();
447
448 {
450 assert_eq!(
451 ctx.try_get(d).as_deref(),
452 Err(MultiBorrowError::Empty(d)).as_ref()
453 );
454 assert_eq!(
455 ctx.try_get_mut(d).as_deref_mut(),
456 Err(MultiBorrowError::Empty(d)).as_mut()
457 );
458 }
459
460 {
462 let ref_a_1 = ctx.try_get(a);
463 let ref_a_2 = ctx.try_get(a);
464 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
465 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
466 }
467
468 {
470 let ref_a_1 = ctx.try_get(a);
471 assert_eq!(unsafe { ref_a_1.as_ref().unwrap().ref_counter.get() }, 1);
472 let ref_a_2 = ctx.try_get(a);
473 assert_eq!(unsafe { ref_a_2.as_ref().unwrap().ref_counter.get() }, 2);
474
475 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
476 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
477 assert_eq!(
478 ctx.try_get_mut(a).as_deref(),
479 Err(MultiBorrowError::ImmutablyBorrowed(a)).as_ref()
480 );
481
482 drop(ref_a_1);
483 drop(ref_a_2);
484
485 let mut mut_ref_a_1 = ctx.try_get_mut(a);
486 assert_eq!(mut_ref_a_1.as_deref_mut(), Ok(&mut val_a));
487
488 assert_eq!(
489 unsafe { mut_ref_a_1.as_ref().unwrap().ref_counter.get() },
490 -1
491 );
492 }
493
494 {
496 let ref_a_1 = ctx.try_get(a);
498 let ref_a_2 = ctx.try_get(a);
499 assert_eq!(ref_a_1.as_deref(), Ok(&val_a));
500 assert_eq!(ref_a_2.as_deref(), Ok(&val_a));
501
502 let mut ref_b_1 = ctx.try_get_mut(b);
504 let mut ref_b_2 = ctx.try_get_mut(b);
505 assert_eq!(ref_b_1.as_deref_mut(), Ok(&mut val_b));
506 assert_eq!(
507 ref_b_2.as_deref_mut(),
508 Err(MultiBorrowError::MutablyBorrowed(b)).as_mut()
509 );
510
511 let mut ref_c_1 = ctx.try_get_mut(c);
512 let mut ref_c_2 = ctx.try_get_mut(c);
513 assert_eq!(ref_c_1.as_deref_mut(), Ok(&mut val_c));
514 assert_eq!(
515 ref_c_2.as_deref_mut(),
516 Err(MultiBorrowError::MutablyBorrowed(c)).as_mut()
517 );
518 }
519 }
520}