1use std::{
4 cmp,
5 fmt::Debug,
6 io::Result,
7 ops::{Range, RangeTo},
8 task::Poll,
9};
10
11use futures::io::{AsyncBufRead, AsyncRead, AsyncWrite};
12
13pub struct RingBuf {
15 memory_block: Box<[u8]>,
17 read_pos: u64,
19 write_pos: u64,
21}
22
23impl Debug for RingBuf {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("RingBuf")
26 .field("memory", &self.memory_block.len())
27 .field("read", &self.read_pos)
28 .field("write", &self.write_pos)
29 .finish()
30 }
31}
32
33impl RingBuf {
34 pub fn with_capacity(len: usize) -> Self {
36 assert!(len > 0, "capacity is zero.");
37 Self {
38 memory_block: vec![0; len].into_boxed_slice(),
39 read_pos: 0,
40 write_pos: 0,
41 }
42 }
43
44 pub fn readable(&self) -> usize {
46 (self.write_pos - self.read_pos) as usize
47 }
48
49 pub fn writable(&self) -> usize {
51 (self.read_pos + self.memory_block.len() as u64 - self.write_pos) as usize
52 }
53
54 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
56 let read_size = cmp::min(self.readable(), buf.len());
57
58 if read_size == 0 {
59 return Ok(0);
60 }
61
62 match self.readable_ranges(read_size) {
63 (first, Some(second)) => {
64 let first_part = &self.memory_block[first];
65 let second_part = &self.memory_block[second];
66 buf[..first_part.len()].copy_from_slice(first_part);
67
68 buf[first_part.len()..first_part.len() + second_part.len()]
69 .copy_from_slice(second_part);
70 }
71 (first, None) => {
72 let first_part = &self.memory_block[first];
73 buf[..first_part.len()].copy_from_slice(first_part);
74 }
75 }
76
77 self.read_pos += read_size as u64;
78
79 Ok(read_size)
80 }
81
82 pub fn write(&mut self, buf: &[u8]) -> Result<usize> {
84 let write_size = cmp::min(self.writable(), buf.len());
85
86 if write_size == 0 {
87 return Ok(0);
88 }
89
90 match self.writable_ranges(write_size) {
91 (first, Some(second)) => {
92 let first_source = &buf[..first.len()];
93 self.memory_block[first].copy_from_slice(first_source);
94
95 let second_part = &mut self.memory_block[second];
96 second_part.copy_from_slice(
97 &buf[first_source.len()..first_source.len() + second_part.len()],
98 );
99 }
100 (first, None) => {
101 let first_source = &buf[..first.len()];
102 self.memory_block[first].copy_from_slice(first_source);
103 }
104 }
105
106 self.write_pos += write_size as u64;
107
108 Ok(write_size)
109 }
110
111 pub unsafe fn writable_buf(&mut self) -> &mut [u8] {
113 let (range, _) = self.writable_ranges(self.writable());
114 &mut self.memory_block[range]
115 }
116
117 pub unsafe fn readable_buf(&mut self) -> &[u8] {
119 let (range, _) = self.readable_ranges(self.readable());
120 &self.memory_block[range]
121 }
122
123 pub unsafe fn readable_consume(&mut self, amt: usize) {
125 self.read_pos += amt as u64;
126 assert!(
127 !(self.read_pos > self.write_pos),
128 "advance_readable_pos: overflow"
129 );
130 }
131
132 pub unsafe fn writable_consume(&mut self, amt: usize) {
134 self.write_pos += amt as u64;
135 assert!(
136 !((self.read_pos + self.memory_block.len() as u64) < self.write_pos),
137 "advance_writable_pos: overflow"
138 );
139 }
140
141 fn writable_ranges(&self, write_size: usize) -> (Range<usize>, Option<RangeTo<usize>>) {
142 let write_end_pos = self.write_pos + write_size as u64;
143
144 assert!(
145 !(write_end_pos > self.read_pos + self.memory_block.len() as u64),
146 "write_size is overflow."
147 );
148
149 let start = (self.write_pos % self.memory_block.len() as u64) as usize;
150 let end = (write_end_pos % self.memory_block.len() as u64) as usize;
151
152 if write_size == 0 {
153 return (start..end, None);
154 }
155
156 if start < end {
157 (start..end, None)
158 } else {
159 (start..self.memory_block.len(), Some(..end))
160 }
161 }
162
163 fn readable_ranges(&self, read_size: usize) -> (Range<usize>, Option<RangeTo<usize>>) {
164 let read_end_pos = self.read_pos + read_size as u64;
165
166 assert!(!(read_end_pos > self.write_pos), "read_size is overflow.");
167
168 let start = (self.read_pos % self.memory_block.len() as u64) as usize;
169 let end = (read_end_pos % self.memory_block.len() as u64) as usize;
170
171 if read_size == 0 {
172 return (start..end, None);
173 }
174
175 if start < end {
176 (start..end, None)
177 } else {
178 (start..self.memory_block.len(), Some(..end))
179 }
180 }
181}
182
183impl std::io::Read for RingBuf {
184 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
185 Self::read(self, buf)
186 }
187}
188
189impl std::io::Write for RingBuf {
190 fn write(&mut self, buf: &[u8]) -> Result<usize> {
191 Self::write(self, buf)
192 }
193
194 fn flush(&mut self) -> Result<()> {
195 Ok(())
196 }
197}
198
199#[cfg(feature = "async")]
200#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
201impl AsyncRead for RingBuf {
202 fn poll_read(
203 mut self: std::pin::Pin<&mut Self>,
204 _cx: &mut std::task::Context<'_>,
205 buf: &mut [u8],
206 ) -> std::task::Poll<Result<usize>> {
207 Poll::Ready(self.read(buf))
208 }
209}
210
211#[cfg(feature = "async")]
212#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
213impl AsyncBufRead for RingBuf {
214 fn poll_fill_buf(
215 self: std::pin::Pin<&mut Self>,
216 _cx: &mut std::task::Context<'_>,
217 ) -> Poll<Result<&[u8]>> {
218 unsafe { Poll::Ready(Ok(self.get_mut().readable_buf())) }
219 }
220
221 fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
222 unsafe {
223 self.get_mut().readable_consume(amt);
224 }
225 }
226}
227
228#[cfg(feature = "async")]
229#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
230impl AsyncWrite for RingBuf {
231 fn poll_write(
232 mut self: std::pin::Pin<&mut Self>,
233 _cx: &mut std::task::Context<'_>,
234 buf: &[u8],
235 ) -> Poll<Result<usize>> {
236 Poll::Ready(self.write(buf))
237 }
238
239 fn poll_flush(
240 self: std::pin::Pin<&mut Self>,
241 _cx: &mut std::task::Context<'_>,
242 ) -> Poll<Result<()>> {
243 Poll::Ready(Ok(()))
244 }
245
246 fn poll_close(
247 self: std::pin::Pin<&mut Self>,
248 _cx: &mut std::task::Context<'_>,
249 ) -> Poll<Result<()>> {
250 Poll::Ready(Ok(()))
251 }
252}
253
254#[cfg(test)]
255mod tests {
256
257 use super::*;
258
259 #[cfg(not(target_os = "windows"))]
260 #[test_fuzz::test_fuzz]
261 fn fuzz_test_unsafe_fn(offset: usize) {
262 let offset = offset % 11;
263 let mut ringbuf = RingBuf::with_capacity(11);
264
265 unsafe {
266 assert_eq!(ringbuf.writable_buf().len(), 11);
267 ringbuf.writable_consume(offset);
268
269 assert_eq!(ringbuf.writable_buf().len(), 11 - offset);
270
271 assert_eq!(ringbuf.readable_buf().len(), offset);
272 }
273 }
274
275 #[test]
276 fn test_unsafe_fns() {
277 let mut ringbuf = RingBuf::with_capacity(11);
278 unsafe {
279 assert_eq!(ringbuf.writable_buf().len(), 11);
280 ringbuf.writable_buf().copy_from_slice(b"12345678901");
281 ringbuf.writable_consume(5);
282 assert_eq!(ringbuf.writable_buf().len(), 6);
283 assert_eq!(ringbuf.readable_buf().len(), 5);
284 ringbuf.writable_consume(6);
285 assert_eq!(ringbuf.writable(), 0);
286 assert_eq!(ringbuf.writable_buf().len(), 0);
287 assert_eq!(ringbuf.readable_buf().len(), 11);
288 ringbuf.readable_consume(3);
289 assert_eq!(ringbuf.readable_buf(), b"45678901");
290 assert_eq!(ringbuf.writable_buf(), b"123");
291 }
292 }
293
294 #[test]
295 fn test_pos() {
296 let mut ringbuf = RingBuf::with_capacity(11);
297
298 unsafe {
299 ringbuf.writable_consume(11);
300 assert_eq!(ringbuf.readable(), 11);
301 assert_eq!(ringbuf.writable(), 0);
302 ringbuf.readable_consume(11);
303 assert_eq!(ringbuf.readable(), 0);
304 assert_eq!(ringbuf.writable(), 11);
305 ringbuf.writable_consume(10);
306 ringbuf.readable_consume(9);
307 assert_eq!(ringbuf.readable(), 1);
308 assert_eq!(ringbuf.writable(), 10);
309 }
310 }
311
312 #[test]
313 fn test_io() {
314 let mut ringbuf = RingBuf::with_capacity(11);
315
316 assert_eq!(ringbuf.write(b"12345678901234").unwrap(), 11);
317 assert_eq!(ringbuf.writable(), 0);
318 assert_eq!(ringbuf.readable(), 11);
319
320 let mut buf = vec![0; 12];
321
322 assert_eq!(ringbuf.read(&mut buf).unwrap(), 11);
323 assert_eq!(&buf[..11], b"12345678901");
324 assert_eq!(ringbuf.writable(), 11);
325 assert_eq!(ringbuf.readable(), 0);
326
327 unsafe {
328 ringbuf.writable_consume(5);
329 ringbuf.readable_consume(5);
330 assert_eq!(ringbuf.writable(), 11);
331 assert_eq!(ringbuf.readable(), 0);
332 }
333
334 assert_eq!(ringbuf.write(b"67890123412345").unwrap(), 11);
335
336 let mut buf = vec![0; 7];
337
338 assert_eq!(ringbuf.read(&mut buf).unwrap(), 7);
339
340 assert_eq!(&buf, b"6789012");
341
342 assert_eq!(ringbuf.writable(), 7);
343 assert_eq!(ringbuf.readable(), 4);
344 }
345
346 #[test]
347 fn test_boundary_condition() {
348 let mut ringbuf = RingBuf::with_capacity(1024 * 3);
349
350 unsafe {
351 ringbuf.writable_consume(1024 * 3);
352
353 assert_eq!(ringbuf.readable(), 1024 * 3);
354
355 assert_eq!(ringbuf.writable(), 0);
356
357 ringbuf.readable_consume(3);
358
359 assert_eq!(ringbuf.writable(), 3);
360 }
361
362 let mut ringbuf = RingBuf::with_capacity(1024 * 3);
363
364 unsafe {
365 ringbuf.writable_consume(1024 * 3 - 1);
366
367 assert_eq!(ringbuf.readable(), 1024 * 3 - 1);
368
369 assert_eq!(ringbuf.writable(), 1);
370 }
371 }
372}