1use crate::stream::StreamInstance;
2
3use std::io::IoSlice;
4use std::ops::Range;
5use std::slice;
6use std::sync::Arc;
7
8use bytes::{Buf, Bytes};
9use libc::c_void;
10use tracing::trace;
11
12pub struct StreamRecvBuffer {
16 stream: Option<Arc<StreamInstance>>,
17 buffers: Vec<msquic::ffi::QUIC_BUFFER>,
18 offset: usize,
19 len: usize,
20 read_cursor: usize,
21 read_cursor_in_buffer: usize,
22 fin: bool,
23}
24
25impl StreamRecvBuffer {
26 pub(crate) fn new<T: AsRef<[msquic::BufferRef]> + ?Sized>(
27 offset: usize,
28 buffers: &T,
29 fin: bool,
30 ) -> Self {
31 let buf = Self {
32 stream: None,
33 buffers: buffers.as_ref().iter().map(|v| v.0).collect(),
34 offset,
35 len: buffers.as_ref().iter().map(|x| x.0.Length).sum::<u32>() as usize,
36 read_cursor: 0,
37 read_cursor_in_buffer: 0,
38 fin,
39 };
40 trace!(
41 "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
42 buf.buffers
43 .first()
44 .map(|x| x.Buffer)
45 .unwrap_or(std::ptr::null_mut()),
46 buf.offset,
47 buf.len(),
48 buf.fin,
49 );
50 buf
51 }
52
53 pub(crate) fn set_stream(&mut self, stream: Arc<StreamInstance>) {
54 trace!(
55 "StreamRecvBuffer({:p}) set StreamInstance({:p})",
56 self.buffers
57 .first()
58 .map(|x| x.Buffer)
59 .unwrap_or(std::ptr::null_mut()),
60 stream
61 );
62
63 self.stream = Some(stream);
64 }
65
66 pub fn len(&self) -> usize {
68 if self.buffers.len() <= self.read_cursor {
69 return 0;
70 }
71 self.len
72 - self.buffers[..self.read_cursor]
73 .iter()
74 .map(|x| x.Length)
75 .sum::<u32>() as usize
76 - self.read_cursor_in_buffer
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.len() == 0
82 }
83
84 pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
86 if self.buffers.len() <= self.read_cursor {
87 return &[];
88 }
89 assert!(self.buffers.len() >= self.read_cursor);
90 let buffer = &self.buffers[self.read_cursor];
91 assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
92 let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
93 unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) }
94 }
95
96 pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
98 if self.buffers.len() <= self.read_cursor {
99 return None;
100 }
101 assert!(self.buffers.len() >= self.read_cursor);
102 let buffer = &self.buffers[self.read_cursor];
103
104 assert!(buffer.Length as usize >= self.read_cursor_in_buffer);
105 let len = std::cmp::min(buffer.Length as usize - self.read_cursor_in_buffer, size);
106
107 let slice =
108 unsafe { slice::from_raw_parts(buffer.Buffer.add(self.read_cursor_in_buffer), len) };
109 self.read_cursor_in_buffer += len;
110 if self.read_cursor_in_buffer >= buffer.Length as usize {
111 self.read_cursor += 1;
112 self.read_cursor_in_buffer = 0;
113 }
114 Some(slice)
115 }
116
117 pub fn offset(&self) -> usize {
119 self.offset
120 }
121
122 pub fn range(&self) -> Range<usize> {
124 self.offset..self.offset + self.len
125 }
126
127 pub fn fin(&self) -> bool {
129 self.fin
130 }
131}
132
133unsafe impl Sync for StreamRecvBuffer {}
134unsafe impl Send for StreamRecvBuffer {}
135
136impl Buf for StreamRecvBuffer {
137 fn advance(&mut self, mut count: usize) {
138 assert!(count == 0 || count <= self.remaining());
139 for buffer in &self.buffers[self.read_cursor..] {
140 if count == 0 {
141 break;
142 }
143 let remaining = buffer.Length as usize - self.read_cursor_in_buffer;
144 if count < remaining {
145 self.read_cursor_in_buffer += count;
146 break;
147 } else {
148 self.read_cursor += 1;
149 self.read_cursor_in_buffer = 0;
150 count -= remaining;
151 }
152 }
153 }
154
155 fn chunk(&self) -> &[u8] {
156 self.as_slice_upto_size(self.len())
157 }
158
159 fn remaining(&self) -> usize {
160 self.len()
161 }
162
163 fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
164 let mut count = 0;
165 let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
166 for buffer in &self.buffers[self.read_cursor..] {
167 if let Some(slice) = dst.get_mut(count) {
168 count += 1;
169 let skip = read_cursor_in_buffer.take().unwrap_or(0);
170 *slice = IoSlice::new(unsafe {
171 slice::from_raw_parts(buffer.Buffer.add(skip), buffer.Length as usize - skip)
172 });
173 } else {
174 break;
175 }
176 }
177 count
178 }
179}
180
181impl Drop for StreamRecvBuffer {
182 fn drop(&mut self) {
183 trace!(
184 "StreamRecvBuffer({:p}) dropping",
185 self.buffers
186 .first()
187 .map(|x| x.Buffer)
188 .unwrap_or(std::ptr::null_mut())
189 );
190 if let Some(stream) = self.stream.take() {
191 stream.read_complete(self);
192 }
193 }
194}
195
196pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
197
198struct WriteBufferInner {
199 internal: Vec<u8>,
200 zerocopy: Vec<Bytes>,
201 buffers: Vec<msquic::BufferRef>,
202}
203unsafe impl Sync for WriteBufferInner {}
204unsafe impl Send for WriteBufferInner {}
205
206impl WriteBuffer {
207 pub(crate) fn new() -> Self {
208 Self(Box::new(WriteBufferInner {
209 internal: Vec::new(),
210 zerocopy: Vec::new(),
211 buffers: Vec::new(),
212 }))
213 }
214
215 pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
216 Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
217 }
218
219 pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
220 self.0.zerocopy.push(buf.clone());
221 buf.len()
222 }
223
224 pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
225 self.0.internal.extend_from_slice(slice);
226 slice.len()
227 }
228
229 pub(crate) fn get_buffers(&mut self) -> (*const msquic::BufferRef, usize) {
230 if !self.0.zerocopy.is_empty() {
231 for buf in &self.0.zerocopy {
232 self.0.buffers.push(buf[..].into());
233 }
234 } else {
235 self.0.buffers.push((&self.0.internal[..]).into());
236 }
237 let ptr = self.0.buffers.as_ptr();
238 let len = self.0.buffers.len();
239 (ptr, len)
240 }
241
242 pub(crate) fn into_raw(self) -> *mut c_void {
243 Box::into_raw(self.0) as *mut c_void
244 }
245
246 pub(crate) fn reset(&mut self) {
247 self.0.internal.clear();
248 self.0.zerocopy.clear();
249 self.0.buffers.clear();
250 }
251}