msquic_async/
buffer.rs

1use crate::stream::StreamInstance;
2
3use std::io::IoSlice;
4use std::ops::Range;
5use std::slice;
6use std::sync::Arc;
7
8use bytes::{Buf, Bytes};
9use libc::c_void;
10use tracing::trace;
11
12/// A buffer for receiving data from a stream.
13///
14/// It implements [`bytes::Buf`] and is backed by a list of [`msquic::ffi::QUIC_BUFFER`].
15pub struct StreamRecvBuffer {
16    stream: Option<Arc<StreamInstance>>,
17    buffers: Vec<msquic::ffi::QUIC_BUFFER>,
18    offset: usize,
19    len: usize,
20    read_cursor: usize,
21    read_cursor_in_buffer: usize,
22    fin: bool,
23}
24
25impl StreamRecvBuffer {
26    pub(crate) fn new<T: AsRef<[msquic::BufferRef]> + ?Sized>(
27        offset: usize,
28        buffers: &T,
29        fin: bool,
30    ) -> Self {
31        let buf = Self {
32            stream: None,
33            buffers: buffers.as_ref().iter().map(|v| v.0).collect(),
34            offset,
35            len: buffers.as_ref().iter().map(|x| x.0.Length).sum::<u32>() as usize,
36            read_cursor: 0,
37            read_cursor_in_buffer: 0,
38            fin,
39        };
40        trace!(
41            "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
42            buf.buffers
43                .first()
44                .map(|x| x.Buffer)
45                .unwrap_or(std::ptr::null_mut()),
46            buf.offset,
47            buf.len(),
48            buf.fin,
49        );
50        buf
51    }
52
53    pub(crate) fn set_stream(&mut self, stream: Arc<StreamInstance>) {
54        trace!(
55            "StreamRecvBuffer({:p}) set StreamInstance({:p})",
56            self.buffers
57                .first()
58                .map(|x| x.Buffer)
59                .unwrap_or(std::ptr::null_mut()),
60            stream
61        );
62
63        self.stream = Some(stream);
64    }
65
66    /// Returns the length of the buffer.
67    pub fn len(&self) -> usize {
68        if self.buffers.len() <= self.read_cursor {
69            return 0;
70        }
71        self.len
72            - self.buffers[..self.read_cursor]
73                .iter()
74                .map(|x| x.Length)
75                .sum::<u32>() as usize
76            - self.read_cursor_in_buffer
77    }
78
79    /// Returns `true` if the buffer is empty.
80    pub fn is_empty(&self) -> bool {
81        self.len() == 0
82    }
83
84    /// Returns the buffer as a slice.
85    pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
86        if self.buffers.len() <= self.read_cursor {
87            return &[];
88        }
89        assert!(self.buffers.len() >= self.read_cursor);
90        let buffer = &self.buffers[self.read_cursor];
91        assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
92        let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
93        unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) }
94    }
95
96    /// Consumes and returns the buffer as a slice.
97    pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
98        if self.buffers.len() <= self.read_cursor {
99            return None;
100        }
101        assert!(self.buffers.len() >= self.read_cursor);
102        let buffer = &self.buffers[self.read_cursor];
103
104        assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
105        let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
106
107        let slice =
108            unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) };
109        self.read_cursor_in_buffer += len;
110        if self.read_cursor_in_buffer >= buffer.Length as usize {
111            self.read_cursor += 1;
112            self.read_cursor_in_buffer = 0;
113        }
114        Some(slice)
115    }
116
117    /// Return the offset in the stream.
118    pub fn offset(&self) -> usize {
119        self.offset
120    }
121
122    /// Return the range in the stream.
123    pub fn range(&self) -> Range<usize> {
124        self.offset..self.offset + self.len
125    }
126
127    /// Return `true` if the buffer is the end of the stream.
128    pub fn fin(&self) -> bool {
129        self.fin
130    }
131}
132
133unsafe impl Sync for StreamRecvBuffer {}
134unsafe impl Send for StreamRecvBuffer {}
135
136impl Buf for StreamRecvBuffer {
137    fn advance(&mut self, mut count: usize) {
138        assert!(count == 0 || count <= self.remaining());
139        for buffer in &self.buffers[self.read_cursor..] {
140            if count == 0 {
141                break;
142            }
143            let remaining = buffer.Length as usize - self.read_cursor_in_buffer;
144            if count < remaining {
145                self.read_cursor_in_buffer += count;
146                break;
147            } else {
148                self.read_cursor += 1;
149                self.read_cursor_in_buffer = 0;
150                count -= remaining;
151            }
152        }
153    }
154
155    fn chunk(&self) -> &[u8] {
156        self.as_slice_upto_size(self.len())
157    }
158
159    fn remaining(&self) -> usize {
160        self.len()
161    }
162
163    fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
164        let mut count = 0;
165        let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
166        for buffer in &self.buffers[self.read_cursor..] {
167            if let Some(slice) = dst.get_mut(count) {
168                count += 1;
169                let skip = read_cursor_in_buffer.take().unwrap_or(0);
170                *slice = IoSlice::new(unsafe {
171                    slice::from_raw_parts(buffer.Buffer.add(skip), buffer.Length as usize - skip)
172                });
173            } else {
174                break;
175            }
176        }
177        count
178    }
179}
180
181impl Drop for StreamRecvBuffer {
182    fn drop(&mut self) {
183        trace!(
184            "StreamRecvBuffer({:p}) dropping",
185            self.buffers
186                .first()
187                .map(|x| x.Buffer)
188                .unwrap_or(std::ptr::null_mut())
189        );
190        if let Some(stream) = self.stream.take() {
191            stream.read_complete(self);
192        }
193    }
194}
195
196pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
197
198struct WriteBufferInner {
199    internal: Vec<u8>,
200    zerocopy: Vec<Bytes>,
201    buffers: Vec<msquic::BufferRef>,
202}
203unsafe impl Sync for WriteBufferInner {}
204unsafe impl Send for WriteBufferInner {}
205
206impl WriteBuffer {
207    pub(crate) fn new() -> Self {
208        Self(Box::new(WriteBufferInner {
209            internal: Vec::new(),
210            zerocopy: Vec::new(),
211            buffers: Vec::new(),
212        }))
213    }
214
215    pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
216        Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
217    }
218
219    pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
220        self.0.zerocopy.push(buf.clone());
221        buf.len()
222    }
223
224    pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
225        self.0.internal.extend_from_slice(slice);
226        slice.len()
227    }
228
229    pub(crate) fn get_buffers(&mut self) -> (*const msquic::BufferRef, usize) {
230        if !self.0.zerocopy.is_empty() {
231            for buf in &self.0.zerocopy {
232                self.0.buffers.push(buf[..].into());
233            }
234        } else {
235            self.0.buffers.push((&self.0.internal[..]).into());
236        }
237        let ptr = self.0.buffers.as_ptr();
238        let len = self.0.buffers.len();
239        (ptr, len)
240    }
241
242    pub(crate) fn into_raw(self) -> *mut c_void {
243        Box::into_raw(self.0) as *mut c_void
244    }
245
246    pub(crate) fn reset(&mut self) {
247        self.0.internal.clear();
248        self.0.zerocopy.clear();
249        self.0.buffers.clear();
250    }
251}