fixed_size_buf/
lib.rs

1use std::{
2    ops::{Deref, DerefMut},
3    sync::{
4        atomic::{AtomicBool, Ordering},
5        Arc,
6    },
7    task::Waker,
8};
9
10use backing_buffer::BackingBuffer;
11use futures::BufGuardFuture;
12
13mod backing_buffer;
14mod futures;
15
16#[cfg(feature = "memmap")]
17pub use memmap2::MmapMut;
18
19pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
20    storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
21    /// TRUE available FALSE occupied
22    registry: Arc<[AtomicBool; BLOCKS]>,
23    waiters: Arc<crossbeam_queue::SegQueue<Waker>>,
24}
25
26impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
27    fn clone(&self) -> Self {
28        Self {
29            storage: self.storage.clone(),
30            registry: self.registry.clone(),
31            waiters: self.waiters.clone(),
32        }
33    }
34}
35
36impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
37    #[cfg(feature = "unsafe_new")]
38    pub unsafe fn new(inner: Inner) -> Self {
39        let mut registry = Vec::with_capacity(BLOCKS);
40        for _ in 0..BLOCKS {
41            registry.push(AtomicBool::new(true));
42        }
43        Self {
44            storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
45            registry: Arc::new(registry.try_into().unwrap()),
46            waiters: Default::default(),
47        }
48    }
49
50    fn free_block(&self, guard: &mut BufGuard<Inner, BLOCKS, SIZE>) {
51        let idx = guard.idx;
52        self.registry[idx].store(true, Ordering::SeqCst);
53        if let Some(waker) = self.waiters.pop() {
54            waker.wake();
55        }
56    }
57
58    fn push_waker(&self, waker: Waker) {
59        self.waiters.push(waker);
60    }
61}
62
63impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
64where
65    Inner: Deref<Target = [u8]> + DerefMut,
66{
67    pub async fn acquire_block(&self) -> BufGuard<Inner, BLOCKS, SIZE> {
68        let block = self.try_acquire_block();
69
70        if let Some(block) = block {
71            block
72        } else {
73            BufGuardFuture::new(self).await
74        }
75    }
76
77    fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<Inner, BLOCKS, SIZE>> {
78        let mut idx = None;
79        for (i, e) in self.registry.iter().enumerate() {
80            if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
81                .is_ok()
82            {
83                idx = Some(i);
84                break;
85            }
86        }
87        let idx = idx?;
88
89        let block = unsafe { self.storage.get_block(idx) };
90        BufGuard {
91            block,
92            buffer: self.clone(),
93            idx,
94        }
95        .into()
96    }
97}
98
99impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
100    pub fn new_vec() -> Self {
101        let mut registry = Vec::with_capacity(BLOCKS);
102        for _ in 0..BLOCKS {
103            registry.push(AtomicBool::new(true));
104        }
105        Self {
106            storage: Arc::new(BackingBuffer::new_vec()),
107            registry: Arc::new(registry.try_into().unwrap()),
108            waiters: Default::default(),
109        }
110    }
111}
112
113#[cfg(feature = "memmap")]
114impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
115    pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
116        let mut registry = Vec::with_capacity(BLOCKS);
117        for _ in 0..BLOCKS {
118            registry.push(AtomicBool::new(true));
119        }
120        Ok(Self {
121            storage: Arc::new(BackingBuffer::new_memmap(backing_file)?),
122            registry: Arc::new(registry.try_into().unwrap()),
123            waiters: Default::default(),
124        })
125    }
126}
127
128#[cfg(feature = "bytes")]
129impl<const BLOCKS: usize, const SIZE: usize> Buffer<bytes::BytesMut, BLOCKS, SIZE> {
130    pub unsafe fn new_bytes() -> std::io::Result<Self> {
131        let mut registry = Vec::with_capacity(BLOCKS);
132        for _ in 0..BLOCKS {
133            registry.push(AtomicBool::new(true));
134        }
135        Ok(Self {
136            storage: Arc::new(BackingBuffer::new_bytes()),
137            registry: Arc::new(registry.try_into().unwrap()),
138            waiters: Default::default(),
139        })
140    }
141}
142
143pub struct BufGuard<Inner, const BLOCKS: usize, const SIZE: usize> {
144    buffer: Buffer<Inner, BLOCKS, SIZE>,
145    block: &'static mut [u8],
146    idx: usize,
147}
148
149impl<Inner, const BLOCKS: usize, const SIZE: usize> Deref for BufGuard<Inner, BLOCKS, SIZE> {
150    type Target = [u8];
151    fn deref(&self) -> &Self::Target {
152        &self.block
153    }
154}
155
156impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> DerefMut for BufGuard<Inner, BLOCKS, SIZE> {
157    fn deref_mut(&mut self) -> &mut Self::Target {
158        &mut self.block
159    }
160}
161
162impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<Inner, BLOCKS, SIZE> {
163    fn drop(&mut self) {
164        self.fill(0);
165        let buffer = self.buffer.clone();
166        buffer.free_block(self);
167    }
168}
169
170#[cfg(test)]
171mod tests {
172
173    const SIZE: usize = 1048576 * 3;
174    use std::{io::Write, thread, time::Duration};
175
176    use tempdir::TempDir;
177
178    use super::*;
179    #[test]
180    fn basic_test() {
181        let (tx, rx) = oneshot::channel();
182        thread::spawn(move || {
183            std::thread::sleep(Duration::from_millis(300));
184            tx.send(())
185        });
186        ::futures::executor::block_on(async move {
187            let buffer = Buffer::<_, 10, SIZE>::new_vec();
188            let buf = buffer.acquire_block().await;
189            assert_eq!(buf.len(), SIZE);
190            assert!(buf.iter().all(|&a| a == 0));
191
192            let mut blocks = Vec::new();
193            for _ in 0..9 {
194                let b = buffer.acquire_block().await;
195                blocks.push(b);
196            }
197            let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
198                _ = rx.await;
199                drop(buf);
200            });
201        });
202    }
203
204    #[cfg(feature = "memmap")]
205    #[test]
206    fn test_memmap() {
207        let dir = TempDir::new("dir").unwrap();
208
209        let (tx, rx) = oneshot::channel();
210
211        let buffer =
212            unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
213        let sent_buffer = buffer.clone();
214        thread::spawn(move || {
215            std::thread::sleep(Duration::from_millis(300));
216            _ = tx.send(());
217            ::futures::executor::block_on(sent_buffer.acquire_block());
218        });
219        ::futures::executor::block_on(async move {
220            let mut v = Vec::with_capacity(10 * SIZE);
221            v.fill(0);
222
223            let buf = buffer.acquire_block().await;
224            assert_eq!(buf.len(), SIZE);
225            assert!(buf.iter().all(|&a| a == 0));
226
227            let mut blocks = Vec::new();
228            for _ in 0..9 {
229                let b = buffer.acquire_block().await;
230                blocks.push(b);
231            }
232            let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
233                _ = rx.await;
234                drop(buf);
235            });
236
237            buffer.acquire_block().await;
238        });
239    }
240
241    #[cfg(feature = "memmap")]
242    #[test]
243    fn test_race() {
244        let dir = TempDir::new("dir").unwrap();
245
246        let buffer =
247            unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file")).unwrap() };
248        let mut hs = Vec::new();
249        for _ in 0..50 {
250            let buffer = buffer.clone();
251            let h = std::thread::spawn(move || {
252                ::futures::executor::block_on(async move {
253                    for _ in 0..500 {
254                        buffer.acquire_block().await;
255                    }
256                });
257            });
258            hs.push(h);
259        }
260
261        for h in hs {
262            h.join().unwrap();
263        }
264    }
265}