db_pool/sync/
object_pool.rs

1// adapted from https://github.com/CJP10/object-pool and https://github.com/EVaillant/lockfree-object-pool
2
3use parking_lot::Mutex;
4use std::ops::{Deref, DerefMut};
5
6type Stack<T> = Vec<T>;
7type Init<T> = Box<dyn Fn() -> T + Send + Sync + 'static>;
8type Reset<T> = Box<dyn Fn(&mut T) + Send + Sync + 'static>;
9
10/// Object pool
11pub struct ObjectPool<T> {
12    objects: Mutex<Stack<T>>,
13    init: Init<T>,
14    reset: Reset<T>,
15}
16
17impl<T> ObjectPool<T> {
18    pub(crate) fn new(
19        init: impl Fn() -> T + Send + Sync + 'static,
20        reset: impl Fn(&mut T) + Send + Sync + 'static,
21    ) -> ObjectPool<T> {
22        ObjectPool {
23            objects: Mutex::new(Vec::new()),
24            init: Box::new(init),
25            reset: Box::new(reset),
26        }
27    }
28
29    pub(crate) fn pull(&self) -> Reusable<T> {
30        self.objects.lock().pop().map_or_else(
31            || Reusable::new(self, (self.init)()),
32            |mut data| {
33                (self.reset)(&mut data);
34                Reusable::new(self, data)
35            },
36        )
37    }
38
39    fn attach(&self, t: T) {
40        self.objects.lock().push(t);
41    }
42}
43
44/// Reusable object wrapper
45pub struct Reusable<'a, T> {
46    pool: &'a ObjectPool<T>,
47    data: Option<T>,
48}
49
50impl<'a, T> Reusable<'a, T> {
51    fn new(pool: &'a ObjectPool<T>, t: T) -> Self {
52        Self {
53            pool,
54            data: Some(t),
55        }
56    }
57}
58
59const DATA_MUST_CONTAIN_SOME: &str = "data must always contain a [Some] value";
60
61impl<T> Deref for Reusable<'_, T> {
62    type Target = T;
63
64    fn deref(&self) -> &Self::Target {
65        self.data.as_ref().expect(DATA_MUST_CONTAIN_SOME)
66    }
67}
68
69impl<T> DerefMut for Reusable<'_, T> {
70    fn deref_mut(&mut self) -> &mut Self::Target {
71        self.data.as_mut().expect(DATA_MUST_CONTAIN_SOME)
72    }
73}
74
75impl<T> Drop for Reusable<'_, T> {
76    fn drop(&mut self) {
77        self.pool
78            .attach(self.data.take().expect(DATA_MUST_CONTAIN_SOME));
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::ObjectPool;
85    use std::mem::drop;
86
87    impl<T> ObjectPool<T> {
88        fn len(&self) -> usize {
89            self.objects.lock().len()
90        }
91    }
92
93    #[test]
94    fn len() {
95        {
96            let pool = ObjectPool::<Vec<u8>>::new(Vec::new, |_| {});
97
98            let object1 = pool.pull();
99            drop(object1);
100            let object2 = pool.pull();
101            drop(object2);
102
103            assert_eq!(pool.len(), 1);
104        }
105
106        {
107            let pool = ObjectPool::<Vec<u8>>::new(Vec::new, |_| {});
108
109            let object1 = pool.pull();
110            let object2 = pool.pull();
111
112            drop(object1);
113            drop(object2);
114            assert_eq!(pool.len(), 2);
115        }
116    }
117
118    #[test]
119    fn e2e() {
120        let pool = ObjectPool::new(Vec::new, |_| {});
121        let mut objects = Vec::new();
122
123        for i in 0..10 {
124            let mut object = pool.pull();
125            object.push(i);
126            objects.push(object);
127        }
128
129        drop(objects);
130
131        for i in (0..10).rev() {
132            let mut object = pool.objects.lock().pop().expect("pool must have objects");
133            assert_eq!(object.pop(), Some(i));
134        }
135    }
136
137    #[test]
138    fn reset() {
139        let pool = ObjectPool::new(Vec::new, Vec::clear);
140
141        let mut object = pool.pull();
142        object.push(1);
143        drop(object);
144        let object = pool.pull();
145        assert_eq!(object.len(), 0);
146    }
147
148    #[test]
149    fn no_reset() {
150        let pool = ObjectPool::new(Vec::new, |_| {});
151
152        let mut object = pool.pull();
153        object.push(1);
154        drop(object);
155        let object = pool.pull();
156        assert_eq!(object.len(), 1);
157    }
158}