msquic_async/
buffer.rs

1use crate::stream::StreamInner;
2
3use std::io::IoSlice;
4use std::ops::Range;
5use std::slice;
6use std::sync::Arc;
7
8use bytes::{Buf, Bytes};
9use libc::c_void;
10use tracing::trace;
11
12/// A buffer for receiving data from a stream.
13///
14/// It implements [`bytes::Buf`] and is backed by a list of [`msquic::Buffer`].
15pub struct StreamRecvBuffer {
16    stream: Option<Arc<StreamInner>>,
17    buffers: Vec<msquic::Buffer>,
18    offset: usize,
19    len: usize,
20    read_cursor: usize,
21    read_cursor_in_buffer: usize,
22    fin: bool,
23}
24
25impl StreamRecvBuffer {
26    pub(crate) fn new<T: AsRef<[msquic::Buffer]>>(offset: usize, buffers: &T, fin: bool) -> Self {
27        let buf = Self {
28            stream: None,
29            buffers: buffers.as_ref().to_vec(),
30            offset,
31            len: buffers.as_ref().iter().map(|x| x.length).sum::<u32>() as usize,
32            read_cursor: 0,
33            read_cursor_in_buffer: 0,
34            fin,
35        };
36        trace!(
37            "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
38            buf.buffers
39                .first()
40                .map(|x| x.buffer)
41                .unwrap_or(std::ptr::null_mut()),
42            buf.offset,
43            buf.len(),
44            buf.fin,
45        );
46        buf
47    }
48
49    pub(crate) fn set_stream(&mut self, stream: Arc<StreamInner>) {
50        self.stream = Some(stream);
51    }
52
53    /// Returns the length of the buffer.
54    pub fn len(&self) -> usize {
55        if self.buffers.len() <= self.read_cursor {
56            return 0;
57        }
58        self.len
59            - self.buffers[..self.read_cursor]
60                .iter()
61                .map(|x| x.length)
62                .sum::<u32>() as usize
63            - self.read_cursor_in_buffer
64    }
65
66    /// Returns `true` if the buffer is empty.
67    pub fn is_empty(&self) -> bool {
68        self.len() == 0
69    }
70
71    /// Returns the buffer as a slice.
72    pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
73        if self.buffers.len() <= self.read_cursor {
74            return &[];
75        }
76        assert!(self.buffers.len() >= self.read_cursor);
77        let buffer = &self.buffers[self.read_cursor];
78        assert!(buffer.length as usize >= self.read_cursor_in_buffer);
79        let len = std::cmp::min(buffer.length as usize - self.read_cursor_in_buffer, size);
80        unsafe { slice::from_raw_parts(buffer.buffer.add(self.read_cursor_in_buffer), len) }
81    }
82
83    /// Consumes and returns the buffer as a slice.
84    pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
85        if self.buffers.len() <= self.read_cursor {
86            return None;
87        }
88        assert!(self.buffers.len() >= self.read_cursor);
89        let buffer = &self.buffers[self.read_cursor];
90
91        assert!(buffer.length as usize >= self.read_cursor_in_buffer);
92        let len = std::cmp::min(buffer.length as usize - self.read_cursor_in_buffer, size);
93
94        let slice =
95            unsafe { slice::from_raw_parts(buffer.buffer.add(self.read_cursor_in_buffer), len) };
96        self.read_cursor_in_buffer += len;
97        if self.read_cursor_in_buffer >= buffer.length as usize {
98            self.read_cursor += 1;
99            self.read_cursor_in_buffer = 0;
100        }
101        Some(slice)
102    }
103
104    /// Return the offset in the stream.
105    pub fn offset(&self) -> usize {
106        self.offset
107    }
108
109    /// Return the range in the stream.
110    pub fn range(&self) -> Range<usize> {
111        self.offset..self.offset + self.len
112    }
113
114    /// Return `true` if the buffer is the end of the stream.
115    pub fn fin(&self) -> bool {
116        self.fin
117    }
118}
119
120unsafe impl Sync for StreamRecvBuffer {}
121unsafe impl Send for StreamRecvBuffer {}
122
123impl Buf for StreamRecvBuffer {
124    fn advance(&mut self, mut count: usize) {
125        assert!(count == 0 || count <= self.remaining());
126        for buffer in &self.buffers[self.read_cursor..] {
127            if count == 0 {
128                break;
129            }
130            let remaining = buffer.length as usize - self.read_cursor_in_buffer;
131            if count < remaining {
132                self.read_cursor_in_buffer += count;
133                break;
134            } else {
135                self.read_cursor += 1;
136                self.read_cursor_in_buffer = 0;
137                count -= remaining;
138            }
139        }
140    }
141
142    fn chunk(&self) -> &[u8] {
143        self.as_slice_upto_size(self.len())
144    }
145
146    fn remaining(&self) -> usize {
147        self.len()
148    }
149
150    fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
151        let mut count = 0;
152        let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
153        for buffer in &self.buffers[self.read_cursor..] {
154            if let Some(slice) = dst.get_mut(count) {
155                count += 1;
156                let skip = read_cursor_in_buffer.take().unwrap_or(0);
157                *slice = IoSlice::new(unsafe {
158                    slice::from_raw_parts(buffer.buffer.add(skip), buffer.length as usize - skip)
159                });
160            } else {
161                break;
162            }
163        }
164        count
165    }
166}
167
168impl Drop for StreamRecvBuffer {
169    fn drop(&mut self) {
170        trace!(
171            "StreamRecvBuffer({:p}) dropping",
172            self.buffers
173                .first()
174                .map(|x| x.buffer)
175                .unwrap_or(std::ptr::null_mut())
176        );
177        if let Some(stream) = self.stream.take() {
178            stream.read_complete(self);
179        }
180    }
181}
182
183pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
184
185struct WriteBufferInner {
186    internal: Vec<u8>,
187    zerocopy: Vec<Bytes>,
188    msquic_buffer: Vec<msquic::Buffer>,
189}
190unsafe impl Sync for WriteBufferInner {}
191unsafe impl Send for WriteBufferInner {}
192
193impl WriteBuffer {
194    pub(crate) fn new() -> Self {
195        Self(Box::new(WriteBufferInner {
196            internal: Vec::new(),
197            zerocopy: Vec::new(),
198            msquic_buffer: Vec::new(),
199        }))
200    }
201
202    pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
203        Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
204    }
205
206    pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
207        self.0.zerocopy.push(buf.clone());
208        buf.len()
209    }
210
211    pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
212        self.0.internal.extend_from_slice(slice);
213        slice.len()
214    }
215
216    pub(crate) fn get_buffer(&mut self) -> (*const msquic::Buffer, u32) {
217        if !self.0.zerocopy.is_empty() {
218            for buf in &self.0.zerocopy {
219                self.0.msquic_buffer.push(buf[..].into());
220            }
221        } else {
222            self.0.msquic_buffer.push((&self.0.internal).into());
223        }
224        let ptr = self.0.msquic_buffer.as_ptr();
225        let len = self.0.msquic_buffer.len() as u32;
226        (ptr, len)
227    }
228
229    pub(crate) fn into_raw(self) -> *mut c_void {
230        Box::into_raw(self.0) as *mut c_void
231    }
232
233    pub(crate) fn reset(&mut self) {
234        self.0.internal.clear();
235        self.0.zerocopy.clear();
236        self.0.msquic_buffer.clear();
237    }
238}