1use crate::stream::StreamInstance;
2
3#[cfg(feature = "msquic-2-5")]
4use msquic_v2_5 as msquic;
5#[cfg(feature = "msquic-seera")]
6use seera_msquic as msquic;
7
8use std::io::IoSlice;
9use std::ops::Range;
10use std::slice;
11use std::sync::Arc;
12
13use bytes::{Buf, Bytes};
14use libc::c_void;
15use tracing::trace;
16
17pub struct StreamRecvBuffer {
21 stream: Option<Arc<StreamInstance>>,
22 buffers: Vec<msquic::ffi::QUIC_BUFFER>,
23 offset: usize,
24 len: usize,
25 read_cursor: usize,
26 read_cursor_in_buffer: usize,
27 fin: bool,
28}
29
30impl StreamRecvBuffer {
31 pub(crate) fn new<T: AsRef<[msquic::BufferRef]> + ?Sized>(
32 offset: usize,
33 buffers: &T,
34 fin: bool,
35 ) -> Self {
36 let buf = Self {
37 stream: None,
38 buffers: buffers.as_ref().iter().map(|v| v.0).collect(),
39 offset,
40 len: buffers.as_ref().iter().map(|x| x.0.Length).sum::<u32>() as usize,
41 read_cursor: 0,
42 read_cursor_in_buffer: 0,
43 fin,
44 };
45 trace!(
46 "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
47 buf.buffers
48 .first()
49 .map(|x| x.Buffer)
50 .unwrap_or(std::ptr::null_mut()),
51 buf.offset,
52 buf.len(),
53 buf.fin,
54 );
55 buf
56 }
57
58 pub(crate) fn set_stream(&mut self, stream: Arc<StreamInstance>) {
59 trace!(
60 "StreamRecvBuffer({:p}) set StreamInstance({:p})",
61 self.buffers
62 .first()
63 .map(|x| x.Buffer)
64 .unwrap_or(std::ptr::null_mut()),
65 stream
66 );
67
68 self.stream = Some(stream);
69 }
70
71 pub fn len(&self) -> usize {
73 if self.buffers.len() <= self.read_cursor {
74 return 0;
75 }
76 self.len
77 - self.buffers[..self.read_cursor]
78 .iter()
79 .map(|x| x.Length)
80 .sum::<u32>() as usize
81 - self.read_cursor_in_buffer
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.len() == 0
87 }
88
89 pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
91 if self.buffers.len() <= self.read_cursor {
92 return &[];
93 }
94 assert!(self.buffers.len() >= self.read_cursor);
95 let buffer = &self.buffers[self.read_cursor];
96 assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
97 let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
98 unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) }
99 }
100
101 pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
103 if self.buffers.len() <= self.read_cursor {
104 return None;
105 }
106 assert!(self.buffers.len() >= self.read_cursor);
107 let buffer = &self.buffers[self.read_cursor];
108
109 assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
110 let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
111
112 let slice =
113 unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) };
114 self.read_cursor_in_buffer += len;
115 if self.read_cursor_in_buffer >= buffer.Length as usize {
116 self.read_cursor += 1;
117 self.read_cursor_in_buffer = 0;
118 }
119 Some(slice)
120 }
121
122 pub fn offset(&self) -> usize {
124 self.offset
125 }
126
127 pub fn range(&self) -> Range<usize> {
129 self.offset..self.offset + self.len
130 }
131
132 pub fn fin(&self) -> bool {
134 self.fin
135 }
136}
137
138unsafe impl Sync for StreamRecvBuffer {}
139unsafe impl Send for StreamRecvBuffer {}
140
141impl Buf for StreamRecvBuffer {
142 fn advance(&mut self, mut count: usize) {
143 assert!(count == 0 || count <= self.remaining());
144 for buffer in &self.buffers[self.read_cursor..] {
145 if count == 0 {
146 break;
147 }
148 let remaining = buffer.Length as usize - self.read_cursor_in_buffer;
149 if count < remaining {
150 self.read_cursor_in_buffer += count;
151 break;
152 } else {
153 self.read_cursor += 1;
154 self.read_cursor_in_buffer = 0;
155 count -= remaining;
156 }
157 }
158 }
159
160 fn chunk(&self) -> &[u8] {
161 self.as_slice_upto_size(self.len())
162 }
163
164 fn remaining(&self) -> usize {
165 self.len()
166 }
167
168 fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
169 let mut count = 0;
170 let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
171 for buffer in &self.buffers[self.read_cursor..] {
172 if let Some(slice) = dst.get_mut(count) {
173 count += 1;
174 let skip = read_cursor_in_buffer.take().unwrap_or(0);
175 *slice = IoSlice::new(unsafe {
176 slice::from_raw_parts(buffer.Buffer.add(skip), buffer.Length as usize - skip)
177 });
178 } else {
179 break;
180 }
181 }
182 count
183 }
184}
185
186impl Drop for StreamRecvBuffer {
187 fn drop(&mut self) {
188 trace!(
189 "StreamRecvBuffer({:p}) dropping",
190 self.buffers
191 .first()
192 .map(|x| x.Buffer)
193 .unwrap_or(std::ptr::null_mut())
194 );
195 if let Some(stream) = self.stream.take() {
196 stream.read_complete(self);
197 }
198 }
199}
200
201pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
202
203struct WriteBufferInner {
204 internal: Vec<u8>,
205 zerocopy: Vec<Bytes>,
206 buffers: Vec<msquic::BufferRef>,
207}
208unsafe impl Sync for WriteBufferInner {}
209unsafe impl Send for WriteBufferInner {}
210
211impl WriteBuffer {
212 pub(crate) fn new() -> Self {
213 Self(Box::new(WriteBufferInner {
214 internal: Vec::new(),
215 zerocopy: Vec::new(),
216 buffers: Vec::new(),
217 }))
218 }
219
220 pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
221 Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
222 }
223
224 pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
225 self.0.zerocopy.push(buf.clone());
226 buf.len()
227 }
228
229 pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
230 self.0.internal.extend_from_slice(slice);
231 slice.len()
232 }
233
234 pub(crate) fn get_buffers(&mut self) -> (*const msquic::BufferRef, usize) {
235 if !self.0.zerocopy.is_empty() {
236 for buf in &self.0.zerocopy {
237 self.0.buffers.push(buf[..].into());
238 }
239 } else {
240 self.0.buffers.push((&self.0.internal[..]).into());
241 }
242 let ptr = self.0.buffers.as_ptr();
243 let len = self.0.buffers.len();
244 (ptr, len)
245 }
246
247 pub(crate) fn into_raw(self) -> *mut c_void {
248 Box::into_raw(self.0) as *mut c_void
249 }
250
251 pub(crate) fn reset(&mut self) {
252 self.0.internal.clear();
253 self.0.zerocopy.clear();
254 self.0.buffers.clear();
255 }
256}