ant_quic/connection/streams/
send.rs1use bytes::Bytes;
9use thiserror::Error;
10
11use crate::{VarInt, connection::send_buffer::SendBuffer, frame};
12
13#[derive(Debug)]
14pub(super) struct Send {
15 pub(super) max_data: u64,
16 pub(super) state: SendState,
17 pub(super) pending: SendBuffer,
18 pub(super) priority: i32,
19 pub(super) fin_pending: bool,
21 pub(super) connection_blocked: bool,
23 pub(super) stop_reason: Option<VarInt>,
25}
26
27impl Send {
28 pub(super) fn new(max_data: VarInt) -> Box<Self> {
29 Box::new(Self {
30 max_data: max_data.into(),
31 state: SendState::Ready,
32 pending: SendBuffer::new(),
33 priority: 0,
34 fin_pending: false,
35 connection_blocked: false,
36 stop_reason: None,
37 })
38 }
39
40 pub(super) fn is_reset(&self) -> bool {
42 matches!(self.state, SendState::ResetSent)
43 }
44
45 pub(super) fn finish(&mut self) -> Result<(), FinishError> {
46 if let Some(error_code) = self.stop_reason {
47 Err(FinishError::Stopped(error_code))
48 } else if self.state == SendState::Ready {
49 self.state = SendState::DataSent {
50 finish_acked: false,
51 };
52 self.fin_pending = true;
53 Ok(())
54 } else {
55 Err(FinishError::ClosedStream)
56 }
57 }
58
59 pub(super) fn write<S: BytesSource>(
60 &mut self,
61 source: &mut S,
62 limit: u64,
63 ) -> Result<Written, WriteError> {
64 if !self.is_writable() {
65 return Err(WriteError::ClosedStream);
66 }
67 if let Some(error_code) = self.stop_reason {
68 return Err(WriteError::Stopped(error_code));
69 }
70 let budget = self.max_data - self.pending.offset();
71 if budget == 0 {
72 return Err(WriteError::Blocked);
73 }
74 let mut limit = limit.min(budget) as usize;
75
76 let mut result = Written::default();
77 loop {
78 let (chunk, chunks_consumed) = source.pop_chunk(limit);
79 result.chunks += chunks_consumed;
80 result.bytes += chunk.len();
81
82 if chunk.is_empty() {
83 break;
84 }
85
86 limit -= chunk.len();
87 self.pending.write(chunk);
88 }
89
90 Ok(result)
91 }
92
93 pub(super) fn reset(&mut self) {
95 use SendState::*;
96 if let DataSent { .. } | Ready = self.state {
97 self.state = ResetSent;
98 }
99 }
100
101 pub(super) fn try_stop(&mut self, error_code: VarInt) -> bool {
106 if self.stop_reason.is_none() {
107 self.stop_reason = Some(error_code);
108 true
109 } else {
110 false
111 }
112 }
113
114 pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool {
116 self.pending.ack(frame.offsets);
117 match self.state {
118 SendState::DataSent {
119 ref mut finish_acked,
120 } => {
121 *finish_acked |= frame.fin;
122 *finish_acked && self.pending.is_fully_acked()
123 }
124 _ => false,
125 }
126 }
127
128 pub(super) fn increase_max_data(&mut self, offset: u64) -> bool {
132 if offset <= self.max_data || self.state != SendState::Ready {
133 return false;
134 }
135 let was_blocked = self.pending.offset() == self.max_data;
136 self.max_data = offset;
137 was_blocked
138 }
139
140 pub(super) fn offset(&self) -> u64 {
141 self.pending.offset()
142 }
143
144 pub(super) fn is_pending(&self) -> bool {
145 self.pending.has_unsent_data() || self.fin_pending
146 }
147
148 pub(super) fn is_writable(&self) -> bool {
149 matches!(self.state, SendState::Ready)
150 }
151}
152
153pub(crate) struct BytesArray<'a> {
158 chunks: &'a mut [Bytes],
160 consumed: usize,
162}
163
164impl<'a> BytesArray<'a> {
165 pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
166 Self {
167 chunks,
168 consumed: 0,
169 }
170 }
171}
172
173impl BytesSource for BytesArray<'_> {
174 fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
175 let mut chunks_consumed = 0;
178
179 while self.consumed < self.chunks.len() {
180 let chunk = &mut self.chunks[self.consumed];
181
182 if chunk.len() <= limit {
183 let chunk = std::mem::take(chunk);
184 self.consumed += 1;
185 chunks_consumed += 1;
186 if chunk.is_empty() {
187 continue;
188 }
189 return (chunk, chunks_consumed);
190 } else if limit > 0 {
191 let chunk = chunk.split_to(limit);
192 return (chunk, chunks_consumed);
193 } else {
194 break;
195 }
196 }
197
198 (Bytes::new(), chunks_consumed)
199 }
200}
201
202pub(crate) struct ByteSlice<'a> {
208 data: &'a [u8],
210}
211
212impl<'a> ByteSlice<'a> {
213 pub(crate) fn from_slice(data: &'a [u8]) -> Self {
214 Self { data }
215 }
216}
217
218impl BytesSource for ByteSlice<'_> {
219 fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
220 let limit = limit.min(self.data.len());
221 if limit == 0 {
222 return (Bytes::new(), 0);
223 }
224
225 let chunk = Bytes::from(self.data[..limit].to_owned());
226 self.data = &self.data[chunk.len()..];
227
228 let chunks_consumed = usize::from(self.data.is_empty());
229 (chunk, chunks_consumed)
230 }
231}
232
233pub(super) trait BytesSource {
238 fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
247}
248
249#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
251pub struct Written {
252 pub bytes: usize,
254 pub chunks: usize,
258}
259
260#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
262pub enum WriteError {
263 #[error("unable to accept further writes")]
270 Blocked,
271 #[error("stopped by peer: code {0}")]
278 Stopped(VarInt),
279 #[error("closed stream")]
281 ClosedStream,
282 #[error("connection closed")]
284 ConnectionClosed,
285}
286
287#[derive(Debug, Copy, Clone, Eq, PartialEq)]
289pub(super) enum SendState {
290 Ready,
292 DataSent { finish_acked: bool },
294 ResetSent,
296}
297
298#[derive(Debug, Error, Clone, PartialEq, Eq)]
300pub enum FinishError {
301 #[error("stopped by peer: code {0}")]
308 Stopped(VarInt),
309 #[error("closed stream")]
311 ClosedStream,
312 #[error("connection closed")]
314 ConnectionClosed,
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn bytes_array() {
323 let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
324 for limit in 0..full.len() {
325 let mut chunks = [
326 Bytes::from_static(b""),
327 Bytes::from_static(b"Hello "),
328 Bytes::from_static(b"Wo"),
329 Bytes::from_static(b""),
330 Bytes::from_static(b"r"),
331 Bytes::from_static(b"ld"),
332 Bytes::from_static(b""),
333 Bytes::from_static(b" 12345678"),
334 Bytes::from_static(b"9 ABCDE"),
335 Bytes::from_static(b"F"),
336 Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
337 ];
338 let num_chunks = chunks.len();
339 let last_chunk_len = chunks[chunks.len() - 1].len();
340
341 let mut array = BytesArray::from_chunks(&mut chunks);
342
343 let mut buf = Vec::new();
344 let mut chunks_popped = 0;
345 let mut chunks_consumed = 0;
346 let mut remaining = limit;
347 loop {
348 let (chunk, consumed) = array.pop_chunk(remaining);
349 chunks_consumed += consumed;
350
351 if !chunk.is_empty() {
352 buf.extend_from_slice(&chunk);
353 remaining -= chunk.len();
354 chunks_popped += 1;
355 } else {
356 break;
357 }
358 }
359
360 assert_eq!(&buf[..], &full[..limit]);
361
362 if limit == full.len() {
363 assert_eq!(chunks_consumed, num_chunks);
365 assert_eq!(chunks_consumed, chunks_popped + 3);
367 } else if limit > full.len() - last_chunk_len {
368 assert_eq!(chunks_consumed, num_chunks - 1);
370 assert_eq!(chunks_consumed, chunks_popped + 2);
371 }
372 }
373 }
374
375 #[test]
376 fn byte_slice() {
377 let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
378 for limit in 0..full.len() {
379 let mut array = ByteSlice::from_slice(&full[..]);
380
381 let mut buf = Vec::new();
382 let mut chunks_popped = 0;
383 let mut chunks_consumed = 0;
384 let mut remaining = limit;
385 loop {
386 let (chunk, consumed) = array.pop_chunk(remaining);
387 chunks_consumed += consumed;
388
389 if !chunk.is_empty() {
390 buf.extend_from_slice(&chunk);
391 remaining -= chunk.len();
392 chunks_popped += 1;
393 } else {
394 break;
395 }
396 }
397
398 assert_eq!(&buf[..], &full[..limit]);
399 if limit != 0 {
400 assert_eq!(chunks_popped, 1);
401 } else {
402 assert_eq!(chunks_popped, 0);
403 }
404
405 if limit == full.len() {
406 assert_eq!(chunks_consumed, 1);
407 } else {
408 assert_eq!(chunks_consumed, 0);
409 }
410 }
411 }
412}