1use crate::codec::UserError;
2use crate::frame::{Priorities, PseudoOrder, Reason, StreamDependency, StreamId};
3use crate::{client, server, tracing};
4
5use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE;
6use crate::proto::*;
7
8use bytes::Bytes;
9use futures_core::Stream;
10use std::io;
11use std::marker::PhantomData;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tokio::io::AsyncRead;
16
17#[derive(Debug)]
19pub(crate) struct Connection<T, P, B: Buf = Bytes>
20where
21 P: Peer,
22{
23 codec: Codec<T, Prioritized<B>>,
25
26 inner: ConnectionInner<P, B>,
27}
28
29#[derive(Debug)]
32struct ConnectionInner<P, B: Buf = Bytes>
33where
34 P: Peer,
35{
36 state: State,
38
39 error: Option<frame::GoAway>,
44
45 go_away: GoAway,
47
48 ping_pong: PingPong,
50
51 settings: Settings,
53
54 streams: Streams<B, P>,
56
57 #[cfg(feature = "tracing")]
59 span: ::tracing::Span,
60
61 _phantom: PhantomData<P>,
63}
64
65struct DynConnection<'a, B: Buf = Bytes> {
66 state: &'a mut State,
67
68 go_away: &'a mut GoAway,
69
70 streams: DynStreams<'a, B>,
71
72 error: &'a mut Option<frame::GoAway>,
73
74 ping_pong: &'a mut PingPong,
75}
76
77#[derive(Debug, Clone)]
78pub(crate) struct Config {
79 pub next_stream_id: StreamId,
80 pub initial_max_send_streams: usize,
81 pub max_send_buffer_size: usize,
82 pub reset_stream_duration: Duration,
83 pub reset_stream_max: usize,
84 pub remote_reset_stream_max: usize,
85 pub local_error_reset_streams_max: Option<usize>,
86 pub settings: frame::Settings,
87 pub headers_pseudo_order: Option<PseudoOrder>,
88 pub headers_stream_dependency: Option<StreamDependency>,
89 pub priorities: Option<Priorities>,
90}
91
92#[derive(Debug)]
93enum State {
94 Open,
96
97 Closing(Reason, Initiator),
99
100 Closed(Reason, Initiator),
102}
103
104impl<T, P, B> Connection<T, P, B>
105where
106 T: AsyncRead + AsyncWrite + Unpin,
107 P: Peer,
108 B: Buf,
109{
110 pub fn new(codec: Codec<T, Prioritized<B>>, config: Config) -> Connection<T, P, B> {
111 fn streams_config(config: &Config) -> streams::Config {
112 streams::Config {
113 initial_max_send_streams: config.initial_max_send_streams,
114 local_max_buffer_size: config.max_send_buffer_size,
115 local_next_stream_id: config.next_stream_id,
116 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
117 extended_connect_protocol_enabled: config
118 .settings
119 .is_extended_connect_protocol_enabled()
120 .unwrap_or(false),
121 local_reset_duration: config.reset_stream_duration,
122 local_reset_max: config.reset_stream_max,
123 remote_reset_max: config.remote_reset_stream_max,
124 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
125 remote_max_initiated: config
126 .settings
127 .max_concurrent_streams()
128 .map(|max| max as usize),
129 local_max_error_reset_streams: config.local_error_reset_streams_max,
130 headers_stream_dependency: config.headers_stream_dependency,
131 headers_pseudo_order: config.headers_pseudo_order.clone(),
132 priorities: config.priorities.clone(),
133 }
134 }
135 let streams = Streams::new(streams_config(&config));
136 Connection {
137 codec,
138 inner: ConnectionInner {
139 state: State::Open,
140 error: None,
141 go_away: GoAway::new(),
142 ping_pong: PingPong::new(),
143 settings: Settings::new(config.settings),
144 streams,
145 #[cfg(feature = "tracing")]
146 span: ::tracing::debug_span!("Connection", peer = %P::NAME),
147 _phantom: PhantomData,
148 },
149 }
150 }
151
152 pub(crate) fn set_target_window_size(&mut self, size: WindowSize) {
154 let _res = self.inner.streams.set_target_connection_window_size(size);
155 debug_assert!(_res.is_ok());
157 }
158
159 pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> {
161 let mut settings = frame::Settings::default();
162 settings.set_initial_window_size(Some(size));
163 self.inner.settings.send_settings(settings)
164 }
165
166 pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
168 let mut settings = frame::Settings::default();
169 settings.set_enable_connect_protocol(Some(1));
170 self.inner.settings.send_settings(settings)
171 }
172
173 pub(crate) fn max_send_streams(&self) -> usize {
176 self.inner.streams.max_send_streams()
177 }
178
179 pub(crate) fn max_recv_streams(&self) -> usize {
182 self.inner.streams.max_recv_streams()
183 }
184
185 #[cfg(feature = "unstable")]
186 pub fn num_wired_streams(&self) -> usize {
187 self.inner.streams.num_wired_streams()
188 }
189
190 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
195 #[cfg(feature = "tracing")]
196 let _e = self.inner.span.enter();
197 let _span = tracing::trace_span!("poll_ready");
198 ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?;
200 ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?;
201 ready!(self
202 .inner
203 .settings
204 .poll_send(cx, &mut self.codec, &mut self.inner.streams))?;
205 ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?;
206
207 Poll::Ready(Ok(()))
208 }
209
210 fn poll_go_away(&mut self, cx: &mut Context) -> Poll<Option<io::Result<Reason>>> {
215 self.inner.go_away.send_pending_go_away(cx, &mut self.codec)
216 }
217
218 pub fn go_away_from_user(&mut self, e: Reason) {
219 self.inner.as_dyn().go_away_from_user(e)
220 }
221
222 fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> {
223 let (debug_data, theirs) = self
224 .inner
225 .error
226 .take()
227 .as_ref()
228 .map_or((Bytes::new(), Reason::NO_ERROR), |frame| {
229 (frame.debug_data().clone(), frame.reason())
230 });
231
232 match (ours, theirs) {
233 (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
234 (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)),
235 (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)),
240 }
241 }
242
243 pub fn maybe_close_connection_if_no_streams(&mut self) {
246 if !self.inner.streams.has_streams_or_other_references() {
249 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
250 }
251 }
252
253 pub fn has_streams(&self) -> bool {
255 self.inner.streams.has_streams()
256 }
257
258 pub fn has_streams_or_other_references(&self) -> bool {
260 self.inner.streams.has_streams_or_other_references()
263 }
264
265 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
266 self.inner.ping_pong.take_user_pings()
267 }
268
269 pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
271 #[cfg(feature = "tracing")]
276 let _span1 = self.inner.span.clone().entered();
277 let _span2 = tracing::trace_span!("poll");
278
279 loop {
280 tracing::trace!(connection.state = ?self.inner.state);
281 match self.inner.state {
283 State::Open => {
285 let result = match self.poll2(cx) {
286 Poll::Ready(result) => result,
287 Poll::Pending => {
289 ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;
293
294 if (self.inner.error.is_some()
295 || self.inner.go_away.should_close_on_idle())
296 && !self.inner.streams.has_streams()
297 {
298 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
299 continue;
300 }
301
302 return Poll::Pending;
303 }
304 };
305
306 self.inner.as_dyn().handle_poll2_result(result)?
307 }
308 State::Closing(reason, initiator) => {
309 tracing::trace!("connection closing after flush");
310 ready!(self.codec.shutdown(cx))?;
312
313 self.inner.state = State::Closed(reason, initiator);
315 }
316 State::Closed(reason, initiator) => {
317 return Poll::Ready(self.take_error(reason, initiator));
318 }
319 }
320 }
321 }
322
323 fn poll2(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
324 self.clear_expired_reset_streams();
328
329 loop {
330 if let Some(reason) = ready!(self.poll_go_away(cx)?) {
336 if self.inner.go_away.should_close_now() {
337 if self.inner.go_away.is_user_initiated() {
338 return Poll::Ready(Ok(()));
341 } else {
342 return Poll::Ready(Err(Error::library_go_away(reason)));
343 }
344 }
345 debug_assert_eq!(
347 reason,
348 Reason::NO_ERROR,
349 "graceful GOAWAY should be NO_ERROR"
350 );
351 }
352 ready!(self.poll_ready(cx))?;
353
354 match self
355 .inner
356 .as_dyn()
357 .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))?
358 {
359 ReceivedFrame::Settings(frame) => {
360 self.inner.settings.recv_settings(
361 frame,
362 &mut self.codec,
363 &mut self.inner.streams,
364 )?;
365 }
366 ReceivedFrame::Continue => (),
367 ReceivedFrame::Done => {
368 return Poll::Ready(Ok(()));
369 }
370 }
371 }
372 }
373
374 fn clear_expired_reset_streams(&mut self) {
375 self.inner.streams.clear_expired_reset_streams();
376 }
377}
378
379impl<P, B> ConnectionInner<P, B>
380where
381 P: Peer,
382 B: Buf,
383{
384 fn as_dyn(&mut self) -> DynConnection<'_, B> {
385 let ConnectionInner {
386 state,
387 go_away,
388 streams,
389 error,
390 ping_pong,
391 ..
392 } = self;
393 let streams = streams.as_dyn();
394 DynConnection {
395 state,
396 go_away,
397 streams,
398 error,
399 ping_pong,
400 }
401 }
402}
403
404impl<B> DynConnection<'_, B>
405where
406 B: Buf,
407{
408 fn go_away(&mut self, id: StreamId, e: Reason) {
409 let frame = frame::GoAway::new(id, e);
410 self.streams.send_go_away(id);
411 self.go_away.go_away(frame);
412 }
413
414 fn go_away_now(&mut self, e: Reason) {
415 let last_processed_id = self.streams.last_processed_id();
416 let frame = frame::GoAway::new(last_processed_id, e);
417 self.go_away.go_away_now(frame);
418 }
419
420 fn go_away_now_data(&mut self, e: Reason, data: Bytes) {
421 let last_processed_id = self.streams.last_processed_id();
422 let frame = frame::GoAway::with_debug_data(last_processed_id, e, data);
423 self.go_away.go_away_now(frame);
424 }
425
426 fn go_away_from_user(&mut self, e: Reason) {
427 let last_processed_id = self.streams.last_processed_id();
428 let frame = frame::GoAway::new(last_processed_id, e);
429 self.go_away.go_away_from_user(frame);
430
431 self.streams.handle_error(Error::user_go_away(e));
433 }
434
435 fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> {
436 match result {
437 Ok(()) => {
439 *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
440 Ok(())
441 }
442 Err(Error::GoAway(debug_data, reason, initiator)) => {
446 self.handle_go_away(reason, debug_data, initiator);
447 Ok(())
448 }
449 Err(Error::Reset(id, reason, initiator)) => {
453 debug_assert_eq!(initiator, Initiator::Library);
454 tracing::trace!(?id, ?reason, "stream error");
455 match self.streams.send_reset(id, reason) {
456 Ok(()) => (),
457 Err(crate::proto::error::GoAway { debug_data, reason }) => {
458 self.handle_go_away(reason, debug_data, Initiator::Library);
459 }
460 }
461 Ok(())
462 }
463 Err(Error::Io(kind, inner)) => {
468 tracing::debug!(error = ?kind, "Connection::poll; IO error");
469 let e = Error::Io(kind, inner);
470
471 self.streams.handle_error(e.clone());
473
474 if self.streams.is_buffer_empty()
481 && matches!(kind, io::ErrorKind::UnexpectedEof)
482 && (self.streams.is_server()
483 || self.error.as_ref().map(|f| f.reason() == Reason::NO_ERROR)
484 == Some(true))
485 {
486 *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library);
487 return Ok(());
488 }
489
490 Err(e)
492 }
493 }
494 }
495
496 fn handle_go_away(&mut self, reason: Reason, debug_data: Bytes, initiator: Initiator) {
497 let e = Error::GoAway(debug_data.clone(), reason, initiator);
498 tracing::debug!(error = ?e, "Connection::poll; connection error");
499
500 if self
503 .go_away
504 .going_away()
505 .map_or(false, |frame| frame.reason() == reason)
506 {
507 tracing::trace!(" -> already going away");
508 *self.state = State::Closing(reason, initiator);
509 return;
510 }
511
512 self.streams.handle_error(e);
514 self.go_away_now_data(reason, debug_data);
515 }
516
517 fn recv_frame(&mut self, frame: Option<Frame>) -> Result<ReceivedFrame, Error> {
518 use crate::frame::Frame::*;
519 match frame {
520 Some(Headers(frame)) => {
521 tracing::trace!(?frame, "recv HEADERS");
522 self.streams.recv_headers(frame)?;
523 }
524 Some(Data(frame)) => {
525 tracing::trace!(?frame, "recv DATA");
526 self.streams.recv_data(frame)?;
527 }
528 Some(Reset(frame)) => {
529 tracing::trace!(?frame, "recv RST_STREAM");
530 self.streams.recv_reset(frame)?;
531 }
532 Some(PushPromise(frame)) => {
533 tracing::trace!(?frame, "recv PUSH_PROMISE");
534 self.streams.recv_push_promise(frame)?;
535 }
536 Some(Settings(frame)) => {
537 tracing::trace!(?frame, "recv SETTINGS");
538 return Ok(ReceivedFrame::Settings(frame));
539 }
540 Some(GoAway(frame)) => {
541 tracing::trace!(?frame, "recv GOAWAY");
542 self.streams.recv_go_away(&frame)?;
547 *self.error = Some(frame);
548 }
549 Some(Ping(frame)) => {
550 tracing::trace!(?frame, "recv PING");
551 let status = self.ping_pong.recv_ping(frame);
552 if status.is_shutdown() {
553 assert!(
554 self.go_away.is_going_away(),
555 "received unexpected shutdown ping"
556 );
557
558 let last_processed_id = self.streams.last_processed_id();
559 self.go_away(last_processed_id, Reason::NO_ERROR);
560 }
561 }
562 Some(WindowUpdate(frame)) => {
563 tracing::trace!(?frame, "recv WINDOW_UPDATE");
564 self.streams.recv_window_update(frame)?;
565 }
566 Some(Priority(_frame)) => {
567 tracing::trace!(?_frame, "recv PRIORITY");
568 }
570 None => {
571 tracing::trace!("codec closed");
572 self.streams.recv_eof(false).expect("mutex poisoned");
573 return Ok(ReceivedFrame::Done);
574 }
575 }
576 Ok(ReceivedFrame::Continue)
577 }
578}
579
580enum ReceivedFrame {
581 Settings(frame::Settings),
582 Continue,
583 Done,
584}
585
586impl<T, B> Connection<T, client::Peer, B>
587where
588 T: AsyncRead + AsyncWrite,
589 B: Buf,
590{
591 pub(crate) fn streams(&self) -> &Streams<B, client::Peer> {
592 &self.inner.streams
593 }
594}
595
596impl<T, B> Connection<T, server::Peer, B>
597where
598 T: AsyncRead + AsyncWrite + Unpin,
599 B: Buf,
600{
601 pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
602 self.inner.streams.next_incoming()
603 }
604
605 pub fn go_away_gracefully(&mut self) {
607 if self.inner.go_away.is_going_away() {
608 return;
610 }
611
612 self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR);
624
625 self.inner.ping_pong.ping_shutdown();
628 }
629}
630
631impl<T, P, B> Drop for Connection<T, P, B>
632where
633 P: Peer,
634 B: Buf,
635{
636 fn drop(&mut self) {
637 let _ = self.inner.streams.recv_eof(true);
639 }
640}