1use std::{
19 collections::VecDeque,
20 mem::ManuallyDrop,
21 ops::{Deref, DerefMut},
22 sync::{Arc, Mutex},
23};
24
25#[derive(Debug)]
27pub struct ObjectPool<T> {
28 queue: Mutex<VecDeque<T>>,
30
31 capacity: Option<usize>,
34}
35
36impl<T> ObjectPool<T> {
37 pub fn new<A>(args: A, initial_size: usize, capacity: Option<usize>) -> Self
44 where
45 T: AsPooled<A>,
46 A: Clone,
47 {
48 let queue = (0..initial_size).map(|_| T::create(args.clone())).collect();
49
50 Self {
51 queue: Mutex::new(queue),
52 capacity,
53 }
54 }
55
56 pub fn try_new<A>(
63 args: A,
64 initial_size: usize,
65 capacity: Option<usize>,
66 ) -> Result<Self, T::Error>
67 where
68 T: TryAsPooled<A>,
69 A: Clone,
70 {
71 let queue = (0..initial_size)
72 .map(|_| T::try_create(args.clone()))
73 .collect::<Result<_, _>>()?;
74 Ok(Self {
75 queue: Mutex::new(queue),
76 capacity,
77 })
78 }
79
80 pub fn try_get_ref<A>(&self, args: A) -> Result<PooledRef<'_, T>, T::Error>
87 where
88 T: TryAsPooled<A>,
89 {
90 let item = self.try_get_or_create(args)?;
91 Ok(PooledRef {
92 item: ManuallyDrop::new(item),
93 parent: self,
94 })
95 }
96
97 pub fn get_ref<A>(&self, args: A) -> PooledRef<'_, T>
106 where
107 T: AsPooled<A>,
108 {
109 let item = self.get_or_create(args);
110 PooledRef {
111 item: ManuallyDrop::new(item),
112 parent: self,
113 }
114 }
115
116 pub fn try_get<A>(self: &Arc<Self>, args: A) -> Result<PooledArc<T>, T::Error>
123 where
124 T: TryAsPooled<A>,
125 {
126 let item = self.try_get_or_create(args)?;
127 Ok(PooledArc {
128 item: ManuallyDrop::new(item),
129 parent: self.clone(),
130 })
131 }
132
133 pub fn get<A>(self: &Arc<Self>, args: A) -> PooledArc<T>
142 where
143 T: AsPooled<A>,
144 {
145 let item = self.get_or_create(args);
146 PooledArc {
147 item: ManuallyDrop::new(item),
148 parent: self.clone(),
149 }
150 }
151
152 pub fn len(&self) -> usize {
154 self.lock().len()
155 }
156
157 pub fn is_empty(&self) -> bool {
159 self.len() == 0
160 }
161
162 fn try_get_or_create<A>(&self, args: A) -> Result<T, T::Error>
167 where
168 T: TryAsPooled<A>,
169 {
170 let maybe = self.lock().pop_front();
174 if let Some(mut item) = maybe {
175 item.try_modify(args)?;
176 Ok(item)
177 } else {
178 T::try_create(args)
179 }
180 }
181
182 fn get_or_create<A>(&self, args: A) -> T
183 where
184 T: AsPooled<A>,
185 {
186 let maybe = self.lock().pop_front();
190 if let Some(mut item) = maybe {
191 item.modify(args);
192 item
193 } else {
194 T::create(args)
195 }
196 }
197
198 fn lock(&self) -> std::sync::MutexGuard<'_, VecDeque<T>> {
199 match self.queue.lock() {
200 Ok(guard) => guard,
201 Err(poisoned) => {
202 self.queue.clear_poison();
212 poisoned.into_inner()
213 }
214 }
215 }
216}
217
218pub trait TryAsPooled<A>
225where
226 Self: Sized,
227{
228 type Error;
230
231 fn try_create(args: A) -> Result<Self, Self::Error>;
233
234 fn try_modify(&mut self, args: A) -> Result<(), Self::Error>;
243}
244
245pub trait AsPooled<A> {
251 fn create(args: A) -> Self;
253
254 fn modify(&mut self, args: A);
263}
264
265#[derive(Debug, Clone, Copy)]
273pub struct Undef {
274 pub len: usize,
275}
276
277impl Undef {
278 pub fn new(len: usize) -> Self {
280 Self { len }
281 }
282}
283
284impl<T> AsPooled<Undef> for Vec<T>
285where
286 T: Default + Clone,
287{
288 fn create(undef: Undef) -> Self {
289 vec![T::default(); undef.len]
290 }
291
292 fn modify(&mut self, undef: Undef) {
293 self.resize(undef.len, T::default())
294 }
295}
296
297#[derive(Debug)]
301pub struct PooledRef<'a, T> {
302 item: ManuallyDrop<T>,
303 parent: &'a ObjectPool<T>,
304}
305
306impl<T> Drop for PooledRef<'_, T> {
307 fn drop(&mut self) {
308 let mut guard = self.parent.lock();
309 if guard.len() < self.parent.capacity.unwrap_or(usize::MAX) {
310 guard.push_back(unsafe { ManuallyDrop::take(&mut self.item) });
312 } else {
313 std::mem::drop(guard);
315
316 unsafe { ManuallyDrop::drop(&mut self.item) };
318 }
319 }
320}
321
322impl<T> Deref for PooledRef<'_, T> {
323 type Target = T;
324
325 fn deref(&self) -> &Self::Target {
326 &self.item
327 }
328}
329
330impl<T> DerefMut for PooledRef<'_, T> {
331 fn deref_mut(&mut self) -> &mut Self::Target {
332 &mut self.item
333 }
334}
335
336#[derive(Debug)]
342pub struct PooledArc<T> {
343 item: ManuallyDrop<T>,
344 parent: Arc<ObjectPool<T>>,
345}
346
347impl<T> Drop for PooledArc<T> {
348 fn drop(&mut self) {
349 let mut guard = self.parent.lock();
350 if guard.len() < self.parent.capacity.unwrap_or(usize::MAX) {
351 guard.push_back(unsafe { ManuallyDrop::take(&mut self.item) });
353 } else {
354 std::mem::drop(guard);
356
357 unsafe { ManuallyDrop::drop(&mut self.item) };
359 }
360 }
361}
362
363impl<T> Deref for PooledArc<T> {
364 type Target = T;
365
366 fn deref(&self) -> &Self::Target {
367 &self.item
368 }
369}
370
371impl<T> DerefMut for PooledArc<T> {
372 fn deref_mut(&mut self) -> &mut Self::Target {
373 &mut self.item
374 }
375}
376
377#[derive(Debug)]
380pub enum PoolOption<T> {
381 NonPooled(T),
382 Pooled(PooledArc<T>),
383}
384
385impl<T> PoolOption<T> {
386 pub fn non_pooled(item: T) -> Self {
387 PoolOption::NonPooled(item)
388 }
389
390 pub fn try_non_pooled_create<A>(args: A) -> Result<Self, T::Error>
391 where
392 T: TryAsPooled<A>,
393 {
394 Ok(PoolOption::NonPooled(T::try_create(args)?))
395 }
396
397 pub fn non_pooled_create<A>(args: A) -> Self
398 where
399 T: AsPooled<A>,
400 {
401 PoolOption::NonPooled(T::create(args))
402 }
403
404 pub fn pooled<A>(pool: &Arc<ObjectPool<T>>, args: A) -> Self
405 where
406 T: AsPooled<A>,
407 {
408 PoolOption::Pooled(pool.get(args))
409 }
410
411 pub fn try_pooled<A>(pool: &Arc<ObjectPool<T>>, args: A) -> Result<Self, T::Error>
412 where
413 T: TryAsPooled<A>,
414 {
415 Ok(PoolOption::Pooled(pool.try_get(args)?))
416 }
417
418 pub fn is_pooled(&self) -> bool {
419 matches!(self, PoolOption::Pooled(_))
420 }
421
422 pub fn is_non_pooled(&self) -> bool {
423 matches!(self, PoolOption::NonPooled(_))
424 }
425}
426
427impl<T> Deref for PoolOption<T> {
428 type Target = T;
429
430 fn deref(&self) -> &Self::Target {
431 match self {
432 PoolOption::NonPooled(item) => item,
433 PoolOption::Pooled(item) => item,
434 }
435 }
436}
437
438impl<T> DerefMut for PoolOption<T> {
439 fn deref_mut(&mut self) -> &mut Self::Target {
440 match self {
441 PoolOption::NonPooled(item) => item,
442 PoolOption::Pooled(item) => item,
443 }
444 }
445}
446
447#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[derive(Debug)]
456 struct TestItem {
457 value: Box<u32>,
458 panic_on_drop: bool,
459 }
460
461 impl TestItem {
462 fn new(value: u32) -> Self {
463 Self {
464 value: Box::new(value),
465 panic_on_drop: false,
466 }
467 }
468 }
469
470 impl AsPooled<u32> for TestItem {
471 fn create(value: u32) -> Self {
472 TestItem::new(value)
473 }
474
475 fn modify(&mut self, value: u32) {
476 *self.value = value;
477 self.panic_on_drop = false;
478 }
479 }
480
481 impl TryAsPooled<i32> for TestItem {
482 type Error = ();
483
484 fn try_create(value: i32) -> Result<Self, Self::Error> {
485 match value.try_into() {
486 Ok(v) => Ok(TestItem::new(v)),
487 Err(_) => Err(()),
488 }
489 }
490
491 fn try_modify(&mut self, value: i32) -> Result<(), Self::Error> {
492 match value.try_into() {
493 Ok(v) => {
494 *self.value = v;
495 self.panic_on_drop = false;
496 Ok(())
497 }
498 Err(_) => Err(()),
499 }
500 }
501 }
502
503 impl Drop for TestItem {
504 fn drop(&mut self) {
505 if self.panic_on_drop {
506 panic!("panicking on drop");
507 }
508 }
509 }
510
511 struct TestPanic;
513
514 impl AsPooled<TestPanic> for TestItem {
515 fn create(_: TestPanic) -> Self {
516 panic!("panicking on create")
517 }
518
519 fn modify(&mut self, _: TestPanic) {
520 panic!("panicking on modify")
521 }
522 }
523
524 impl TryAsPooled<TestPanic> for TestItem {
525 type Error = ();
526
527 fn try_create(_: TestPanic) -> Result<Self, Self::Error> {
528 panic!("panicking on try_create")
529 }
530
531 fn try_modify(&mut self, _: TestPanic) -> Result<(), Self::Error> {
532 panic!("panicking on try_modify")
533 }
534 }
535
536 #[test]
537 fn test_pool_basic_tests() {
538 let pool = ObjectPool::<TestItem>::new(42, 2, None);
539 assert_eq!(pool.len(), 2);
540
541 let item1 = pool.get_ref(100);
542 assert_eq!(*item1.value, 100);
543 assert_eq!(pool.len(), 1);
544
545 let item2 = pool.get_ref(200);
546 assert_eq!(*item2.value, 200);
547 assert_eq!(pool.len(), 0);
548
549 let item = pool.get_ref(300);
550 assert_eq!(*item.value, 300);
551 assert_eq!(pool.len(), 0);
552 {
553 let item = pool.get_ref(400);
554 assert_eq!(*item.value, 400);
555 assert_eq!(pool.len(), 0);
556 }
557 assert_eq!(pool.len(), 1); {
559 let item_a = pool.get_ref(500);
560 assert_eq!(*item_a.value, 500);
561 assert_eq!(pool.len(), 0); let item_b = pool.get_ref(600);
563 assert_eq!(*item_b.value, 600);
564 assert_eq!(pool.len(), 0); }
566 assert_eq!(pool.len(), 2); let pool = ObjectPool::<TestItem>::new(42, 1, None);
569 let item = pool.get_ref(100);
570 assert_eq!(*item.value, 100);
571 }
572
573 #[test]
574 fn test_pool_basic_tests_with_try() {
575 let pool_result = ObjectPool::<TestItem>::try_new(-1, 2, Some(100));
577 assert!(
578 pool_result.is_err(),
579 "Pool creation should fail with negative args"
580 );
581
582 let pool = ObjectPool::<TestItem>::try_new(42, 2, None).unwrap();
583 assert_eq!(pool.len(), 2);
584 let item1 = pool.try_get_ref(100).unwrap();
585 assert_eq!(*item1.value, 100);
586 assert_eq!(pool.len(), 1);
587 let item2 = pool.try_get_ref(200).unwrap();
588 assert_eq!(*item2.value, 200);
589 assert_eq!(pool.len(), 0);
590 let item = pool.try_get_ref(300).unwrap();
591 assert_eq!(*item.value, 300);
592 assert_eq!(pool.len(), 0);
593 {
594 let item = pool.try_get_ref(400).unwrap();
595 assert_eq!(*item.value, 400);
596 assert_eq!(pool.len(), 0);
597 }
598 assert_eq!(pool.len(), 1); {
600 let item_a = pool.try_get_ref(500).unwrap();
601 assert_eq!(*item_a.value, 500);
602 assert_eq!(pool.len(), 0); let item_b = pool.try_get_ref(600).unwrap();
604 assert_eq!(*item_b.value, 600);
605 assert_eq!(pool.len(), 0); }
607 assert_eq!(pool.len(), 2); let pool = ObjectPool::<TestItem>::try_new(42, 1, Some(100)).unwrap();
610 let item = pool.try_get_ref(100).unwrap();
611 assert_eq!(*item.value, 100);
612 }
613
614 #[test]
615 fn test_pool_with_arc() {
616 let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
617 let item = pool.get(100);
618 assert_eq!(*item.value, 100);
619 assert_eq!(pool.len(), 0);
620
621 let item = pool.get(200);
622 assert_eq!(*item.value, 200);
623 assert_eq!(pool.len(), 0);
624 {
625 let item = pool.get(400);
626 assert_eq!(*item.value, 400);
627 assert_eq!(pool.len(), 0);
628 }
629 assert_eq!(pool.len(), 1); let item = pool.try_get_ref(400).unwrap();
631 assert_eq!(*item.value, 400);
632 assert_eq!(pool.len(), 0); let item = pool.try_get(500).unwrap();
634 assert_eq!(*item.value, 500);
635 }
636
637 #[test]
638 fn test_pool_max_capacity_ref() {
639 let pool = ObjectPool::<TestItem>::new(42, 1, Some(1));
640 assert_eq!(pool.len(), 1);
641 assert!(!pool.is_empty());
642 assert_eq!(pool.len(), pool.capacity.unwrap()); {
644 let item = pool.get_ref(100);
645 assert_eq!(pool.len(), 0); assert!(pool.is_empty());
647 assert!(pool.len() < pool.capacity.unwrap()); assert_eq!(*item.value, 100);
649 }
650 assert_eq!(pool.len(), 1); assert_eq!(pool.len(), pool.capacity.unwrap()); {
653 let item1 = pool.get_ref(100);
654 assert_eq!(pool.len(), 0); let item2 = pool.get_ref(200);
656 assert_eq!(pool.len(), 0); let item3 = pool.get_ref(300);
658 assert_eq!(pool.len(), 0); assert!(*item1.value == 100 && *item2.value == 200 && *item3.value == 300);
660 }
662 assert_eq!(pool.len(), pool.capacity.unwrap()); assert_eq!(pool.len(), 1); }
665
666 #[test]
667 fn test_pool_max_capacity_pooled_item() {
668 let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, Some(1)));
669 assert_eq!(pool.len(), 1);
670 assert_eq!(pool.len(), pool.capacity.unwrap()); {
672 let item = pool.get(100);
673 assert_eq!(pool.len(), 0); assert!(pool.len() < pool.capacity.unwrap()); assert_eq!(*item.value, 100);
676 }
677 assert_eq!(pool.len(), 1); assert_eq!(pool.len(), pool.capacity.unwrap()); {
680 let item1 = pool.get(100);
681 assert_eq!(pool.len(), 0); let item2 = pool.get(200);
683 assert_eq!(pool.len(), 0); let item3 = pool.get(300);
685 assert_eq!(pool.len(), 0); assert!(*item1.value == 100 && *item2.value == 200 && *item3.value == 300);
687 }
689 assert_eq!(pool.len(), pool.capacity.unwrap()); assert_eq!(pool.len(), 1); }
692
693 #[test]
694 fn test_pool_options() {
695 let item = PoolOption::non_pooled(TestItem::new(42));
697 assert_eq!(*item.value, 42);
698 assert!(item.is_non_pooled());
699 assert!(!item.is_pooled());
700
701 let item = PoolOption::<TestItem>::non_pooled_create(100);
702 assert_eq!(*item.value, 100);
703 assert!(item.is_non_pooled());
704
705 let pool = Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
707 let item = PoolOption::pooled(&pool, 100);
708 assert_eq!(*item.value, 100);
709 assert!(item.is_pooled());
710 assert!(!item.is_non_pooled());
711
712 let item = PoolOption::try_pooled(&pool, 100).unwrap();
714 assert_eq!(*item.value, 100);
715 assert!(item.is_pooled());
716 let item = PoolOption::<TestItem>::try_non_pooled_create(200).unwrap();
717 assert_eq!(*item.value, 200);
718 assert!(item.is_non_pooled());
719 assert!(!item.is_pooled());
720 let item_result = PoolOption::<TestItem>::try_non_pooled_create(-200);
721 assert!(
722 item_result.is_err(),
723 "Creating non-pooled item with negative args should fail"
724 );
725 }
726
727 #[test]
728 fn test_pool_ref_deref_mut() {
729 let pool = ObjectPool::<TestItem>::new(42, 1, None);
731 let item = pool.get_ref(100);
732
733 assert_eq!(*item.value, 100);
734 let mut item = pool.get_ref(100);
735 *item.value = 200;
736 assert_eq!(*item.value, 200);
737
738 let item_ref: &TestItem = &item;
739 assert_eq!(*item_ref.value, 200);
740
741 let item_ref_mut: &mut TestItem = &mut item;
742 assert_eq!(*item_ref_mut.value, 200);
743
744 *item_ref_mut.value = 300;
745 assert_eq!(*item_ref_mut.value, 300);
746
747 let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
749 let mut item = pool.get_ref(100);
750 assert_eq!(*item.value, 100);
751
752 *item.value = 200;
753 assert_eq!(*item.value, 200);
754
755 let pool = Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
757 let mut item = PoolOption::pooled(&pool, 100);
758 assert_eq!(*item.value, 100);
759
760 *item.value = 200;
761 assert_eq!(*item.value, 200);
762
763 let mut item = PoolOption::non_pooled(TestItem::new(42));
764 assert_eq!(*item.value, 42);
765
766 *item.value = 100;
767 assert_eq!(*item.value, 100);
768 }
769
770 fn check_error(err: &dyn std::any::Any, contains: &str) {
789 match err.downcast_ref::<&'static str>() {
790 Some(msg) => assert!(
791 msg.contains(contains),
792 "failed: message \"{}\" does not contain \"{}\"",
793 msg,
794 contains
795 ),
796 None => panic!("incorrect downcast type"),
797 }
798 }
799
800 #[test]
801 fn test_panic_during_create() {
802 let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
803
804 let err = std::panic::catch_unwind(|| {
806 let _ = pool.get_ref(TestPanic);
807 })
808 .unwrap_err();
809
810 check_error(&*err, "panicking on create");
811
812 assert!(
813 !pool.queue.is_poisoned(),
814 "lock should be released while calling trait implementations"
815 );
816
817 assert_eq!(pool.len(), 0);
818 }
819
820 #[test]
821 fn test_panic_during_try_create() {
822 let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
823
824 let err = std::panic::catch_unwind(|| {
826 let _ = pool.try_get_ref(TestPanic);
827 })
828 .unwrap_err();
829
830 check_error(&*err, "panicking on try_create");
831
832 assert!(
833 !pool.queue.is_poisoned(),
834 "lock should be released while calling trait implementations"
835 );
836
837 assert_eq!(pool.len(), 0);
838 }
839
840 #[test]
841 fn test_panic_during_modify() {
842 let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
843
844 let _ = pool.get_ref(0u32);
846 assert_eq!(pool.len(), 1);
847
848 let err = std::panic::catch_unwind(|| {
849 let _ = pool.get_ref(TestPanic);
850 })
851 .unwrap_err();
852
853 check_error(&*err, "panicking on modify");
854
855 assert!(
856 !pool.queue.is_poisoned(),
857 "lock should be released while calling trait implementations"
858 );
859
860 assert_eq!(
861 pool.len(),
862 0,
863 "we should not return a potentially torn object to the pool"
864 );
865 }
866
867 #[test]
868 fn test_panic_during_try_modify() {
869 let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
870
871 let _ = pool.get_ref(0u32);
873 assert_eq!(pool.len(), 1);
874
875 let err = std::panic::catch_unwind(|| {
876 let _ = pool.try_get_ref(TestPanic);
877 })
878 .unwrap_err();
879
880 check_error(&*err, "panicking on try_modify");
881
882 assert!(
883 !pool.queue.is_poisoned(),
884 "lock should be released while calling trait implementations"
885 );
886
887 assert_eq!(
888 pool.len(),
889 0,
890 "we should not return a potentially torn object to the pool"
891 );
892 }
893
894 #[test]
896 fn test_panic_during_drop_ref() {
897 let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
898
899 let mut a = pool.get_ref(0u32);
902 let _ = pool.get_ref(1u32);
903 assert_eq!(pool.len(), 1);
904
905 a.panic_on_drop = true;
906 let err = std::panic::catch_unwind(move || std::mem::drop(a)).unwrap_err();
907 check_error(&*err, "panicking on drop");
908
909 assert!(
910 !pool.queue.is_poisoned(),
911 "lock should be released while calling object drop"
912 );
913 assert_eq!(pool.len(), 1);
914 }
915
916 #[test]
918 fn test_panic_during_drop_arc() {
919 let pool = Arc::new(ObjectPool::<TestItem>::new(0u32, 0, Some(1)));
920
921 let mut a = pool.get(0u32);
924 let _ = pool.get(1u32);
925 assert_eq!(pool.len(), 1);
926
927 a.panic_on_drop = true;
928 let err = std::panic::catch_unwind(move || std::mem::drop(a)).unwrap_err();
929 check_error(&*err, "panicking on drop");
930
931 assert!(
932 !pool.queue.is_poisoned(),
933 "lock should be released while calling object drop"
934 );
935 assert_eq!(pool.len(), 1);
936 }
937
938 #[test]
940 fn test_panic_recovery() {
941 let pool = ObjectPool::<TestItem>::new(0u32, 1, Some(1));
942
943 let err = std::panic::catch_unwind(|| {
944 let _guard = pool.queue.lock();
945 panic!("yeet");
946 })
947 .unwrap_err();
948
949 check_error(&*err, "yeet");
950
951 assert!(pool.queue.is_poisoned());
952
953 let _ = pool.get_ref(1u32);
954 assert!(!pool.queue.is_poisoned(), "poison should be cleared");
955 }
956
957 #[test]
962 fn test_undef() {
963 let mut x: Vec<f32> = Vec::<f32>::create(Undef::new(10));
964 assert_eq!(x.len(), 10);
965
966 x.modify(Undef::new(0));
967 assert_eq!(x.len(), 0);
968
969 x.modify(Undef::new(20));
970 assert_eq!(x.len(), 20);
971 }
972}