1use std::{
2 ops::{Deref, DerefMut},
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 task::Waker,
8};
9
10use backing_buffer::BackingBuffer;
11use futures::BufGuardFuture;
12
13mod backing_buffer;
14mod futures;
15
16#[cfg(feature = "memmap")]
17pub use memmap2::MmapMut;
18
19pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
20 storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
21 registry: Arc<[AtomicBool; BLOCKS]>,
23 waiters: Arc<crossbeam_queue::SegQueue<Waker>>,
24}
25
26impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
27 fn clone(&self) -> Self {
28 Self {
29 storage: self.storage.clone(),
30 registry: self.registry.clone(),
31 waiters: self.waiters.clone(),
32 }
33 }
34}
35
36impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
37 pub unsafe fn new(inner: Inner) -> Self {
38 let mut registry = Vec::with_capacity(BLOCKS);
39 for _ in 0..BLOCKS {
40 registry.push(AtomicBool::new(true));
41 }
42 Self {
43 storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
44 registry: Arc::new(registry.try_into().unwrap()),
45 waiters: Default::default(),
46 }
47 }
48
49 fn free_block(&self, guard: &BufGuard<'_, Inner, BLOCKS, SIZE>) {
50 let idx = guard.idx;
51 self.registry[idx].store(true, Ordering::SeqCst);
52 if let Some(waker) = self.waiters.pop() {
53 waker.wake();
54 }
55 }
56
57 fn push_waker(&self, waker: Waker) {
58 self.waiters.push(waker);
59 }
60}
61
62impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
63where
64 Inner: Deref<Target = [u8]> + DerefMut,
65{
66 pub async fn acquire_block<'a>(&'a self) -> BufGuard<'a, Inner, BLOCKS, SIZE> {
67 let block = self.try_acquire_block();
68
69 if let Some(block) = block {
70 block
71 } else {
72 BufGuardFuture::new(self).await
73 }
74 }
75
76 fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<'a, Inner, BLOCKS, SIZE>> {
77 let mut idx = None;
78 for (i, e) in self.registry.iter().enumerate() {
79 if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
80 .is_ok()
81 {
82 idx = Some(i);
83 break;
84 }
85 }
86 let idx = idx?;
87
88 let block = unsafe { self.storage.get_block(idx) };
89 BufGuard {
90 block,
91 buffer: self,
92 idx,
93 }
94 .into()
95 }
96}
97
98impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
99 pub fn new_vec() -> Self {
100 let mut registry = Vec::with_capacity(BLOCKS);
101 for _ in 0..BLOCKS {
102 registry.push(AtomicBool::new(true));
103 }
104 Self {
105 storage: Arc::new(BackingBuffer::new_vec()),
106 registry: Arc::new(registry.try_into().unwrap()),
107 waiters: Default::default(),
108 }
109 }
110}
111
112#[cfg(feature = "memmap")]
113impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
114 pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
115 use std::io::Write;
116
117 let mut backing_file = std::fs::OpenOptions::new()
118 .create(true)
119 .read(true)
120 .write(true)
121 .open(backing_file)
122 .unwrap();
123
124 backing_file.set_len((BLOCKS * SIZE).try_into().unwrap())?;
125 backing_file.flush()?;
126 let map = unsafe { memmap2::MmapMut::map_mut(&backing_file)? };
127 unsafe { Ok(Self::new(map)) }
128 }
129}
130
131pub struct BufGuard<'a, Inner, const BLOCKS: usize, const SIZE: usize> {
132 buffer: &'a Buffer<Inner, BLOCKS, SIZE>,
133 block: &'a mut [u8],
134 idx: usize,
135}
136
137impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Deref
138 for BufGuard<'a, Inner, BLOCKS, SIZE>
139{
140 type Target = [u8];
141 fn deref(&self) -> &Self::Target {
142 &self.block
143 }
144}
145
146impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> DerefMut
147 for BufGuard<'a, Inner, BLOCKS, SIZE>
148{
149 fn deref_mut(&mut self) -> &mut Self::Target {
150 &mut self.block
151 }
152}
153
154impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<'a, Inner, BLOCKS, SIZE> {
155 fn drop(&mut self) {
156 self.buffer.free_block(self);
157 }
158}
159
160#[cfg(test)]
161mod tests {
162
163 const SIZE: usize = 1048576 * 2;
164 use std::{io::Write, thread, time::Duration};
165
166 use tempdir::TempDir;
167
168 use super::*;
169 #[test]
170 fn basic_test() {
171 let (tx, rx) = oneshot::channel();
172 thread::spawn(move || {
173 std::thread::sleep(Duration::from_millis(300));
174 tx.send(())
175 });
176 ::futures::executor::block_on(async move {
177 let mut v = Vec::with_capacity(10 * SIZE);
178 v.fill(0);
179
180 let buffer = unsafe { Buffer::<_, 10, SIZE>::new(v) };
181 let buf = buffer.acquire_block().await;
182 assert_eq!(buf.len(), SIZE);
183 assert!(buf.iter().all(|&a| a == 0));
184
185 let mut blocks = Vec::new();
186 for _ in 0..9 {
187 let b = buffer.acquire_block().await;
188 blocks.push(b);
189 }
190 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
191 _ = rx.await;
192 drop(buf);
193 });
194 });
195 }
196
197 #[cfg(feature = "memmap")]
198 #[test]
199 fn test_memmap() {
200 let dir = TempDir::new("dir").unwrap();
201
202 let (tx, rx) = oneshot::channel();
203
204 let buffer =
205 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
206 let sent_buffer = buffer.clone();
207 thread::spawn(move || {
208 std::thread::sleep(Duration::from_millis(300));
209 _ = tx.send(());
210 ::futures::executor::block_on(sent_buffer.acquire_block());
211 });
212 ::futures::executor::block_on(async move {
213 let mut v = Vec::with_capacity(10 * SIZE);
214 v.fill(0);
215
216 let buf = buffer.acquire_block().await;
217 assert_eq!(buf.len(), SIZE);
218 assert!(buf.iter().all(|&a| a == 0));
219
220 let mut blocks = Vec::new();
221 for _ in 0..9 {
222 let b = buffer.acquire_block().await;
223 blocks.push(b);
224 }
225 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
226 _ = rx.await;
227 drop(buf);
228 });
229
230 buffer.acquire_block().await;
231 });
232 }
233
234 #[test]
235 fn test_race() {
236 let dir = TempDir::new("dir").unwrap();
237 let mut f = std::fs::OpenOptions::new()
238 .create(true)
239 .read(true)
240 .write(true)
241 .open(dir.path().join("file.txt"))
242 .unwrap();
243
244 f.set_len((10 * SIZE) as _).unwrap();
245 f.flush().unwrap();
246 let map = unsafe { memmap2::MmapMut::map_mut(&f).unwrap() };
247 let buffer = unsafe { Buffer::<_, 10, SIZE>::new(map) };
248 let mut hs = Vec::new();
249 for _ in 0..50 {
250 let buffer = buffer.clone();
251 let h = std::thread::spawn(move || {
252 ::futures::executor::block_on(async move {
253 for _ in 0..500 {
254 buffer.acquire_block().await;
255 }
256 });
257 });
258 hs.push(h);
259 }
260
261 for h in hs {
262 h.join().unwrap();
263 }
264 }
265}