fixed_size_buf/
lib.rs

1use std::{
2    ops::{Deref, DerefMut},
3    sync::{
4        atomic::{AtomicBool, Ordering},
5        Arc,
6    },
7    task::Waker,
8};
9
10use backing_buffer::BackingBuffer;
11use futures::BufGuardFuture;
12
13mod backing_buffer;
14mod futures;
15
16#[cfg(feature = "memmap")]
17pub use memmap2::MmapMut;
18
19pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
20    storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
21    /// TRUE available FALSE occupied
22    registry: Arc<[AtomicBool; BLOCKS]>,
23    waiters: Arc<crossbeam_queue::SegQueue<Waker>>,
24}
25
26impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
27    fn clone(&self) -> Self {
28        Self {
29            storage: self.storage.clone(),
30            registry: self.registry.clone(),
31            waiters: self.waiters.clone(),
32        }
33    }
34}
35
36impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
37    pub unsafe fn new(inner: Inner) -> Self {
38        let mut registry = Vec::with_capacity(BLOCKS);
39        for _ in 0..BLOCKS {
40            registry.push(AtomicBool::new(true));
41        }
42        Self {
43            storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
44            registry: Arc::new(registry.try_into().unwrap()),
45            waiters: Default::default(),
46        }
47    }
48
49    fn free_block(&self, guard: &BufGuard<'_, Inner, BLOCKS, SIZE>) {
50        let idx = guard.idx;
51        self.registry[idx].store(true, Ordering::SeqCst);
52        if let Some(waker) = self.waiters.pop() {
53            waker.wake();
54        }
55    }
56
57    fn push_waker(&self, waker: Waker) {
58        self.waiters.push(waker);
59    }
60}
61
62impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
63where
64    Inner: Deref<Target = [u8]> + DerefMut,
65{
66    pub async fn acquire_block<'a>(&'a self) -> BufGuard<'a, Inner, BLOCKS, SIZE> {
67        let block = self.try_acquire_block();
68
69        if let Some(block) = block {
70            block
71        } else {
72            BufGuardFuture::new(self).await
73        }
74    }
75
76    fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<'a, Inner, BLOCKS, SIZE>> {
77        let mut idx = None;
78        for (i, e) in self.registry.iter().enumerate() {
79            if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
80                .is_ok()
81            {
82                idx = Some(i);
83                break;
84            }
85        }
86        let idx = idx?;
87
88        let block = unsafe { self.storage.get_block(idx) };
89        BufGuard {
90            block,
91            buffer: self,
92            idx,
93        }
94        .into()
95    }
96}
97
98impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
99    pub fn new_vec() -> Self {
100        let mut registry = Vec::with_capacity(BLOCKS);
101        for _ in 0..BLOCKS {
102            registry.push(AtomicBool::new(true));
103        }
104        Self {
105            storage: Arc::new(BackingBuffer::new_vec()),
106            registry: Arc::new(registry.try_into().unwrap()),
107            waiters: Default::default(),
108        }
109    }
110}
111
112#[cfg(feature = "memmap")]
113impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
114    pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
115        use std::io::Write;
116
117        let mut backing_file = std::fs::OpenOptions::new()
118            .create(true)
119            .read(true)
120            .write(true)
121            .open(backing_file)
122            .unwrap();
123
124        backing_file.set_len((BLOCKS * SIZE).try_into().unwrap())?;
125        backing_file.flush()?;
126        let map = unsafe { memmap2::MmapMut::map_mut(&backing_file)? };
127        unsafe { Ok(Self::new(map)) }
128    }
129}
130
131pub struct BufGuard<'a, Inner, const BLOCKS: usize, const SIZE: usize> {
132    buffer: &'a Buffer<Inner, BLOCKS, SIZE>,
133    block: &'a mut [u8],
134    idx: usize,
135}
136
137impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Deref
138    for BufGuard<'a, Inner, BLOCKS, SIZE>
139{
140    type Target = [u8];
141    fn deref(&self) -> &Self::Target {
142        &self.block
143    }
144}
145
146impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> DerefMut
147    for BufGuard<'a, Inner, BLOCKS, SIZE>
148{
149    fn deref_mut(&mut self) -> &mut Self::Target {
150        &mut self.block
151    }
152}
153
154impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<'a, Inner, BLOCKS, SIZE> {
155    fn drop(&mut self) {
156        self.buffer.free_block(self);
157    }
158}
159
160#[cfg(test)]
161mod tests {
162
163    const SIZE: usize = 1048576 * 2;
164    use std::{io::Write, thread, time::Duration};
165
166    use tempdir::TempDir;
167
168    use super::*;
169    #[test]
170    fn basic_test() {
171        let (tx, rx) = oneshot::channel();
172        thread::spawn(move || {
173            std::thread::sleep(Duration::from_millis(300));
174            tx.send(())
175        });
176        ::futures::executor::block_on(async move {
177            let mut v = Vec::with_capacity(10 * SIZE);
178            v.fill(0);
179
180            let buffer = unsafe { Buffer::<_, 10, SIZE>::new(v) };
181            let buf = buffer.acquire_block().await;
182            assert_eq!(buf.len(), SIZE);
183            assert!(buf.iter().all(|&a| a == 0));
184
185            let mut blocks = Vec::new();
186            for _ in 0..9 {
187                let b = buffer.acquire_block().await;
188                blocks.push(b);
189            }
190            let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
191                _ = rx.await;
192                drop(buf);
193            });
194        });
195    }
196
197    #[cfg(feature = "memmap")]
198    #[test]
199    fn test_memmap() {
200        let dir = TempDir::new("dir").unwrap();
201
202        let (tx, rx) = oneshot::channel();
203
204        let buffer =
205            unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
206        let sent_buffer = buffer.clone();
207        thread::spawn(move || {
208            std::thread::sleep(Duration::from_millis(300));
209            _ = tx.send(());
210            ::futures::executor::block_on(sent_buffer.acquire_block());
211        });
212        ::futures::executor::block_on(async move {
213            let mut v = Vec::with_capacity(10 * SIZE);
214            v.fill(0);
215
216            let buf = buffer.acquire_block().await;
217            assert_eq!(buf.len(), SIZE);
218            assert!(buf.iter().all(|&a| a == 0));
219
220            let mut blocks = Vec::new();
221            for _ in 0..9 {
222                let b = buffer.acquire_block().await;
223                blocks.push(b);
224            }
225            let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
226                _ = rx.await;
227                drop(buf);
228            });
229
230            buffer.acquire_block().await;
231        });
232    }
233
234    #[test]
235    fn test_race() {
236        let dir = TempDir::new("dir").unwrap();
237        let mut f = std::fs::OpenOptions::new()
238            .create(true)
239            .read(true)
240            .write(true)
241            .open(dir.path().join("file.txt"))
242            .unwrap();
243
244        f.set_len((10 * SIZE) as _).unwrap();
245        f.flush().unwrap();
246        let map = unsafe { memmap2::MmapMut::map_mut(&f).unwrap() };
247        let buffer = unsafe { Buffer::<_, 10, SIZE>::new(map) };
248        let mut hs = Vec::new();
249        for _ in 0..50 {
250            let buffer = buffer.clone();
251            let h = std::thread::spawn(move || {
252                ::futures::executor::block_on(async move {
253                    for _ in 0..500 {
254                        buffer.acquire_block().await;
255                    }
256                });
257            });
258            hs.push(h);
259        }
260
261        for h in hs {
262            h.join().unwrap();
263        }
264    }
265}