1use super::{
2 store, Buffer, Codec, Config, Counts, Frame, Prioritize, Prioritized, Store, Stream, StreamId,
3 StreamIdOverflow, WindowSize,
4};
5use crate::codec::UserError;
6use crate::frame::{self, Reason};
7use crate::proto::{self, Error, Initiator};
8
9use bytes::Buf;
10use tokio::io::AsyncWrite;
11
12use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::io;
15use std::task::{Context, Poll, Waker};
16
17#[derive(Debug)]
19pub(super) struct Send {
20 next_stream_id: Result<StreamId, StreamIdOverflow>,
22
23 max_stream_id: StreamId,
31
32 init_window_sz: WindowSize,
34
35 prioritize: Prioritize,
37
38 is_push_enabled: bool,
39
40 is_extended_connect_protocol_enabled: bool,
42}
43
44#[derive(Debug)]
46pub(crate) enum PollReset {
47 AwaitingHeaders,
48 Streaming,
49}
50
51impl Send {
52 pub fn new(config: &Config) -> Self {
54 Send {
55 init_window_sz: config.remote_init_window_sz,
56 max_stream_id: StreamId::MAX,
57 next_stream_id: Ok(config.local_next_stream_id),
58 prioritize: Prioritize::new(config),
59 is_push_enabled: true,
60 is_extended_connect_protocol_enabled: false,
61 }
62 }
63
64 pub fn init_window_sz(&self) -> WindowSize {
66 self.init_window_sz
67 }
68
69 pub fn open(&mut self) -> Result<StreamId, UserError> {
70 let stream_id = self.ensure_next_stream_id()?;
71 self.next_stream_id = stream_id.next_id();
72 Ok(stream_id)
73 }
74
75 pub fn reserve_local(&mut self) -> Result<StreamId, UserError> {
76 let stream_id = self.ensure_next_stream_id()?;
77 self.next_stream_id = stream_id.next_id();
78 Ok(stream_id)
79 }
80
81 fn check_headers(fields: &http::HeaderMap) -> Result<(), UserError> {
82 if fields.contains_key(http::header::CONNECTION)
84 || fields.contains_key(http::header::TRANSFER_ENCODING)
85 || fields.contains_key(http::header::UPGRADE)
86 || fields.contains_key("keep-alive")
87 || fields.contains_key("proxy-connection")
88 {
89 tracing::debug!("illegal connection-specific headers found");
90 return Err(UserError::MalformedHeaders);
91 } else if let Some(te) = fields.get(http::header::TE) {
92 if te != "trailers" {
93 tracing::debug!("illegal connection-specific headers found");
94 return Err(UserError::MalformedHeaders);
95 }
96 }
97 Ok(())
98 }
99
100 pub fn send_push_promise<B>(
101 &mut self,
102 frame: frame::PushPromise,
103 buffer: &mut Buffer<Frame<B>>,
104 stream: &mut store::Ptr,
105 task: &mut Option<Waker>,
106 ) -> Result<(), UserError> {
107 if !self.is_push_enabled {
108 return Err(UserError::PeerDisabledServerPush);
109 }
110
111 tracing::trace!(
112 "send_push_promise; frame={:?}; init_window={:?}",
113 frame,
114 self.init_window_sz
115 );
116
117 Self::check_headers(frame.fields())?;
118
119 self.prioritize
121 .queue_frame(frame.into(), buffer, stream, task);
122
123 Ok(())
124 }
125
126 pub fn send_headers<B>(
127 &mut self,
128 frame: frame::Headers,
129 buffer: &mut Buffer<Frame<B>>,
130 stream: &mut store::Ptr,
131 counts: &mut Counts,
132 task: &mut Option<Waker>,
133 ) -> Result<(), UserError> {
134 self.send_headers_with_priority(None, frame, buffer, stream, counts, task)
135 }
136
137 pub fn send_headers_with_priority<B>(
138 &mut self,
139 priority_frame: Option<Cow<'static, [frame::Priority]>>,
140 headers_frame: frame::Headers,
141 buffer: &mut Buffer<Frame<B>>,
142 stream: &mut store::Ptr,
143 counts: &mut Counts,
144 task: &mut Option<Waker>,
145 ) -> Result<(), UserError> {
146 tracing::trace!(
147 "send_headers; frame={:?}; init_window={:?}",
148 headers_frame,
149 self.init_window_sz
150 );
151
152 Self::check_headers(headers_frame.fields())?;
153
154 let end_stream = headers_frame.is_end_stream();
155
156 stream.state.send_open(end_stream)?;
158
159 let mut pending_open = false;
160 if counts.peer().is_local_init(headers_frame.stream_id()) && !stream.is_pending_push {
161 self.prioritize.queue_open(stream);
162 pending_open = true;
163 }
164
165 if let Some(priority_frames) = priority_frame {
167 for priority_frame in priority_frames.into_owned() {
168 tracing::trace!(
169 "send_priority; frame={:?}; init_window={:?}",
170 priority_frame,
171 self.init_window_sz
172 );
173 self.prioritize
174 .queue_frame(priority_frame.into(), buffer, stream, task);
175 }
176 }
177
178 self.prioritize
183 .queue_frame(headers_frame.into(), buffer, stream, task);
184
185 if pending_open {
188 if let Some(task) = task.take() {
189 task.wake();
190 }
191 }
192
193 Ok(())
194 }
195
196 pub fn send_reset<B>(
198 &mut self,
199 reason: Reason,
200 initiator: Initiator,
201 buffer: &mut Buffer<Frame<B>>,
202 stream: &mut store::Ptr,
203 counts: &mut Counts,
204 task: &mut Option<Waker>,
205 ) {
206 let is_reset = stream.state.is_reset();
207 let is_closed = stream.state.is_closed();
208 let is_empty = stream.pending_send.is_empty();
209 let stream_id = stream.id;
210
211 tracing::trace!(
212 "send_reset(..., reason={:?}, initiator={:?}, stream={:?}, ..., \
213 is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \
214 state={:?} \
215 ",
216 reason,
217 initiator,
218 stream_id,
219 is_reset,
220 is_closed,
221 is_empty,
222 stream.state
223 );
224
225 if is_reset {
226 tracing::trace!(
228 " -> not sending RST_STREAM ({:?} is already reset)",
229 stream_id
230 );
231 return;
232 }
233
234 stream.set_reset(reason, initiator);
236
237 if is_closed && is_empty {
240 tracing::trace!(
241 " -> not sending explicit RST_STREAM ({:?} was closed \
242 and send queue was flushed)",
243 stream_id
244 );
245 return;
246 }
247
248 self.prioritize.clear_queue(buffer, stream);
253
254 let frame = frame::Reset::new(stream.id, reason);
255
256 tracing::trace!("send_reset -- queueing; frame={:?}", frame);
257 self.prioritize
258 .queue_frame(frame.into(), buffer, stream, task);
259 self.prioritize.reclaim_all_capacity(stream, counts);
260 }
261
262 pub fn schedule_implicit_reset(
263 &mut self,
264 stream: &mut store::Ptr,
265 reason: Reason,
266 counts: &mut Counts,
267 task: &mut Option<Waker>,
268 ) {
269 if stream.state.is_closed() {
270 return;
272 }
273
274 stream.state.set_scheduled_reset(reason);
275
276 self.prioritize.reclaim_reserved_capacity(stream, counts);
277 self.prioritize.schedule_send(stream, task);
278 }
279
280 pub fn send_data<B>(
281 &mut self,
282 frame: frame::Data<B>,
283 buffer: &mut Buffer<Frame<B>>,
284 stream: &mut store::Ptr,
285 counts: &mut Counts,
286 task: &mut Option<Waker>,
287 ) -> Result<(), UserError>
288 where
289 B: Buf,
290 {
291 self.prioritize
292 .send_data(frame, buffer, stream, counts, task)
293 }
294
295 pub fn send_trailers<B>(
296 &mut self,
297 frame: frame::Headers,
298 buffer: &mut Buffer<Frame<B>>,
299 stream: &mut store::Ptr,
300 counts: &mut Counts,
301 task: &mut Option<Waker>,
302 ) -> Result<(), UserError> {
303 if !stream.state.is_send_streaming() {
305 return Err(UserError::UnexpectedFrameType);
306 }
307
308 stream.state.send_close();
309
310 tracing::trace!("send_trailers -- queuing; frame={:?}", frame);
311 self.prioritize
312 .queue_frame(frame.into(), buffer, stream, task);
313
314 self.prioritize.reserve_capacity(0, stream, counts);
316
317 Ok(())
318 }
319
320 pub fn poll_complete<T, B>(
321 &mut self,
322 cx: &mut Context,
323 buffer: &mut Buffer<Frame<B>>,
324 store: &mut Store,
325 counts: &mut Counts,
326 dst: &mut Codec<T, Prioritized<B>>,
327 ) -> Poll<io::Result<()>>
328 where
329 T: AsyncWrite + Unpin,
330 B: Buf,
331 {
332 self.prioritize
333 .poll_complete(cx, buffer, store, counts, dst)
334 }
335
336 pub fn reserve_capacity(
338 &mut self,
339 capacity: WindowSize,
340 stream: &mut store::Ptr,
341 counts: &mut Counts,
342 ) {
343 self.prioritize.reserve_capacity(capacity, stream, counts)
344 }
345
346 pub fn poll_capacity(
347 &mut self,
348 cx: &Context,
349 stream: &mut store::Ptr,
350 ) -> Poll<Option<Result<WindowSize, UserError>>> {
351 if !stream.state.is_send_streaming() {
352 return Poll::Ready(None);
353 }
354
355 if !stream.send_capacity_inc {
356 stream.wait_send(cx);
357 return Poll::Pending;
358 }
359
360 stream.send_capacity_inc = false;
361
362 Poll::Ready(Some(Ok(self.capacity(stream))))
363 }
364
365 pub fn capacity(&self, stream: &mut store::Ptr) -> WindowSize {
367 stream.capacity(self.prioritize.max_buffer_size())
368 }
369
370 pub fn poll_reset(
371 &self,
372 cx: &Context,
373 stream: &mut Stream,
374 mode: PollReset,
375 ) -> Poll<Result<Reason, crate::Error>> {
376 match stream.state.ensure_reason(mode)? {
377 Some(reason) => Poll::Ready(Ok(reason)),
378 None => {
379 stream.wait_send(cx);
380 Poll::Pending
381 }
382 }
383 }
384
385 pub fn recv_connection_window_update(
386 &mut self,
387 frame: frame::WindowUpdate,
388 store: &mut Store,
389 counts: &mut Counts,
390 ) -> Result<(), Reason> {
391 self.prioritize
392 .recv_connection_window_update(frame.size_increment(), store, counts)
393 }
394
395 pub fn recv_stream_window_update<B>(
396 &mut self,
397 sz: WindowSize,
398 buffer: &mut Buffer<Frame<B>>,
399 stream: &mut store::Ptr,
400 counts: &mut Counts,
401 task: &mut Option<Waker>,
402 ) -> Result<(), Reason> {
403 if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) {
404 tracing::debug!("recv_stream_window_update !!; err={:?}", e);
405
406 self.send_reset(
407 Reason::FLOW_CONTROL_ERROR,
408 Initiator::Library,
409 buffer,
410 stream,
411 counts,
412 task,
413 );
414
415 return Err(e);
416 }
417
418 Ok(())
419 }
420
421 pub(super) fn recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), Error> {
422 if last_stream_id > self.max_stream_id {
423 proto_err!(conn:
431 "recv_go_away: last_stream_id ({:?}) > max_stream_id ({:?})",
432 last_stream_id, self.max_stream_id,
433 );
434 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
435 }
436
437 self.max_stream_id = last_stream_id;
438 Ok(())
439 }
440
441 pub fn handle_error<B>(
442 &mut self,
443 buffer: &mut Buffer<Frame<B>>,
444 stream: &mut store::Ptr,
445 counts: &mut Counts,
446 ) {
447 self.prioritize.clear_queue(buffer, stream);
449 self.prioritize.reclaim_all_capacity(stream, counts);
450 }
451
452 pub fn apply_remote_settings<B>(
453 &mut self,
454 settings: &frame::Settings,
455 buffer: &mut Buffer<Frame<B>>,
456 store: &mut Store,
457 counts: &mut Counts,
458 task: &mut Option<Waker>,
459 ) -> Result<(), Error> {
460 if let Some(val) = settings.is_extended_connect_protocol_enabled() {
461 self.is_extended_connect_protocol_enabled = val;
462 }
463
464 if let Some(val) = settings.initial_window_size() {
482 let old_val = self.init_window_sz;
483 self.init_window_sz = val;
484
485 match val.cmp(&old_val) {
486 Ordering::Less => {
487 let dec = old_val - val;
489 tracing::trace!("decrementing all windows; dec={}", dec);
490
491 let mut total_reclaimed = 0;
492 store.try_for_each(|mut stream| {
493 let stream = &mut *stream;
494
495 tracing::trace!(
496 "decrementing stream window; id={:?}; decr={}; flow={:?}",
497 stream.id,
498 dec,
499 stream.send_flow
500 );
501
502 stream
504 .send_flow
505 .dec_send_window(dec)
506 .map_err(proto::Error::library_go_away)?;
507
508 let window_size = stream.send_flow.window_size();
515 let available = stream.send_flow.available().as_size();
516 let reclaimed = if available > window_size {
517 let reclaim = available - window_size;
519 stream
520 .send_flow
521 .claim_capacity(reclaim)
522 .map_err(proto::Error::library_go_away)?;
523 total_reclaimed += reclaim;
524 reclaim
525 } else {
526 0
527 };
528
529 tracing::trace!(
530 "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}",
531 stream.id,
532 dec,
533 reclaimed,
534 stream.send_flow
535 );
536
537 Ok::<_, proto::Error>(())
542 })?;
543
544 self.prioritize
545 .assign_connection_capacity(total_reclaimed, store, counts);
546 }
547 Ordering::Greater => {
548 let inc = val - old_val;
549
550 store.try_for_each(|mut stream| {
551 self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
552 .map_err(Error::library_go_away)
553 })?;
554 }
555 Ordering::Equal => (),
556 }
557 }
558
559 if let Some(val) = settings.is_push_enabled() {
560 self.is_push_enabled = val
561 }
562
563 Ok(())
564 }
565
566 pub fn clear_queues(&mut self, store: &mut Store, counts: &mut Counts) {
567 self.prioritize.clear_pending_capacity(store, counts);
568 self.prioritize.clear_pending_send(store, counts);
569 self.prioritize.clear_pending_open(store, counts);
570 }
571
572 pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
573 if let Ok(next) = self.next_stream_id {
574 if id >= next {
575 return Err(Reason::PROTOCOL_ERROR);
576 }
577 }
578 Ok(())
581 }
582
583 pub fn ensure_next_stream_id(&self) -> Result<StreamId, UserError> {
584 self.next_stream_id
585 .map_err(|_| UserError::OverflowedStreamId)
586 }
587
588 pub fn may_have_created_stream(&self, id: StreamId) -> bool {
589 if let Ok(next_id) = self.next_stream_id {
590 debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated(),);
592 id < next_id
593 } else {
594 true
595 }
596 }
597
598 pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) {
599 if let Ok(next_id) = self.next_stream_id {
600 debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated());
602 if id >= next_id {
603 self.next_stream_id = id.next_id();
604 }
605 }
606 }
607
608 pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
609 self.is_extended_connect_protocol_enabled
610 }
611}