1use std::{
2 ops::{Deref, DerefMut},
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 task::Waker,
8};
9
10use backing_buffer::BackingBuffer;
11use futures::BufGuardFuture;
12
13mod backing_buffer;
14mod futures;
15
16#[cfg(feature = "memmap")]
17pub use memmap2::MmapMut;
18
19pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
20 storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
21 registry: Arc<[AtomicBool; BLOCKS]>,
23 waiters: Arc<crossbeam_queue::SegQueue<Waker>>,
24}
25
26impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
27 fn clone(&self) -> Self {
28 Self {
29 storage: self.storage.clone(),
30 registry: self.registry.clone(),
31 waiters: self.waiters.clone(),
32 }
33 }
34}
35
36impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
37 #[cfg(feature = "unsafe_new")]
38 pub unsafe fn new(inner: Inner) -> Self {
39 let mut registry = Vec::with_capacity(BLOCKS);
40 for _ in 0..BLOCKS {
41 registry.push(AtomicBool::new(true));
42 }
43 Self {
44 storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
45 registry: Arc::new(registry.try_into().unwrap()),
46 waiters: Default::default(),
47 }
48 }
49
50 fn free_block(&self, guard: &BufGuard<'_, Inner, BLOCKS, SIZE>) {
51 let idx = guard.idx;
52 self.registry[idx].store(true, Ordering::SeqCst);
53 if let Some(waker) = self.waiters.pop() {
54 waker.wake();
55 }
56 }
57
58 fn push_waker(&self, waker: Waker) {
59 self.waiters.push(waker);
60 }
61}
62
63impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
64where
65 Inner: Deref<Target = [u8]> + DerefMut,
66{
67 pub async fn acquire_block<'a>(&'a self) -> BufGuard<'a, Inner, BLOCKS, SIZE> {
68 let block = self.try_acquire_block();
69
70 if let Some(block) = block {
71 block
72 } else {
73 BufGuardFuture::new(self).await
74 }
75 }
76
77 fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<'a, Inner, BLOCKS, SIZE>> {
78 let mut idx = None;
79 for (i, e) in self.registry.iter().enumerate() {
80 if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
81 .is_ok()
82 {
83 idx = Some(i);
84 break;
85 }
86 }
87 let idx = idx?;
88
89 let block = unsafe { self.storage.get_block(idx) };
90 BufGuard {
91 block,
92 buffer: self,
93 idx,
94 }
95 .into()
96 }
97}
98
99impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
100 pub fn new_vec() -> Self {
101 let mut registry = Vec::with_capacity(BLOCKS);
102 for _ in 0..BLOCKS {
103 registry.push(AtomicBool::new(true));
104 }
105 Self {
106 storage: Arc::new(BackingBuffer::new_vec()),
107 registry: Arc::new(registry.try_into().unwrap()),
108 waiters: Default::default(),
109 }
110 }
111}
112
113#[cfg(feature = "memmap")]
114impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
115 pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
116 let mut registry = Vec::with_capacity(BLOCKS);
117 for _ in 0..BLOCKS {
118 registry.push(AtomicBool::new(true));
119 }
120 Ok(Self {
121 storage: Arc::new(BackingBuffer::new_memmap(backing_file)?),
122 registry: Arc::new(registry.try_into().unwrap()),
123 waiters: Default::default(),
124 })
125 }
126}
127
128pub struct BufGuard<'a, Inner, const BLOCKS: usize, const SIZE: usize> {
129 buffer: &'a Buffer<Inner, BLOCKS, SIZE>,
130 block: &'a mut [u8],
131 idx: usize,
132}
133
134impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Deref
135 for BufGuard<'a, Inner, BLOCKS, SIZE>
136{
137 type Target = [u8];
138 fn deref(&self) -> &Self::Target {
139 &self.block
140 }
141}
142
143impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> DerefMut
144 for BufGuard<'a, Inner, BLOCKS, SIZE>
145{
146 fn deref_mut(&mut self) -> &mut Self::Target {
147 &mut self.block
148 }
149}
150
151impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<'a, Inner, BLOCKS, SIZE> {
152 fn drop(&mut self) {
153 self.fill(0);
154 self.buffer.free_block(self);
155 }
156}
157
158#[cfg(test)]
159mod tests {
160
161 const SIZE: usize = 1048576 * 3;
162 use std::{io::Write, thread, time::Duration};
163
164 use tempdir::TempDir;
165
166 use super::*;
167 #[test]
168 fn basic_test() {
169 let (tx, rx) = oneshot::channel();
170 thread::spawn(move || {
171 std::thread::sleep(Duration::from_millis(300));
172 tx.send(())
173 });
174 ::futures::executor::block_on(async move {
175 let buffer = Buffer::<_, 10, SIZE>::new_vec();
176 let buf = buffer.acquire_block().await;
177 assert_eq!(buf.len(), SIZE);
178 assert!(buf.iter().all(|&a| a == 0));
179
180 let mut blocks = Vec::new();
181 for _ in 0..9 {
182 let b = buffer.acquire_block().await;
183 blocks.push(b);
184 }
185 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
186 _ = rx.await;
187 drop(buf);
188 });
189 });
190 }
191
192 #[cfg(feature = "memmap")]
193 #[test]
194 fn test_memmap() {
195 let dir = TempDir::new("dir").unwrap();
196
197 let (tx, rx) = oneshot::channel();
198
199 let buffer =
200 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
201 let sent_buffer = buffer.clone();
202 thread::spawn(move || {
203 std::thread::sleep(Duration::from_millis(300));
204 _ = tx.send(());
205 ::futures::executor::block_on(sent_buffer.acquire_block());
206 });
207 ::futures::executor::block_on(async move {
208 let mut v = Vec::with_capacity(10 * SIZE);
209 v.fill(0);
210
211 let buf = buffer.acquire_block().await;
212 assert_eq!(buf.len(), SIZE);
213 assert!(buf.iter().all(|&a| a == 0));
214
215 let mut blocks = Vec::new();
216 for _ in 0..9 {
217 let b = buffer.acquire_block().await;
218 blocks.push(b);
219 }
220 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
221 _ = rx.await;
222 drop(buf);
223 });
224
225 buffer.acquire_block().await;
226 });
227 }
228
229 #[cfg(feature = "memmap")]
230 #[test]
231 fn test_race() {
232 let dir = TempDir::new("dir").unwrap();
233
234 let buffer =
235 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file")).unwrap() };
236 let mut hs = Vec::new();
237 for _ in 0..50 {
238 let buffer = buffer.clone();
239 let h = std::thread::spawn(move || {
240 ::futures::executor::block_on(async move {
241 for _ in 0..500 {
242 buffer.acquire_block().await;
243 }
244 });
245 });
246 hs.push(h);
247 }
248
249 for h in hs {
250 h.join().unwrap();
251 }
252 }
253}