Skip to main content

msquic_async/
buffer.rs

1use crate::stream::StreamInstance;
2
3#[cfg(feature = "msquic-2-5")]
4use msquic_v2_5 as msquic;
5#[cfg(feature = "msquic-seera")]
6use seera_msquic as msquic;
7
8use std::io::IoSlice;
9use std::ops::Range;
10use std::slice;
11use std::sync::Arc;
12
13use bytes::{Buf, Bytes};
14use libc::c_void;
15use tracing::trace;
16
17/// A buffer for receiving data from a stream.
18///
19/// It implements [`bytes::Buf`] and is backed by a list of [`msquic::ffi::QUIC_BUFFER`].
20pub struct StreamRecvBuffer {
21    stream: Option<Arc<StreamInstance>>,
22    buffers: Vec<msquic::ffi::QUIC_BUFFER>,
23    offset: usize,
24    len: usize,
25    read_cursor: usize,
26    read_cursor_in_buffer: usize,
27    fin: bool,
28}
29
30impl StreamRecvBuffer {
31    pub(crate) fn new<T: AsRef<[msquic::BufferRef]> + ?Sized>(
32        offset: usize,
33        buffers: &T,
34        fin: bool,
35    ) -> Self {
36        let buf = Self {
37            stream: None,
38            buffers: buffers.as_ref().iter().map(|v| v.0).collect(),
39            offset,
40            len: buffers.as_ref().iter().map(|x| x.0.Length).sum::<u32>() as usize,
41            read_cursor: 0,
42            read_cursor_in_buffer: 0,
43            fin,
44        };
45        trace!(
46            "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
47            buf.buffers
48                .first()
49                .map(|x| x.Buffer)
50                .unwrap_or(std::ptr::null_mut()),
51            buf.offset,
52            buf.len(),
53            buf.fin,
54        );
55        buf
56    }
57
58    pub(crate) fn set_stream(&mut self, stream: Arc<StreamInstance>) {
59        trace!(
60            "StreamRecvBuffer({:p}) set StreamInstance({:p})",
61            self.buffers
62                .first()
63                .map(|x| x.Buffer)
64                .unwrap_or(std::ptr::null_mut()),
65            stream
66        );
67
68        self.stream = Some(stream);
69    }
70
71    /// Returns the length of the buffer.
72    pub fn len(&self) -> usize {
73        if self.buffers.len() <= self.read_cursor {
74            return 0;
75        }
76        self.len
77            - self.buffers[..self.read_cursor]
78                .iter()
79                .map(|x| x.Length)
80                .sum::<u32>() as usize
81            - self.read_cursor_in_buffer
82    }
83
84    /// Returns `true` if the buffer is empty.
85    pub fn is_empty(&self) -> bool {
86        self.len() == 0
87    }
88
89    /// Returns the buffer as a slice.
90    pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
91        if self.buffers.len() <= self.read_cursor {
92            return &[];
93        }
94        assert!(self.buffers.len() >= self.read_cursor);
95        let buffer = &self.buffers[self.read_cursor];
96        assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
97        let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
98        unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) }
99    }
100
101    /// Consumes and returns the buffer as a slice.
102    pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
103        if self.buffers.len() <= self.read_cursor {
104            return None;
105        }
106        assert!(self.buffers.len() >= self.read_cursor);
107        let buffer = &self.buffers[self.read_cursor];
108
109        assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
110        let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
111
112        let slice =
113            unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) };
114        self.read_cursor_in_buffer += len;
115        if self.read_cursor_in_buffer >= buffer.Length as usize {
116            self.read_cursor += 1;
117            self.read_cursor_in_buffer = 0;
118        }
119        Some(slice)
120    }
121
122    /// Return the offset in the stream.
123    pub fn offset(&self) -> usize {
124        self.offset
125    }
126
127    /// Return the range in the stream.
128    pub fn range(&self) -> Range<usize> {
129        self.offset..self.offset + self.len
130    }
131
132    /// Return `true` if the buffer is the end of the stream.
133    pub fn fin(&self) -> bool {
134        self.fin
135    }
136}
137
138unsafe impl Sync for StreamRecvBuffer {}
139unsafe impl Send for StreamRecvBuffer {}
140
141impl Buf for StreamRecvBuffer {
142    fn advance(&mut self, mut count: usize) {
143        assert!(count == 0 || count <= self.remaining());
144        for buffer in &self.buffers[self.read_cursor..] {
145            if count == 0 {
146                break;
147            }
148            let remaining = buffer.Length as usize - self.read_cursor_in_buffer;
149            if count < remaining {
150                self.read_cursor_in_buffer += count;
151                break;
152            } else {
153                self.read_cursor += 1;
154                self.read_cursor_in_buffer = 0;
155                count -= remaining;
156            }
157        }
158    }
159
160    fn chunk(&self) -> &[u8] {
161        self.as_slice_upto_size(self.len())
162    }
163
164    fn remaining(&self) -> usize {
165        self.len()
166    }
167
168    fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
169        let mut count = 0;
170        let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
171        for buffer in &self.buffers[self.read_cursor..] {
172            if let Some(slice) = dst.get_mut(count) {
173                count += 1;
174                let skip = read_cursor_in_buffer.take().unwrap_or(0);
175                *slice = IoSlice::new(unsafe {
176                    slice::from_raw_parts(buffer.Buffer.add(skip), buffer.Length as usize - skip)
177                });
178            } else {
179                break;
180            }
181        }
182        count
183    }
184}
185
186impl Drop for StreamRecvBuffer {
187    fn drop(&mut self) {
188        trace!(
189            "StreamRecvBuffer({:p}) dropping",
190            self.buffers
191                .first()
192                .map(|x| x.Buffer)
193                .unwrap_or(std::ptr::null_mut())
194        );
195        if let Some(stream) = self.stream.take() {
196            stream.read_complete(self);
197        }
198    }
199}
200
201pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
202
203struct WriteBufferInner {
204    internal: Vec<u8>,
205    zerocopy: Vec<Bytes>,
206    buffers: Vec<msquic::BufferRef>,
207}
208unsafe impl Sync for WriteBufferInner {}
209unsafe impl Send for WriteBufferInner {}
210
211impl WriteBuffer {
212    pub(crate) fn new() -> Self {
213        Self(Box::new(WriteBufferInner {
214            internal: Vec::new(),
215            zerocopy: Vec::new(),
216            buffers: Vec::new(),
217        }))
218    }
219
220    pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
221        Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
222    }
223
224    pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
225        self.0.zerocopy.push(buf.clone());
226        buf.len()
227    }
228
229    pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
230        self.0.internal.extend_from_slice(slice);
231        slice.len()
232    }
233
234    pub(crate) fn get_buffers(&mut self) -> (*const msquic::BufferRef, usize) {
235        if !self.0.zerocopy.is_empty() {
236            for buf in &self.0.zerocopy {
237                self.0.buffers.push(buf[..].into());
238            }
239        } else {
240            self.0.buffers.push((&self.0.internal[..]).into());
241        }
242        let ptr = self.0.buffers.as_ptr();
243        let len = self.0.buffers.len();
244        (ptr, len)
245    }
246
247    pub(crate) fn into_raw(self) -> *mut c_void {
248        Box::into_raw(self.0) as *mut c_void
249    }
250
251    pub(crate) fn reset(&mut self) {
252        self.0.internal.clear();
253        self.0.zerocopy.clear();
254        self.0.buffers.clear();
255    }
256}