1use crate::codec::UserError;
2use crate::frame::{Reason, StreamId};
3use crate::{client, server};
4
5use crate::ext::PseudoType;
6use crate::frame::{Priority, StreamDependency, DEFAULT_INITIAL_WINDOW_SIZE};
7use crate::proto::*;
8
9use bytes::Bytes;
10use futures_core::Stream;
11use std::io;
12use std::marker::PhantomData;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::AsyncRead;
17
18#[derive(Debug)]
20pub(crate) struct Connection<T, P, B: Buf = Bytes>
21where
22 P: Peer,
23{
24 codec: Codec<T, Prioritized<B>>,
26
27 inner: ConnectionInner<P, B>,
28}
29
30#[derive(Debug)]
33struct ConnectionInner<P, B: Buf = Bytes>
34where
35 P: Peer,
36{
37 state: State,
39
40 error: Option<frame::GoAway>,
45
46 go_away: GoAway,
48
49 ping_pong: PingPong,
51
52 settings: Settings,
54
55 streams: Streams<B, P>,
57
58 span: tracing::Span,
60
61 _phantom: PhantomData<P>,
63}
64
65struct DynConnection<'a, B: Buf = Bytes> {
66 state: &'a mut State,
67
68 go_away: &'a mut GoAway,
69
70 streams: DynStreams<'a, B>,
71
72 error: &'a mut Option<frame::GoAway>,
73
74 ping_pong: &'a mut PingPong,
75}
76
77#[derive(Debug, Clone)]
78pub(crate) struct Config {
79 pub next_stream_id: StreamId,
80 pub initial_max_send_streams: usize,
81 pub max_send_buffer_size: usize,
82 pub reset_stream_duration: Duration,
83 pub reset_stream_max: usize,
84 pub remote_reset_stream_max: usize,
85 pub local_error_reset_streams_max: Option<usize>,
86 pub settings: frame::Settings,
87
88 pub headers_frame_pseudo_order: Option<&'static [PseudoType; 4]>,
90 pub headers_frame_priority: Option<StreamDependency>,
91 pub virtual_streams_priorities: Option<&'static [Priority]>,
92}
93
94#[derive(Debug)]
95enum State {
96 Open,
98
99 Closing(Reason, Initiator),
101
102 Closed(Reason, Initiator),
104}
105
106impl<T, P, B> Connection<T, P, B>
107where
108 T: AsyncRead + AsyncWrite + Unpin,
109 P: Peer,
110 B: Buf,
111{
112 pub fn new(codec: Codec<T, Prioritized<B>>, config: Config) -> Connection<T, P, B> {
113 fn streams_config(config: &Config) -> streams::Config {
114 streams::Config {
115 initial_max_send_streams: config.initial_max_send_streams,
116 local_max_buffer_size: config.max_send_buffer_size,
117 local_next_stream_id: config.next_stream_id,
118 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
119 extended_connect_protocol_enabled: config
120 .settings
121 .is_extended_connect_protocol_enabled()
122 .unwrap_or(false),
123 local_reset_duration: config.reset_stream_duration,
124 local_reset_max: config.reset_stream_max,
125 remote_reset_max: config.remote_reset_stream_max,
126 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
127 remote_max_initiated: config
128 .settings
129 .max_concurrent_streams()
130 .map(|max| max as usize),
131 local_max_error_reset_streams: config.local_error_reset_streams_max,
132 headers_frame_pseudo_order: config.headers_frame_pseudo_order,
133 headers_frame_priority: config.headers_frame_priority,
134 virtual_streams_priorities: config.virtual_streams_priorities,
135 }
136 }
137 let streams = Streams::new(streams_config(&config));
138 Connection {
139 codec,
140 inner: ConnectionInner {
141 state: State::Open,
142 error: None,
143 go_away: GoAway::new(),
144 ping_pong: PingPong::new(),
145 settings: Settings::new(config.settings),
146 streams,
147 span: tracing::debug_span!("Connection", peer = %P::NAME),
148 _phantom: PhantomData,
149 },
150 }
151 }
152
153 pub(crate) fn set_target_window_size(&mut self, size: WindowSize) {
155 let _res = self.inner.streams.set_target_connection_window_size(size);
156 debug_assert!(_res.is_ok());
158 }
159
160 pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> {
162 let mut settings = frame::Settings::default();
163 settings.set_initial_window_size(Some(size));
164 self.inner.settings.send_settings(settings)
165 }
166
167 pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
169 let mut settings = frame::Settings::default();
170 settings.set_enable_connect_protocol(Some(1));
171 self.inner.settings.send_settings(settings)
172 }
173
174 pub(crate) fn max_send_streams(&self) -> usize {
177 self.inner.streams.max_send_streams()
178 }
179
180 pub(crate) fn max_recv_streams(&self) -> usize {
183 self.inner.streams.max_recv_streams()
184 }
185
186 #[cfg(feature = "unstable")]
187 pub fn num_wired_streams(&self) -> usize {
188 self.inner.streams.num_wired_streams()
189 }
190
191 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
196 let _e = self.inner.span.enter();
197 let span = tracing::trace_span!("poll_ready");
198 let _e = span.enter();
199 ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?;
201 ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?;
202 ready!(self
203 .inner
204 .settings
205 .poll_send(cx, &mut self.codec, &mut self.inner.streams))?;
206 ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?;
207
208 Poll::Ready(Ok(()))
209 }
210
211 fn poll_go_away(&mut self, cx: &mut Context) -> Poll<Option<io::Result<Reason>>> {
216 self.inner.go_away.send_pending_go_away(cx, &mut self.codec)
217 }
218
219 pub fn go_away_from_user(&mut self, e: Reason) {
220 self.inner.as_dyn().go_away_from_user(e)
221 }
222
223 fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> {
224 let (debug_data, theirs) = self
225 .inner
226 .error
227 .take()
228 .as_ref()
229 .map_or((Bytes::new(), Reason::NO_ERROR), |frame| {
230 (frame.debug_data().clone(), frame.reason())
231 });
232
233 match (ours, theirs) {
234 (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
235 (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)),
236 (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)),
241 }
242 }
243
244 pub fn maybe_close_connection_if_no_streams(&mut self) {
247 if !self.inner.streams.has_streams_or_other_references() {
250 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
251 }
252 }
253
254 pub fn has_streams_or_other_references(&self) -> bool {
256 self.inner.streams.has_streams_or_other_references()
259 }
260
261 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
262 self.inner.ping_pong.take_user_pings()
263 }
264
265 pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
267 let span = self.inner.span.clone();
272 let _e = span.enter();
273 let span = tracing::trace_span!("poll");
274 let _e = span.enter();
275
276 loop {
277 tracing::trace!(connection.state = ?self.inner.state);
278 match self.inner.state {
280 State::Open => {
282 let result = match self.poll2(cx) {
283 Poll::Ready(result) => result,
284 Poll::Pending => {
286 ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;
290
291 if (self.inner.error.is_some()
292 || self.inner.go_away.should_close_on_idle())
293 && !self.inner.streams.has_streams()
294 {
295 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
296 continue;
297 }
298
299 return Poll::Pending;
300 }
301 };
302
303 self.inner.as_dyn().handle_poll2_result(result)?
304 }
305 State::Closing(reason, initiator) => {
306 tracing::trace!("connection closing after flush");
307 ready!(self.codec.shutdown(cx))?;
309
310 self.inner.state = State::Closed(reason, initiator);
312 }
313 State::Closed(reason, initiator) => {
314 return Poll::Ready(self.take_error(reason, initiator));
315 }
316 }
317 }
318 }
319
320 fn poll2(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
321 self.clear_expired_reset_streams();
325
326 loop {
327 if let Some(reason) = ready!(self.poll_go_away(cx)?) {
333 if self.inner.go_away.should_close_now() {
334 if self.inner.go_away.is_user_initiated() {
335 return Poll::Ready(Ok(()));
338 } else {
339 return Poll::Ready(Err(Error::library_go_away(reason)));
340 }
341 }
342 debug_assert_eq!(
344 reason,
345 Reason::NO_ERROR,
346 "graceful GOAWAY should be NO_ERROR"
347 );
348 }
349 ready!(self.poll_ready(cx))?;
350
351 match self
352 .inner
353 .as_dyn()
354 .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))?
355 {
356 ReceivedFrame::Settings(frame) => {
357 self.inner.settings.recv_settings(
358 frame,
359 &mut self.codec,
360 &mut self.inner.streams,
361 )?;
362 }
363 ReceivedFrame::Continue => (),
364 ReceivedFrame::Done => {
365 return Poll::Ready(Ok(()));
366 }
367 }
368 }
369 }
370
371 fn clear_expired_reset_streams(&mut self) {
372 self.inner.streams.clear_expired_reset_streams();
373 }
374}
375
376impl<P, B> ConnectionInner<P, B>
377where
378 P: Peer,
379 B: Buf,
380{
381 fn as_dyn(&mut self) -> DynConnection<'_, B> {
382 let ConnectionInner {
383 state,
384 go_away,
385 streams,
386 error,
387 ping_pong,
388 ..
389 } = self;
390 let streams = streams.as_dyn();
391 DynConnection {
392 state,
393 go_away,
394 streams,
395 error,
396 ping_pong,
397 }
398 }
399}
400
401impl<B> DynConnection<'_, B>
402where
403 B: Buf,
404{
405 fn go_away(&mut self, id: StreamId, e: Reason) {
406 let frame = frame::GoAway::new(id, e);
407 self.streams.send_go_away(id);
408 self.go_away.go_away(frame);
409 }
410
411 fn go_away_now(&mut self, e: Reason) {
412 let last_processed_id = self.streams.last_processed_id();
413 let frame = frame::GoAway::new(last_processed_id, e);
414 self.go_away.go_away_now(frame);
415 }
416
417 fn go_away_now_data(&mut self, e: Reason, data: Bytes) {
418 let last_processed_id = self.streams.last_processed_id();
419 let frame = frame::GoAway::with_debug_data(last_processed_id, e, data);
420 self.go_away.go_away_now(frame);
421 }
422
423 fn go_away_from_user(&mut self, e: Reason) {
424 let last_processed_id = self.streams.last_processed_id();
425 let frame = frame::GoAway::new(last_processed_id, e);
426 self.go_away.go_away_from_user(frame);
427
428 self.streams.handle_error(Error::user_go_away(e));
430 }
431
432 fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> {
433 match result {
434 Ok(()) => {
436 *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
437 Ok(())
438 }
439 Err(Error::GoAway(debug_data, reason, initiator)) => {
443 let e = Error::GoAway(debug_data.clone(), reason, initiator);
444 tracing::debug!(error = ?e, "Connection::poll; connection error");
445
446 if self
449 .go_away
450 .going_away()
451 .map_or(false, |frame| frame.reason() == reason)
452 {
453 tracing::trace!(" -> already going away");
454 *self.state = State::Closing(reason, initiator);
455 return Ok(());
456 }
457
458 self.streams.handle_error(e);
460 self.go_away_now_data(reason, debug_data);
461 Ok(())
462 }
463 Err(Error::Reset(id, reason, initiator)) => {
467 debug_assert_eq!(initiator, Initiator::Library);
468 tracing::trace!(?id, ?reason, "stream error");
469 self.streams.send_reset(id, reason);
470 Ok(())
471 }
472 Err(Error::Io(kind, inner)) => {
477 tracing::debug!(error = ?kind, "Connection::poll; IO error");
478 let e = Error::Io(kind, inner);
479
480 self.streams.handle_error(e.clone());
482
483 if self.streams.is_server()
490 && self.streams.is_buffer_empty()
491 && matches!(kind, io::ErrorKind::UnexpectedEof)
492 {
493 *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library);
494 return Ok(());
495 }
496
497 Err(e)
499 }
500 }
501 }
502
503 fn recv_frame(&mut self, frame: Option<Frame>) -> Result<ReceivedFrame, Error> {
504 use crate::frame::Frame::*;
505 match frame {
506 Some(Headers(frame)) => {
507 tracing::trace!(?frame, "recv HEADERS");
508 self.streams.recv_headers(frame)?;
509 }
510 Some(Data(frame)) => {
511 tracing::trace!(?frame, "recv DATA");
512 self.streams.recv_data(frame)?;
513 }
514 Some(Reset(frame)) => {
515 tracing::trace!(?frame, "recv RST_STREAM");
516 self.streams.recv_reset(frame)?;
517 }
518 Some(PushPromise(frame)) => {
519 tracing::trace!(?frame, "recv PUSH_PROMISE");
520 self.streams.recv_push_promise(frame)?;
521 }
522 Some(Settings(frame)) => {
523 tracing::trace!(?frame, "recv SETTINGS");
524 return Ok(ReceivedFrame::Settings(frame));
525 }
526 Some(GoAway(frame)) => {
527 tracing::trace!(?frame, "recv GOAWAY");
528 self.streams.recv_go_away(&frame)?;
533 *self.error = Some(frame);
534 }
535 Some(Ping(frame)) => {
536 tracing::trace!(?frame, "recv PING");
537 let status = self.ping_pong.recv_ping(frame);
538 if status.is_shutdown() {
539 assert!(
540 self.go_away.is_going_away(),
541 "received unexpected shutdown ping"
542 );
543
544 let last_processed_id = self.streams.last_processed_id();
545 self.go_away(last_processed_id, Reason::NO_ERROR);
546 }
547 }
548 Some(WindowUpdate(frame)) => {
549 tracing::trace!(?frame, "recv WINDOW_UPDATE");
550 self.streams.recv_window_update(frame)?;
551 }
552 Some(Priority(frame)) => {
553 tracing::trace!(?frame, "recv PRIORITY");
554 }
556 None => {
557 tracing::trace!("codec closed");
558 self.streams.recv_eof(false).expect("mutex poisoned");
559 return Ok(ReceivedFrame::Done);
560 }
561 }
562 Ok(ReceivedFrame::Continue)
563 }
564}
565
566enum ReceivedFrame {
567 Settings(frame::Settings),
568 Continue,
569 Done,
570}
571
572impl<T, B> Connection<T, client::Peer, B>
573where
574 T: AsyncRead + AsyncWrite,
575 B: Buf,
576{
577 pub(crate) fn streams(&self) -> &Streams<B, client::Peer> {
578 &self.inner.streams
579 }
580}
581
582impl<T, B> Connection<T, server::Peer, B>
583where
584 T: AsyncRead + AsyncWrite + Unpin,
585 B: Buf,
586{
587 pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
588 self.inner.streams.next_incoming()
589 }
590
591 pub fn go_away_gracefully(&mut self) {
593 if self.inner.go_away.is_going_away() {
594 return;
596 }
597
598 self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR);
610
611 self.inner.ping_pong.ping_shutdown();
614 }
615}
616
617impl<T, P, B> Drop for Connection<T, P, B>
618where
619 P: Peer,
620 B: Buf,
621{
622 fn drop(&mut self) {
623 let _ = self.inner.streams.recv_eof(true);
625 }
626}