1use std::{
2 ops::{Deref, DerefMut},
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 task::Waker,
8 thread::Thread,
9};
10
11use backing_buffer::BackingBuffer;
12use futures::BufGuardFuture;
13
14mod backing_buffer;
15mod futures;
16
17#[cfg(feature = "memmap")]
18pub use memmap2::MmapMut;
19
20enum InnerWaker {
21 Waker(Waker),
22 Thread(Thread),
23}
24
25pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
26 storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
27 registry: Arc<[AtomicBool; BLOCKS]>,
29 waiters: Arc<crossbeam_queue::SegQueue<InnerWaker>>,
30}
31
32impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
33 fn clone(&self) -> Self {
34 Self {
35 storage: self.storage.clone(),
36 registry: self.registry.clone(),
37 waiters: self.waiters.clone(),
38 }
39 }
40}
41
42impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
43 #[cfg(feature = "unsafe_new")]
44 pub unsafe fn new(inner: Inner) -> Self {
45 let mut registry = Vec::with_capacity(BLOCKS);
46 for _ in 0..BLOCKS {
47 registry.push(AtomicBool::new(true));
48 }
49 Self {
50 storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
51 registry: Arc::new(registry.try_into().unwrap()),
52 waiters: Default::default(),
53 }
54 }
55
56 fn free_block(&self, guard: &mut BufGuard<Inner, BLOCKS, SIZE>) {
57 let idx = guard.idx;
58 self.registry[idx].store(true, Ordering::SeqCst);
59 if let Some(waker) = self.waiters.pop() {
60 match waker {
61 InnerWaker::Waker(waker) => waker.wake(),
62 InnerWaker::Thread(thread) => thread.unpark(),
63 }
64 }
65 }
66
67 fn push_waker(&self, waker: Waker) {
68 self.waiters.push(InnerWaker::Waker(waker));
69 }
70
71 fn push_thread(&self, thread: Thread) {
72 self.waiters.push(InnerWaker::Thread(thread));
73 }
74}
75
76impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
77where
78 Inner: Deref<Target = [u8]> + DerefMut,
79{
80 pub async fn acquire_block(&self) -> BufGuard<Inner, BLOCKS, SIZE> {
81 let block = self.try_acquire_block();
82
83 if let Some(block) = block {
84 block
85 } else {
86 BufGuardFuture::new(self).await
87 }
88 }
89
90 pub fn blocking_acquire_block(&self) -> BufGuard<Inner, BLOCKS, SIZE> {
91 let block = self.try_acquire_block();
92
93 if let Some(block) = block {
94 block
95 } else {
96 loop {
97 let current = std::thread::current();
98 self.waiters.push(InnerWaker::Thread(current.clone()));
99
100 std::thread::park();
101 if let Some(block) = self.try_acquire_block() {
102 return block;
103 }
104 }
105 }
106 }
107
108 fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<Inner, BLOCKS, SIZE>> {
109 let mut idx = None;
110 for (i, e) in self.registry.iter().enumerate() {
111 if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
112 .is_ok()
113 {
114 idx = Some(i);
115 break;
116 }
117 }
118 let idx = idx?;
119
120 let block = unsafe { self.storage.get_block(idx) };
121 BufGuard {
122 block,
123 buffer: self.clone(),
124 idx,
125 }
126 .into()
127 }
128}
129
130impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
131 pub fn new_vec() -> Self {
132 let mut registry = Vec::with_capacity(BLOCKS);
133 for _ in 0..BLOCKS {
134 registry.push(AtomicBool::new(true));
135 }
136 Self {
137 storage: Arc::new(BackingBuffer::new_vec()),
138 registry: Arc::new(registry.try_into().unwrap()),
139 waiters: Default::default(),
140 }
141 }
142}
143
144#[cfg(feature = "memmap")]
145impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
146 pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
147 let mut registry = Vec::with_capacity(BLOCKS);
148 for _ in 0..BLOCKS {
149 registry.push(AtomicBool::new(true));
150 }
151 Ok(Self {
152 storage: Arc::new(BackingBuffer::new_memmap(backing_file)?),
153 registry: Arc::new(registry.try_into().unwrap()),
154 waiters: Default::default(),
155 })
156 }
157}
158
159#[cfg(feature = "bytes")]
160impl<const BLOCKS: usize, const SIZE: usize> Buffer<bytes::BytesMut, BLOCKS, SIZE> {
161 pub unsafe fn new_bytes() -> std::io::Result<Self> {
162 let mut registry = Vec::with_capacity(BLOCKS);
163 for _ in 0..BLOCKS {
164 registry.push(AtomicBool::new(true));
165 }
166 Ok(Self {
167 storage: Arc::new(BackingBuffer::new_bytes()),
168 registry: Arc::new(registry.try_into().unwrap()),
169 waiters: Default::default(),
170 })
171 }
172}
173
174pub struct BufGuard<Inner, const BLOCKS: usize, const SIZE: usize> {
175 buffer: Buffer<Inner, BLOCKS, SIZE>,
176 block: &'static mut [u8],
177 idx: usize,
178}
179
180impl<Inner, const BLOCKS: usize, const SIZE: usize> Deref for BufGuard<Inner, BLOCKS, SIZE> {
181 type Target = [u8];
182 fn deref(&self) -> &Self::Target {
183 self.block
184 }
185}
186
187impl<Inner, const BLOCKS: usize, const SIZE: usize> DerefMut for BufGuard<Inner, BLOCKS, SIZE> {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 self.block
190 }
191}
192
193impl<Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<Inner, BLOCKS, SIZE> {
194 fn drop(&mut self) {
195 self.fill(0);
196 let buffer = self.buffer.clone();
197 buffer.free_block(self);
198 }
199}
200
201#[cfg(test)]
202mod tests {
203
204 const SIZE: usize = 1048576 * 3;
205 use std::{io::Write, thread, time::Duration};
206
207 use tempdir::TempDir;
208
209 use super::*;
210 #[test]
211 fn basic_test() {
212 let (tx, rx) = oneshot::channel();
213 thread::spawn(move || {
214 std::thread::sleep(Duration::from_millis(300));
215 tx.send(())
216 });
217 ::futures::executor::block_on(async move {
218 let buffer = Buffer::<_, 10, SIZE>::new_vec();
219 let buf = buffer.acquire_block().await;
220 assert_eq!(buf.len(), SIZE);
221 assert!(buf.iter().all(|&a| a == 0));
222
223 let mut blocks = Vec::new();
224 for _ in 0..9 {
225 let b = buffer.acquire_block().await;
226 blocks.push(b);
227 }
228 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
229 _ = rx.await;
230 drop(buf);
231 });
232 });
233 }
234
235 #[cfg(feature = "memmap")]
236 #[test]
237 fn test_memmap() {
238 let dir = TempDir::new("dir").unwrap();
239
240 let (tx, rx) = oneshot::channel();
241
242 let buffer =
243 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
244 let sent_buffer = buffer.clone();
245 thread::spawn(move || {
246 std::thread::sleep(Duration::from_millis(300));
247 _ = tx.send(());
248 ::futures::executor::block_on(sent_buffer.acquire_block());
249 });
250 ::futures::executor::block_on(async move {
251 let mut v = Vec::with_capacity(10 * SIZE);
252 v.fill(0);
253
254 let buf = buffer.acquire_block().await;
255 assert_eq!(buf.len(), SIZE);
256 assert!(buf.iter().all(|&a| a == 0));
257
258 let mut blocks = Vec::new();
259 for _ in 0..9 {
260 let b = buffer.acquire_block().await;
261 blocks.push(b);
262 }
263 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
264 _ = rx.await;
265 drop(buf);
266 });
267
268 buffer.acquire_block().await;
269 });
270 }
271
272 #[cfg(feature = "memmap")]
273 #[test]
274 fn test_race() {
275 let dir = TempDir::new("dir").unwrap();
276
277 let buffer =
278 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file")).unwrap() };
279 let mut hs = Vec::new();
280 for _ in 0..50 {
281 let buffer = buffer.clone();
282 let h = std::thread::spawn(move || {
283 ::futures::executor::block_on(async move {
284 for _ in 0..500 {
285 buffer.acquire_block().await;
286 }
287 });
288 });
289 hs.push(h);
290 }
291
292 for h in hs {
293 h.join().unwrap();
294 }
295 }
296}