ringbuff/
ringbuf.rs

1//! RingBuf implementation for synchronous/asynchronous programming.
2
3use std::{
4    cmp,
5    fmt::Debug,
6    io::Result,
7    ops::{Range, RangeTo},
8    task::Poll,
9};
10
11use futures::io::{AsyncBufRead, AsyncRead, AsyncWrite};
12
13// An in-memory ring buffer implemenation.
14pub struct RingBuf {
15    /// inner memory block.
16    memory_block: Box<[u8]>,
17    /// cursor for read position.
18    read_pos: u64,
19    /// cursor for write position.
20    write_pos: u64,
21}
22
23impl Debug for RingBuf {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("RingBuf")
26            .field("memory", &self.memory_block.len())
27            .field("read", &self.read_pos)
28            .field("write", &self.write_pos)
29            .finish()
30    }
31}
32
33impl RingBuf {
34    /// Create a ringbuf with specify capacity.
35    pub fn with_capacity(len: usize) -> Self {
36        assert!(len > 0, "capacity is zero.");
37        Self {
38            memory_block: vec![0; len].into_boxed_slice(),
39            read_pos: 0,
40            write_pos: 0,
41        }
42    }
43
44    /// Returns the length of readable data.
45    pub fn readable(&self) -> usize {
46        (self.write_pos - self.read_pos) as usize
47    }
48
49    /// Returns the capacity of writable data.
50    pub fn writable(&self) -> usize {
51        (self.read_pos + self.memory_block.len() as u64 - self.write_pos) as usize
52    }
53
54    /// Read data from the ringbuf.
55    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
56        let read_size = cmp::min(self.readable(), buf.len());
57
58        if read_size == 0 {
59            return Ok(0);
60        }
61
62        match self.readable_ranges(read_size) {
63            (first, Some(second)) => {
64                let first_part = &self.memory_block[first];
65                let second_part = &self.memory_block[second];
66                buf[..first_part.len()].copy_from_slice(first_part);
67
68                buf[first_part.len()..first_part.len() + second_part.len()]
69                    .copy_from_slice(second_part);
70            }
71            (first, None) => {
72                let first_part = &self.memory_block[first];
73                buf[..first_part.len()].copy_from_slice(first_part);
74            }
75        }
76
77        self.read_pos += read_size as u64;
78
79        Ok(read_size)
80    }
81
82    /// Write data into ringbuf.
83    pub fn write(&mut self, buf: &[u8]) -> Result<usize> {
84        let write_size = cmp::min(self.writable(), buf.len());
85
86        if write_size == 0 {
87            return Ok(0);
88        }
89
90        match self.writable_ranges(write_size) {
91            (first, Some(second)) => {
92                let first_source = &buf[..first.len()];
93                self.memory_block[first].copy_from_slice(first_source);
94
95                let second_part = &mut self.memory_block[second];
96                second_part.copy_from_slice(
97                    &buf[first_source.len()..first_source.len() + second_part.len()],
98                );
99            }
100            (first, None) => {
101                let first_source = &buf[..first.len()];
102                self.memory_block[first].copy_from_slice(first_source);
103            }
104        }
105
106        self.write_pos += write_size as u64;
107
108        Ok(write_size)
109    }
110
111    /// Unsafe function, directly returns first part of writable memory blocks.
112    pub unsafe fn writable_buf(&mut self) -> &mut [u8] {
113        let (range, _) = self.writable_ranges(self.writable());
114        &mut self.memory_block[range]
115    }
116
117    /// Unsafe function, directly returns first part of redable memory blocks.
118    pub unsafe fn readable_buf(&mut self) -> &[u8] {
119        let (range, _) = self.readable_ranges(self.readable());
120        &self.memory_block[range]
121    }
122
123    /// Unsafe function, directly advance readable cursor .
124    pub unsafe fn readable_consume(&mut self, amt: usize) {
125        self.read_pos += amt as u64;
126        assert!(
127            !(self.read_pos > self.write_pos),
128            "advance_readable_pos: overflow"
129        );
130    }
131
132    /// Unsafe function, directly advance writable cursor .
133    pub unsafe fn writable_consume(&mut self, amt: usize) {
134        self.write_pos += amt as u64;
135        assert!(
136            !((self.read_pos + self.memory_block.len() as u64) < self.write_pos),
137            "advance_writable_pos: overflow"
138        );
139    }
140
141    fn writable_ranges(&self, write_size: usize) -> (Range<usize>, Option<RangeTo<usize>>) {
142        let write_end_pos = self.write_pos + write_size as u64;
143
144        assert!(
145            !(write_end_pos > self.read_pos + self.memory_block.len() as u64),
146            "write_size is overflow."
147        );
148
149        let start = (self.write_pos % self.memory_block.len() as u64) as usize;
150        let end = (write_end_pos % self.memory_block.len() as u64) as usize;
151
152        if write_size == 0 {
153            return (start..end, None);
154        }
155
156        if start < end {
157            (start..end, None)
158        } else {
159            (start..self.memory_block.len(), Some(..end))
160        }
161    }
162
163    fn readable_ranges(&self, read_size: usize) -> (Range<usize>, Option<RangeTo<usize>>) {
164        let read_end_pos = self.read_pos + read_size as u64;
165
166        assert!(!(read_end_pos > self.write_pos), "read_size is overflow.");
167
168        let start = (self.read_pos % self.memory_block.len() as u64) as usize;
169        let end = (read_end_pos % self.memory_block.len() as u64) as usize;
170
171        if read_size == 0 {
172            return (start..end, None);
173        }
174
175        if start < end {
176            (start..end, None)
177        } else {
178            (start..self.memory_block.len(), Some(..end))
179        }
180    }
181}
182
183impl std::io::Read for RingBuf {
184    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
185        Self::read(self, buf)
186    }
187}
188
189impl std::io::Write for RingBuf {
190    fn write(&mut self, buf: &[u8]) -> Result<usize> {
191        Self::write(self, buf)
192    }
193
194    fn flush(&mut self) -> Result<()> {
195        Ok(())
196    }
197}
198
199#[cfg(feature = "async")]
200#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
201impl AsyncRead for RingBuf {
202    fn poll_read(
203        mut self: std::pin::Pin<&mut Self>,
204        _cx: &mut std::task::Context<'_>,
205        buf: &mut [u8],
206    ) -> std::task::Poll<Result<usize>> {
207        Poll::Ready(self.read(buf))
208    }
209}
210
211#[cfg(feature = "async")]
212#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
213impl AsyncBufRead for RingBuf {
214    fn poll_fill_buf(
215        self: std::pin::Pin<&mut Self>,
216        _cx: &mut std::task::Context<'_>,
217    ) -> Poll<Result<&[u8]>> {
218        unsafe { Poll::Ready(Ok(self.get_mut().readable_buf())) }
219    }
220
221    fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
222        unsafe {
223            self.get_mut().readable_consume(amt);
224        }
225    }
226}
227
228#[cfg(feature = "async")]
229#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
230impl AsyncWrite for RingBuf {
231    fn poll_write(
232        mut self: std::pin::Pin<&mut Self>,
233        _cx: &mut std::task::Context<'_>,
234        buf: &[u8],
235    ) -> Poll<Result<usize>> {
236        Poll::Ready(self.write(buf))
237    }
238
239    fn poll_flush(
240        self: std::pin::Pin<&mut Self>,
241        _cx: &mut std::task::Context<'_>,
242    ) -> Poll<Result<()>> {
243        Poll::Ready(Ok(()))
244    }
245
246    fn poll_close(
247        self: std::pin::Pin<&mut Self>,
248        _cx: &mut std::task::Context<'_>,
249    ) -> Poll<Result<()>> {
250        Poll::Ready(Ok(()))
251    }
252}
253
254#[cfg(test)]
255mod tests {
256
257    use super::*;
258
259    #[cfg(not(target_os = "windows"))]
260    #[test_fuzz::test_fuzz]
261    fn fuzz_test_unsafe_fn(offset: usize) {
262        let offset = offset % 11;
263        let mut ringbuf = RingBuf::with_capacity(11);
264
265        unsafe {
266            assert_eq!(ringbuf.writable_buf().len(), 11);
267            ringbuf.writable_consume(offset);
268
269            assert_eq!(ringbuf.writable_buf().len(), 11 - offset);
270
271            assert_eq!(ringbuf.readable_buf().len(), offset);
272        }
273    }
274
275    #[test]
276    fn test_unsafe_fns() {
277        let mut ringbuf = RingBuf::with_capacity(11);
278        unsafe {
279            assert_eq!(ringbuf.writable_buf().len(), 11);
280            ringbuf.writable_buf().copy_from_slice(b"12345678901");
281            ringbuf.writable_consume(5);
282            assert_eq!(ringbuf.writable_buf().len(), 6);
283            assert_eq!(ringbuf.readable_buf().len(), 5);
284            ringbuf.writable_consume(6);
285            assert_eq!(ringbuf.writable(), 0);
286            assert_eq!(ringbuf.writable_buf().len(), 0);
287            assert_eq!(ringbuf.readable_buf().len(), 11);
288            ringbuf.readable_consume(3);
289            assert_eq!(ringbuf.readable_buf(), b"45678901");
290            assert_eq!(ringbuf.writable_buf(), b"123");
291        }
292    }
293
294    #[test]
295    fn test_pos() {
296        let mut ringbuf = RingBuf::with_capacity(11);
297
298        unsafe {
299            ringbuf.writable_consume(11);
300            assert_eq!(ringbuf.readable(), 11);
301            assert_eq!(ringbuf.writable(), 0);
302            ringbuf.readable_consume(11);
303            assert_eq!(ringbuf.readable(), 0);
304            assert_eq!(ringbuf.writable(), 11);
305            ringbuf.writable_consume(10);
306            ringbuf.readable_consume(9);
307            assert_eq!(ringbuf.readable(), 1);
308            assert_eq!(ringbuf.writable(), 10);
309        }
310    }
311
312    #[test]
313    fn test_io() {
314        let mut ringbuf = RingBuf::with_capacity(11);
315
316        assert_eq!(ringbuf.write(b"12345678901234").unwrap(), 11);
317        assert_eq!(ringbuf.writable(), 0);
318        assert_eq!(ringbuf.readable(), 11);
319
320        let mut buf = vec![0; 12];
321
322        assert_eq!(ringbuf.read(&mut buf).unwrap(), 11);
323        assert_eq!(&buf[..11], b"12345678901");
324        assert_eq!(ringbuf.writable(), 11);
325        assert_eq!(ringbuf.readable(), 0);
326
327        unsafe {
328            ringbuf.writable_consume(5);
329            ringbuf.readable_consume(5);
330            assert_eq!(ringbuf.writable(), 11);
331            assert_eq!(ringbuf.readable(), 0);
332        }
333
334        assert_eq!(ringbuf.write(b"67890123412345").unwrap(), 11);
335
336        let mut buf = vec![0; 7];
337
338        assert_eq!(ringbuf.read(&mut buf).unwrap(), 7);
339
340        assert_eq!(&buf, b"6789012");
341
342        assert_eq!(ringbuf.writable(), 7);
343        assert_eq!(ringbuf.readable(), 4);
344    }
345
346    #[test]
347    fn test_boundary_condition() {
348        let mut ringbuf = RingBuf::with_capacity(1024 * 3);
349
350        unsafe {
351            ringbuf.writable_consume(1024 * 3);
352
353            assert_eq!(ringbuf.readable(), 1024 * 3);
354
355            assert_eq!(ringbuf.writable(), 0);
356
357            ringbuf.readable_consume(3);
358
359            assert_eq!(ringbuf.writable(), 3);
360        }
361
362        let mut ringbuf = RingBuf::with_capacity(1024 * 3);
363
364        unsafe {
365            ringbuf.writable_consume(1024 * 3 - 1);
366
367            assert_eq!(ringbuf.readable(), 1024 * 3 - 1);
368
369            assert_eq!(ringbuf.writable(), 1);
370        }
371    }
372}