byte_pool/
pool.rs

1use std::fmt;
2use std::mem;
3use std::ops::{Deref, DerefMut};
4use std::ptr;
5
6use crossbeam_queue::SegQueue;
7use stable_deref_trait::StableDeref;
8
9use crate::poolable::{Poolable, Realloc};
10
11/// A pool of byte slices, that reuses memory.
12#[derive(Debug)]
13pub struct BytePool<T = Vec<u8>>
14where
15    T: Poolable,
16{
17    list_large: SegQueue<T>,
18    list_small: SegQueue<T>,
19}
20
21/// The size at which point values are allocated in the small list, rather
22// than the big.
23const SPLIT_SIZE: usize = 4 * 1024;
24
25/// The value returned by an allocation of the pool.
26/// When it is dropped the memory gets returned into the pool, and is not zeroed.
27/// If that is a concern, you must clear the data yourself.
28pub struct Block<'a, T: Poolable = Vec<u8>> {
29    data: mem::ManuallyDrop<T>,
30    pool: &'a BytePool<T>,
31}
32
33impl<T: Poolable + fmt::Debug> fmt::Debug for Block<'_, T> {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        f.debug_struct("Block").field("data", &self.data).finish()
36    }
37}
38
39impl<T: Poolable> Default for BytePool<T> {
40    fn default() -> Self {
41        BytePool::<T> {
42            list_large: SegQueue::new(),
43            list_small: SegQueue::new(),
44        }
45    }
46}
47
48impl<T: Poolable> BytePool<T> {
49    /// Constructs a new pool.
50    pub fn new() -> Self {
51        BytePool::default()
52    }
53
54    /// Allocates a new `Block`, which represents a fixed sice byte slice.
55    /// If `Block` is dropped, the memory is _not_ freed, but rather it is returned into the pool.
56    /// The returned `Block` contains arbitrary data, and must be zeroed or overwritten,
57    /// in cases this is needed.
58    pub fn alloc(&self, size: usize) -> Block<'_, T> {
59        assert!(size > 0, "Can not allocate empty blocks");
60
61        // check the last 4 blocks
62        let list = if size < SPLIT_SIZE {
63            &self.list_small
64        } else {
65            &self.list_large
66        };
67        if let Some(el) = list.pop() {
68            if el.capacity() == size {
69                // found one, reuse it
70                return Block::new(el, self);
71            } else {
72                // put it back
73                list.push(el);
74            }
75        }
76
77        // allocate a new block
78        let data = T::alloc(size);
79        Block::new(data, self)
80    }
81
82    fn push_raw_block(&self, block: T) {
83        if block.capacity() < SPLIT_SIZE {
84            self.list_small.push(block);
85        } else {
86            self.list_large.push(block);
87        }
88    }
89}
90
91impl<'a, T: Poolable> Drop for Block<'a, T> {
92    fn drop(&mut self) {
93        let data = mem::ManuallyDrop::into_inner(unsafe { ptr::read(&self.data) });
94        self.pool.push_raw_block(data);
95    }
96}
97
98impl<'a, T: Poolable> Block<'a, T> {
99    fn new(data: T, pool: &'a BytePool<T>) -> Self {
100        Block {
101            data: mem::ManuallyDrop::new(data),
102            pool,
103        }
104    }
105
106    /// Returns the amount of bytes this block has.
107    pub fn size(&self) -> usize {
108        self.data.capacity()
109    }
110}
111
112impl<'a, T: Poolable + Realloc> Block<'a, T> {
113    /// Resizes a block to a new size.
114    pub fn realloc(&mut self, new_size: usize) {
115        self.data.realloc(new_size);
116    }
117}
118
119impl<'a, T: Poolable> Deref for Block<'a, T> {
120    type Target = T;
121
122    #[inline]
123    fn deref(&self) -> &Self::Target {
124        self.data.deref()
125    }
126}
127
128impl<'a, T: Poolable> DerefMut for Block<'a, T> {
129    #[inline]
130    fn deref_mut(&mut self) -> &mut Self::Target {
131        self.data.deref_mut()
132    }
133}
134
135// Safe because Block is just a wrapper around `T`.
136unsafe impl<'a, T: StableDeref + Poolable> StableDeref for Block<'a, T> {}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn basics_vec_u8() {
144        let pool: BytePool<Vec<u8>> = BytePool::new();
145
146        for i in 0..100 {
147            let mut block_1k = pool.alloc(1 * 1024);
148            let mut block_4k = pool.alloc(4 * 1024);
149
150            for el in block_1k.deref_mut() {
151                *el = i as u8;
152            }
153
154            for el in block_4k.deref_mut() {
155                *el = i as u8;
156            }
157
158            for el in block_1k.deref() {
159                assert_eq!(*el, i as u8);
160            }
161
162            for el in block_4k.deref() {
163                assert_eq!(*el, i as u8);
164            }
165        }
166    }
167
168    #[test]
169    fn realloc() {
170        let pool: BytePool<Vec<u8>> = BytePool::new();
171
172        let mut buf = pool.alloc(10);
173
174        let _slice: &[u8] = &buf;
175
176        assert_eq!(buf.capacity(), 10);
177        for i in 0..10 {
178            buf[i] = 1;
179        }
180
181        buf.realloc(512);
182        assert_eq!(buf.capacity(), 512);
183        for el in buf.iter().take(10) {
184            assert_eq!(*el, 1);
185        }
186
187        buf.realloc(5);
188        assert_eq!(buf.capacity(), 5);
189        for el in buf.iter() {
190            assert_eq!(*el, 1);
191        }
192    }
193
194    #[test]
195    fn multi_thread() {
196        let pool = std::sync::Arc::new(BytePool::<Vec<u8>>::new());
197
198        let pool1 = pool.clone();
199        let h1 = std::thread::spawn(move || {
200            for _ in 0..100 {
201                let mut buf = pool1.alloc(64);
202                buf[10] = 10;
203            }
204        });
205
206        let pool2 = pool.clone();
207        let h2 = std::thread::spawn(move || {
208            for _ in 0..100 {
209                let mut buf = pool2.alloc(64);
210                buf[10] = 10;
211            }
212        });
213
214        h1.join().unwrap();
215        h2.join().unwrap();
216
217        // two threads allocating in parallel will need 2 buffers
218        assert!(pool.list_small.len() <= 2);
219    }
220
221    #[test]
222    fn basics_vec_usize() {
223        let pool: BytePool<Vec<usize>> = BytePool::new();
224
225        for i in 0..100 {
226            let mut block_1k = pool.alloc(1 * 1024);
227            let mut block_4k = pool.alloc(4 * 1024);
228
229            for el in block_1k.deref_mut() {
230                *el = i;
231            }
232
233            for el in block_4k.deref_mut() {
234                *el = i;
235            }
236
237            for el in block_1k.deref() {
238                assert_eq!(*el, i);
239            }
240
241            for el in block_4k.deref() {
242                assert_eq!(*el, i);
243            }
244        }
245    }
246
247    #[test]
248    fn basics_hash_map() {
249        use std::collections::HashMap;
250        let pool: BytePool<HashMap<String, String>> = BytePool::new();
251
252        let mut map = pool.alloc(4);
253        for i in 0..4 {
254            map.insert(format!("hello_{}", i), "world".into());
255        }
256        for i in 0..4 {
257            assert_eq!(
258                map.get(&format!("hello_{}", i)).unwrap(),
259                &"world".to_string()
260            );
261        }
262        drop(map);
263
264        for i in 0..100 {
265            let mut block_1k = pool.alloc(1 * 1024);
266            let mut block_4k = pool.alloc(4 * 1024);
267
268            for el in block_1k.deref_mut() {
269                *el.1 = i.to_string();
270            }
271
272            for el in block_4k.deref_mut() {
273                *el.1 = i.to_string();
274            }
275
276            for el in block_1k.deref() {
277                assert_eq!(*el.0, i.to_string());
278                assert_eq!(*el.1, i.to_string());
279            }
280
281            for el in block_4k.deref() {
282                assert_eq!(*el.0, i.to_string());
283                assert_eq!(*el.1, i.to_string());
284            }
285        }
286    }
287}