1use std::{
2 ops::{Deref, DerefMut},
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 task::Waker,
8};
9
10use backing_buffer::BackingBuffer;
11use futures::BufGuardFuture;
12
13mod backing_buffer;
14mod futures;
15
16#[cfg(feature = "memmap")]
17pub use memmap2::MmapMut;
18
19pub struct Buffer<Inner, const BLOCKS: usize, const SIZE: usize> {
20 storage: Arc<BackingBuffer<Inner, BLOCKS, SIZE>>,
21 registry: Arc<[AtomicBool; BLOCKS]>,
23 waiters: Arc<crossbeam_queue::SegQueue<Waker>>,
24}
25
26impl<Inner, const BLOCKS: usize, const SIZE: usize> Clone for Buffer<Inner, BLOCKS, SIZE> {
27 fn clone(&self) -> Self {
28 Self {
29 storage: self.storage.clone(),
30 registry: self.registry.clone(),
31 waiters: self.waiters.clone(),
32 }
33 }
34}
35
36impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE> {
37 #[cfg(feature = "unsafe_new")]
38 pub unsafe fn new(inner: Inner) -> Self {
39 let mut registry = Vec::with_capacity(BLOCKS);
40 for _ in 0..BLOCKS {
41 registry.push(AtomicBool::new(true));
42 }
43 Self {
44 storage: unsafe { Arc::new(BackingBuffer::new(inner)) },
45 registry: Arc::new(registry.try_into().unwrap()),
46 waiters: Default::default(),
47 }
48 }
49
50 fn free_block(&self, guard: &mut BufGuard<Inner, BLOCKS, SIZE>) {
51 let idx = guard.idx;
52 self.registry[idx].store(true, Ordering::SeqCst);
53 if let Some(waker) = self.waiters.pop() {
54 waker.wake();
55 }
56 }
57
58 fn push_waker(&self, waker: Waker) {
59 self.waiters.push(waker);
60 }
61}
62
63impl<Inner, const BLOCKS: usize, const SIZE: usize> Buffer<Inner, BLOCKS, SIZE>
64where
65 Inner: Deref<Target = [u8]> + DerefMut,
66{
67 pub async fn acquire_block(&self) -> BufGuard<Inner, BLOCKS, SIZE> {
68 let block = self.try_acquire_block();
69
70 if let Some(block) = block {
71 block
72 } else {
73 BufGuardFuture::new(self).await
74 }
75 }
76
77 fn try_acquire_block<'a>(&'a self) -> Option<BufGuard<Inner, BLOCKS, SIZE>> {
78 let mut idx = None;
79 for (i, e) in self.registry.iter().enumerate() {
80 if e.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
81 .is_ok()
82 {
83 idx = Some(i);
84 break;
85 }
86 }
87 let idx = idx?;
88
89 let block = unsafe { self.storage.get_block(idx) };
90 BufGuard {
91 block,
92 buffer: self.clone(),
93 idx,
94 }
95 .into()
96 }
97}
98
99impl<const BLOCKS: usize, const SIZE: usize> Buffer<Box<[u8]>, BLOCKS, SIZE> {
100 pub fn new_vec() -> Self {
101 let mut registry = Vec::with_capacity(BLOCKS);
102 for _ in 0..BLOCKS {
103 registry.push(AtomicBool::new(true));
104 }
105 Self {
106 storage: Arc::new(BackingBuffer::new_vec()),
107 registry: Arc::new(registry.try_into().unwrap()),
108 waiters: Default::default(),
109 }
110 }
111}
112
113#[cfg(feature = "memmap")]
114impl<const BLOCKS: usize, const SIZE: usize> Buffer<memmap2::MmapMut, BLOCKS, SIZE> {
115 pub unsafe fn new_memmap(backing_file: &std::path::Path) -> std::io::Result<Self> {
116 let mut registry = Vec::with_capacity(BLOCKS);
117 for _ in 0..BLOCKS {
118 registry.push(AtomicBool::new(true));
119 }
120 Ok(Self {
121 storage: Arc::new(BackingBuffer::new_memmap(backing_file)?),
122 registry: Arc::new(registry.try_into().unwrap()),
123 waiters: Default::default(),
124 })
125 }
126}
127
128#[cfg(feature = "bytes")]
129impl<const BLOCKS: usize, const SIZE: usize> Buffer<bytes::BytesMut, BLOCKS, SIZE> {
130 pub unsafe fn new_bytes() -> std::io::Result<Self> {
131 let mut registry = Vec::with_capacity(BLOCKS);
132 for _ in 0..BLOCKS {
133 registry.push(AtomicBool::new(true));
134 }
135 Ok(Self {
136 storage: Arc::new(BackingBuffer::new_bytes()),
137 registry: Arc::new(registry.try_into().unwrap()),
138 waiters: Default::default(),
139 })
140 }
141}
142
143pub struct BufGuard<Inner, const BLOCKS: usize, const SIZE: usize> {
144 buffer: Buffer<Inner, BLOCKS, SIZE>,
145 block: &'static mut [u8],
146 idx: usize,
147}
148
149impl<Inner, const BLOCKS: usize, const SIZE: usize> Deref for BufGuard<Inner, BLOCKS, SIZE> {
150 type Target = [u8];
151 fn deref(&self) -> &Self::Target {
152 &self.block
153 }
154}
155
156impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> DerefMut for BufGuard<Inner, BLOCKS, SIZE> {
157 fn deref_mut(&mut self) -> &mut Self::Target {
158 &mut self.block
159 }
160}
161
162impl<'a, Inner, const BLOCKS: usize, const SIZE: usize> Drop for BufGuard<Inner, BLOCKS, SIZE> {
163 fn drop(&mut self) {
164 self.fill(0);
165 let buffer = self.buffer.clone();
166 buffer.free_block(self);
167 }
168}
169
170#[cfg(test)]
171mod tests {
172
173 const SIZE: usize = 1048576 * 3;
174 use std::{io::Write, thread, time::Duration};
175
176 use tempdir::TempDir;
177
178 use super::*;
179 #[test]
180 fn basic_test() {
181 let (tx, rx) = oneshot::channel();
182 thread::spawn(move || {
183 std::thread::sleep(Duration::from_millis(300));
184 tx.send(())
185 });
186 ::futures::executor::block_on(async move {
187 let buffer = Buffer::<_, 10, SIZE>::new_vec();
188 let buf = buffer.acquire_block().await;
189 assert_eq!(buf.len(), SIZE);
190 assert!(buf.iter().all(|&a| a == 0));
191
192 let mut blocks = Vec::new();
193 for _ in 0..9 {
194 let b = buffer.acquire_block().await;
195 blocks.push(b);
196 }
197 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
198 _ = rx.await;
199 drop(buf);
200 });
201 });
202 }
203
204 #[cfg(feature = "memmap")]
205 #[test]
206 fn test_memmap() {
207 let dir = TempDir::new("dir").unwrap();
208
209 let (tx, rx) = oneshot::channel();
210
211 let buffer =
212 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file.txt")).unwrap() };
213 let sent_buffer = buffer.clone();
214 thread::spawn(move || {
215 std::thread::sleep(Duration::from_millis(300));
216 _ = tx.send(());
217 ::futures::executor::block_on(sent_buffer.acquire_block());
218 });
219 ::futures::executor::block_on(async move {
220 let mut v = Vec::with_capacity(10 * SIZE);
221 v.fill(0);
222
223 let buf = buffer.acquire_block().await;
224 assert_eq!(buf.len(), SIZE);
225 assert!(buf.iter().all(|&a| a == 0));
226
227 let mut blocks = Vec::new();
228 for _ in 0..9 {
229 let b = buffer.acquire_block().await;
230 blocks.push(b);
231 }
232 let (_, _) = ::futures::join!(buffer.acquire_block(), async move {
233 _ = rx.await;
234 drop(buf);
235 });
236
237 buffer.acquire_block().await;
238 });
239 }
240
241 #[cfg(feature = "memmap")]
242 #[test]
243 fn test_race() {
244 let dir = TempDir::new("dir").unwrap();
245
246 let buffer =
247 unsafe { Buffer::<_, 10, SIZE>::new_memmap(&dir.path().join("file")).unwrap() };
248 let mut hs = Vec::new();
249 for _ in 0..50 {
250 let buffer = buffer.clone();
251 let h = std::thread::spawn(move || {
252 ::futures::executor::block_on(async move {
253 for _ in 0..500 {
254 buffer.acquire_block().await;
255 }
256 });
257 });
258 hs.push(h);
259 }
260
261 for h in hs {
262 h.join().unwrap();
263 }
264 }
265}