1use crate::stream::StreamInner;
2
3use std::io::IoSlice;
4use std::ops::Range;
5use std::slice;
6use std::sync::Arc;
7
8use bytes::{Buf, Bytes};
9use libc::c_void;
10use tracing::trace;
11
12pub struct StreamRecvBuffer {
16 stream: Option<Arc<StreamInner>>,
17 buffers: Vec<msquic::Buffer>,
18 offset: usize,
19 len: usize,
20 read_cursor: usize,
21 read_cursor_in_buffer: usize,
22 fin: bool,
23}
24
25impl StreamRecvBuffer {
26 pub(crate) fn new<T: AsRef<[msquic::Buffer]>>(offset: usize, buffers: &T, fin: bool) -> Self {
27 let buf = Self {
28 stream: None,
29 buffers: buffers.as_ref().to_vec(),
30 offset,
31 len: buffers.as_ref().iter().map(|x| x.length).sum::<u32>() as usize,
32 read_cursor: 0,
33 read_cursor_in_buffer: 0,
34 fin,
35 };
36 trace!(
37 "StreamRecvBuffer({:p}) created offset={} len={} fin={}",
38 buf.buffers
39 .first()
40 .map(|x| x.buffer)
41 .unwrap_or(std::ptr::null_mut()),
42 buf.offset,
43 buf.len(),
44 buf.fin,
45 );
46 buf
47 }
48
49 pub(crate) fn set_stream(&mut self, stream: Arc<StreamInner>) {
50 self.stream = Some(stream);
51 }
52
53 pub fn len(&self) -> usize {
55 if self.buffers.len() <= self.read_cursor {
56 return 0;
57 }
58 self.len
59 - self.buffers[..self.read_cursor]
60 .iter()
61 .map(|x| x.length)
62 .sum::<u32>() as usize
63 - self.read_cursor_in_buffer
64 }
65
66 pub fn is_empty(&self) -> bool {
68 self.len() == 0
69 }
70
71 pub fn as_slice_upto_size(&self, size: usize) -> &[u8] {
73 if self.buffers.len() <= self.read_cursor {
74 return &[];
75 }
76 assert!(self.buffers.len() >= self.read_cursor);
77 let buffer = &self.buffers[self.read_cursor];
78 assert!(buffer.length as usize >= self.read_cursor_in_buffer);
79 let len = std::cmp::min(buffer.length as usize - self.read_cursor_in_buffer, size);
80 unsafe { slice::from_raw_parts(buffer.buffer.add(self.read_cursor_in_buffer), len) }
81 }
82
83 pub fn get_bytes_upto_size<'a>(&mut self, size: usize) -> Option<&'a [u8]> {
85 if self.buffers.len() <= self.read_cursor {
86 return None;
87 }
88 assert!(self.buffers.len() >= self.read_cursor);
89 let buffer = &self.buffers[self.read_cursor];
90
91 assert!(buffer.length as usize >= self.read_cursor_in_buffer);
92 let len = std::cmp::min(buffer.length as usize - self.read_cursor_in_buffer, size);
93
94 let slice =
95 unsafe { slice::from_raw_parts(buffer.buffer.add(self.read_cursor_in_buffer), len) };
96 self.read_cursor_in_buffer += len;
97 if self.read_cursor_in_buffer >= buffer.length as usize {
98 self.read_cursor += 1;
99 self.read_cursor_in_buffer = 0;
100 }
101 Some(slice)
102 }
103
104 pub fn offset(&self) -> usize {
106 self.offset
107 }
108
109 pub fn range(&self) -> Range<usize> {
111 self.offset..self.offset + self.len
112 }
113
114 pub fn fin(&self) -> bool {
116 self.fin
117 }
118}
119
120unsafe impl Sync for StreamRecvBuffer {}
121unsafe impl Send for StreamRecvBuffer {}
122
123impl Buf for StreamRecvBuffer {
124 fn advance(&mut self, mut count: usize) {
125 assert!(count == 0 || count <= self.remaining());
126 for buffer in &self.buffers[self.read_cursor..] {
127 if count == 0 {
128 break;
129 }
130 let remaining = buffer.length as usize - self.read_cursor_in_buffer;
131 if count < remaining {
132 self.read_cursor_in_buffer += count;
133 break;
134 } else {
135 self.read_cursor += 1;
136 self.read_cursor_in_buffer = 0;
137 count -= remaining;
138 }
139 }
140 }
141
142 fn chunk(&self) -> &[u8] {
143 self.as_slice_upto_size(self.len())
144 }
145
146 fn remaining(&self) -> usize {
147 self.len()
148 }
149
150 fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
151 let mut count = 0;
152 let mut read_cursor_in_buffer = Some(self.read_cursor_in_buffer);
153 for buffer in &self.buffers[self.read_cursor..] {
154 if let Some(slice) = dst.get_mut(count) {
155 count += 1;
156 let skip = read_cursor_in_buffer.take().unwrap_or(0);
157 *slice = IoSlice::new(unsafe {
158 slice::from_raw_parts(buffer.buffer.add(skip), buffer.length as usize - skip)
159 });
160 } else {
161 break;
162 }
163 }
164 count
165 }
166}
167
168impl Drop for StreamRecvBuffer {
169 fn drop(&mut self) {
170 trace!(
171 "StreamRecvBuffer({:p}) dropping",
172 self.buffers
173 .first()
174 .map(|x| x.buffer)
175 .unwrap_or(std::ptr::null_mut())
176 );
177 if let Some(stream) = self.stream.take() {
178 stream.read_complete(self);
179 }
180 }
181}
182
183pub(crate) struct WriteBuffer(Box<WriteBufferInner>);
184
185struct WriteBufferInner {
186 internal: Vec<u8>,
187 zerocopy: Vec<Bytes>,
188 msquic_buffer: Vec<msquic::Buffer>,
189}
190unsafe impl Sync for WriteBufferInner {}
191unsafe impl Send for WriteBufferInner {}
192
193impl WriteBuffer {
194 pub(crate) fn new() -> Self {
195 Self(Box::new(WriteBufferInner {
196 internal: Vec::new(),
197 zerocopy: Vec::new(),
198 msquic_buffer: Vec::new(),
199 }))
200 }
201
202 pub(crate) unsafe fn from_raw(inner: *const c_void) -> Self {
203 Self(unsafe { Box::from_raw(inner as *mut WriteBufferInner) })
204 }
205
206 pub(crate) fn put_zerocopy(&mut self, buf: &Bytes) -> usize {
207 self.0.zerocopy.push(buf.clone());
208 buf.len()
209 }
210
211 pub(crate) fn put_slice(&mut self, slice: &[u8]) -> usize {
212 self.0.internal.extend_from_slice(slice);
213 slice.len()
214 }
215
216 pub(crate) fn get_buffer(&mut self) -> (*const msquic::Buffer, u32) {
217 if !self.0.zerocopy.is_empty() {
218 for buf in &self.0.zerocopy {
219 self.0.msquic_buffer.push(buf[..].into());
220 }
221 } else {
222 self.0.msquic_buffer.push((&self.0.internal).into());
223 }
224 let ptr = self.0.msquic_buffer.as_ptr();
225 let len = self.0.msquic_buffer.len() as u32;
226 (ptr, len)
227 }
228
229 pub(crate) fn into_raw(self) -> *mut c_void {
230 Box::into_raw(self.0) as *mut c_void
231 }
232
233 pub(crate) fn reset(&mut self) {
234 self.0.internal.clear();
235 self.0.zerocopy.clear();
236 self.0.msquic_buffer.clear();
237 }
238}