1#![cfg(all(feature = "std"))]
2#![cfg_attr(docsrs, doc(cfg(all(feature = "std"))))]
3use crate::{
8 error::GrantOverflow,
9 splittable::{SplittableViewImpl, SplittableViewImplMut},
10 View, ViewMut,
11};
12use futures::task::AtomicWaker;
13use num_integer::{div_ceil, lcm};
14use std::{
15 convert::TryInto,
16 mem::{size_of, MaybeUninit},
17 pin::Pin,
18 sync::{
19 atomic::{AtomicBool, AtomicU64, Ordering},
20 Arc, Mutex,
21 },
22 task::{Context, Poll, Waker},
23};
24
25struct UnsafeCircularBuffer<T> {
26 ptr: *mut T,
27 size: usize,
28}
29
30unsafe impl<T> Send for UnsafeCircularBuffer<T> where T: Send {}
31unsafe impl<T> Sync for UnsafeCircularBuffer<T> where T: Send {}
32
33impl<T> Drop for UnsafeCircularBuffer<T> {
34 fn drop(&mut self) {
35 unsafe {
38 for i in 0..self.size {
39 std::ptr::drop_in_place(self.ptr.add(i));
40 }
41 vmap::os::unmap_ring(self.ptr as *mut u8, self.size * size_of::<T>()).unwrap();
42 }
43 }
44}
45
46impl<T: Default> UnsafeCircularBuffer<T> {
47 pub fn new(minimum_size: usize) -> Self {
48 let size_bytes = {
51 let granularity = lcm(vmap::allocation_size(), size_of::<T>());
52 div_ceil(minimum_size * size_of::<T>(), granularity)
53 .checked_mul(granularity)
54 .unwrap()
55 };
56 let size = size_bytes / size_of::<T>();
57
58 let ptr = unsafe {
61 let ptr = vmap::os::map_ring(size_bytes).unwrap() as *mut T;
62 for v in std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<T>, size) {
63 v.as_mut_ptr().write(T::default());
64 }
65 ptr
66 };
67
68 Self { ptr, size }
69 }
70}
71
72impl<T> UnsafeCircularBuffer<T> {
73 pub fn len(&self) -> usize {
74 self.size
75 }
76
77 pub unsafe fn range(&self, index: u64, len: usize) -> &[T] {
79 debug_assert!(len <= self.len());
80 let buf_len: u64 = self.len().try_into().unwrap();
81 let offset = index % buf_len;
82 std::slice::from_raw_parts(self.ptr.add(offset.try_into().unwrap()), len)
83 }
84
85 #[allow(clippy::mut_from_ref)]
87 pub unsafe fn range_mut(&self, index: u64, len: usize) -> &mut [T] {
88 debug_assert!(len <= self.len());
89 let buf_len: u64 = self.len().try_into().unwrap();
90 let offset = index % buf_len;
91 std::slice::from_raw_parts_mut(self.ptr.add(offset.try_into().unwrap()), len)
92 }
93}
94
95struct State<T> {
97 buffer: UnsafeCircularBuffer<T>,
98 closed: AtomicBool, head: AtomicU64, tail: AtomicU64, write_waker: AtomicWaker, read_waker: Mutex<Option<Box<dyn Fn() + Send + Sync>>>, }
104
105impl<T: Default> State<T> {
106 fn new(minimum_size: usize) -> Self {
107 Self {
110 buffer: UnsafeCircularBuffer::new(minimum_size + 1),
111 closed: AtomicBool::new(false),
112 head: AtomicU64::new(0),
113 tail: AtomicU64::new(0),
114 write_waker: AtomicWaker::new(),
115 read_waker: Mutex::new(None),
116 }
117 }
118}
119
120impl<T> State<T> {
121 fn readable_len(&self, start: u64) -> usize {
122 (self.tail.load(Ordering::Relaxed) - start)
123 .try_into()
124 .unwrap()
125 }
126
127 fn writeable_len(&self) -> usize {
128 self.buffer.len() - self.readable_len(self.head.load(Ordering::Relaxed))
129 }
130}
131
132pub struct Sink<T> {
136 state: Arc<State<T>>,
137 tail: u64,
138 available: usize,
139 read_waker: Option<Box<dyn Fn() + Send + Sync>>,
140}
141
142impl<T> Sink<T> {
143 fn new(state: Arc<State<T>>) -> Self {
144 Self {
145 state,
146 tail: 0,
147 available: 0,
148 read_waker: None,
149 }
150 }
151
152 fn wake_readers(&mut self) {
153 if self.read_waker.is_none() {
154 let mut lock = self
155 .state
156 .read_waker
157 .lock()
158 .expect("another thread panicked");
159 std::mem::swap(&mut *lock, &mut self.read_waker);
160 }
161 if let Some(read_waker) = self.read_waker.as_ref() {
162 read_waker()
163 }
164 }
165}
166
167impl<T> Drop for Sink<T> {
168 fn drop(&mut self) {
169 self.state.closed.store(true, Ordering::Relaxed);
170 self.wake_readers(); }
172}
173
174impl<T> View for Sink<T> {
175 type Item = T;
176 type Error = GrantOverflow;
177
178 fn view(&self) -> &[T] {
179 unsafe { self.state.buffer.range(self.tail, self.available) }
181 }
182
183 fn poll_grant(
184 mut self: Pin<&mut Self>,
185 cx: &mut Context,
186 count: usize,
187 ) -> Poll<Result<(), GrantOverflow>> {
188 if count > self.state.buffer.len() {
189 return Poll::Ready(Err(GrantOverflow(self.state.buffer.len())));
190 }
191
192 if self.available >= count {
193 return Poll::Ready(Ok(()));
194 }
195
196 let available = self.state.writeable_len();
200 if available >= count {
201 self.available = available;
202 Poll::Ready(Ok(()))
203 } else {
204 self.state.write_waker.register(cx.waker());
205 let available = self.state.writeable_len();
206 if available >= count || self.state.closed.load(Ordering::Relaxed) {
207 self.available = available;
208 Poll::Ready(Ok(()))
209 } else {
210 Poll::Pending
211 }
212 }
213 }
214
215 fn release(&mut self, count: usize) {
216 if count == 0 {
217 return;
218 }
219
220 assert!(
221 count <= self.available,
222 "attempted to release more than current grant"
223 );
224
225 self.available -= count;
227 let count: u64 = count.try_into().unwrap();
228 self.tail += count;
229 self.state.tail.store(self.tail, Ordering::Relaxed);
230 self.wake_readers();
231 }
232}
233
234impl<T> ViewMut for Sink<T> {
235 fn view_mut(&mut self) -> &mut [T] {
236 unsafe { self.state.buffer.range_mut(self.tail, self.available) }
238 }
239}
240
241pub struct Source<T> {
245 state: Arc<State<T>>,
246}
247
248impl<T> Source<T> {
249 fn new(state: Arc<State<T>>) -> Self {
250 Self { state }
251 }
252}
253
254impl<T> Drop for Source<T> {
255 fn drop(&mut self) {
256 self.state.closed.store(true, Ordering::Relaxed);
257 self.state.write_waker.wake();
258 }
259}
260
261unsafe impl<T> SplittableViewImpl for Source<T> {
262 type Item = T;
263 type Error = GrantOverflow;
264
265 unsafe fn set_reader_waker(&self, waker: impl Fn() + Send + Sync + 'static) {
266 let mut lock = self
267 .state
268 .read_waker
269 .lock()
270 .expect("another thread panicked");
271 assert!(lock.is_none(), "reader waker already set!");
272 *lock = Some(Box::new(waker));
273 }
274
275 unsafe fn set_head(&self, index: u64) {
276 self.state.head.store(index, Ordering::Relaxed);
277 self.state.write_waker.wake();
278 }
279
280 unsafe fn compare_set_head(&self, index: u64) {
281 let mut current = self.state.head.load(Ordering::Relaxed);
283 if index > current {
284 while let Err(previous) = self.state.head.compare_exchange_weak(
285 current,
286 index,
287 Ordering::Relaxed,
288 Ordering::Relaxed,
289 ) {
290 if index > previous {
291 current = previous
292 } else {
293 break;
294 }
295 }
296 }
297 self.state.write_waker.wake();
298 }
299
300 fn poll_available(
301 self: Pin<&Self>,
302 cx: &mut Context,
303 register_wakeup: impl Fn(&Waker),
304 index: u64,
305 len: usize,
306 ) -> Poll<Result<usize, Self::Error>> {
307 let max_len = self.state.buffer.len();
308 if len > max_len {
309 return Poll::Ready(Err(GrantOverflow(max_len)));
310 }
311
312 let available = self.state.readable_len(index);
316 if available >= len {
317 Poll::Ready(Ok(available))
318 } else {
319 register_wakeup(cx.waker());
320 let available = self.state.readable_len(index);
321 if available >= len || self.state.closed.load(Ordering::Relaxed) {
322 Poll::Ready(Ok(available))
323 } else {
324 Poll::Pending
325 }
326 }
327 }
328
329 unsafe fn view(&self, index: u64, len: usize) -> &[Self::Item] {
330 self.state.buffer.range(index, len)
331 }
332}
333
334unsafe impl<T> SplittableViewImplMut for Source<T> {
335 unsafe fn view_mut(&self, index: u64, len: usize) -> &mut [Self::Item] {
336 self.state.buffer.range_mut(index, len)
337 }
338}
339
340pub fn circular_buffer<T: Send + Sync + Default + 'static>(
345 min_size: usize,
346) -> (Sink<T>, Source<T>) {
347 assert!(min_size > 0, "`min_size` must be greater than 0");
348
349 let state = Arc::new(State::new(min_size));
350
351 (Sink::new(state.clone()), Source::new(state))
352}