libp2prs_mplex/
connection.rs

1// Copyright 2020 Netwarps Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21// This module contains the `Connection` type and associated helpers.
22// A `Connection` wraps an underlying (async) I/O resource and multiplexes
23// `Stream`s over it.
24//
25// The overall idea is as follows: The `Connection` makes progress via calls
26// to its `next_stream` method which polls several futures, one that decodes
27// `Frame`s from the I/O resource, one that consumes `ControlCommand`s
28// from an MPSC channel and another one that consumes `StreamCommand`s from
29// yet another MPSC channel. The latter channel is shared with every `Stream`
30// created and whenever a `Stream` wishes to send a `Frame` to the remote end,
31// it enqueues it into this channel (waiting if the channel is full). The
32// former is shared with every `Control` clone and used to open new outbound
33// streams or to trigger a connection close.
34//
35// The `Connection` updates the `Stream` state based on incoming frames, e.g.
36// it pushes incoming data to the `Stream` via channel or increases the sending
37// credit if the remote has sent us a corresponding `Frame::<WindowUpdate>`.
38//
39// Closing a `Connection`
40// ----------------------
41//
42// Every `Control` may send a `ControlCommand::Close` at any time and then
43// waits on a `oneshot::Receiver` for confirmation that the connection is
44// closed. The closing proceeds as follows:
45//
46// 1. As soon as we receive the close command we close the MPSC receiver
47//    of `StreamCommand`s. We want to process any stream commands which are
48//    already enqueued at this point but no more.
49// 2. We change the internal shutdown state to `Shutdown::InProgress` which
50//    contains the `oneshot::Sender` of the `Control` which triggered the
51//    closure and which we need to notify eventually.
52// 3. Crucially -- while closing -- we no longer process further control
53//    commands, because opening new streams should no longer be allowed
54//    and further close commands would mean we need to save those
55//    `oneshot::Sender`s for later. On the other hand we also do not simply
56//    close the control channel as this would signal to `Control`s that
57//    try to send close commands, that the connection is already closed,
58//    which it is not. So we just pause processing control commands which
59//    means such `Control`s will wait.
60// 4. We keep processing I/O and stream commands until the remaining stream
61//    commands have all been consumed, at which point we transition the
62//    shutdown state to `Shutdown::Complete`, which entails sending the
63//    final termination frame to the remote, informing the `Control` and
64//    now also closing the control channel.
65// 5. Now that we are closed we go through all pending control commands
66//    and tell the `Control`s that we are closed and we are finally done.
67//
68// While all of this may look complicated, it ensures that `Control`s are
69// only informed about a closed connection when it really is closed.
70//
71// specific
72// ----------------------
73// - All stream's state is managed by connecttion, stream state get from channel
74//   Shared lock is not efficient.
75// - Connecttion pushes incoming data to the `Stream` via channel, not buffer
76// - Stream must be closed explictly Since garbage collect is not implemented.
77//   Drop it directly do nothing
78//
79// Potential improvements
80// ----------------------
81//
82// There is always more work that can be done to make this a better crate,
83// for example:
84// - Loop in handle_coming() is performance bottleneck.  More seriously, it will be block
85//   when two peers echo with mass of data with lot of stream Since they block
86//   on write data and none of them can read data.
87//   One solution is spawn runtime for reader and writer But depend on async runtime
88//   is not attractive. See detail from concurrent in tests
89
90pub mod control;
91pub mod stream;
92
93use futures::{
94    channel::{mpsc, oneshot},
95    future::{select, Either},
96    prelude::*,
97    select,
98    stream::FusedStream,
99};
100use futures_timer::Delay;
101
102use crate::{
103    error::ConnectionError,
104    frame::{io, Frame, FrameDecodeError, StreamID, Tag},
105    pause::Pausable,
106};
107use control::Control;
108use futures::io::WriteHalf;
109use nohash_hasher::IntMap;
110use std::collections::VecDeque;
111use std::fmt;
112use std::pin::Pin;
113use std::time::Duration;
114use stream::{State, Stream};
115
116/// `Control` to `Connection` commands.
117#[derive(Debug)]
118pub enum ControlCommand {
119    /// Open a new stream to the remote end.
120    OpenStream(oneshot::Sender<Result<Stream>>),
121    /// Accept a new stream from the remote end.
122    AcceptStream(oneshot::Sender<Result<Stream>>),
123    /// Close the whole connection.
124    CloseConnection(oneshot::Sender<()>),
125}
126
127/// `Stream` to `Connection` commands.
128#[derive(Debug)]
129pub(crate) enum StreamCommand {
130    /// A new frame should be sent to the remote.
131    SendFrame(Frame),
132    /// Close a stream.
133    CloseStream(Frame),
134    /// Reset a stream.
135    ResetStream(Frame, oneshot::Sender<()>),
136}
137
138/// The connection identifier.
139///
140/// Randomly generated, this is mainly intended to improve log output.
141#[derive(Clone, Copy)]
142pub struct Id(u32);
143
144impl Id {
145    /// Create a random connection ID.
146    pub(crate) fn random() -> Self {
147        Id(rand::random())
148    }
149}
150
151impl fmt::Debug for Id {
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        write!(f, "{:08x}", self.0)
154    }
155}
156
157impl fmt::Display for Id {
158    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159        write!(f, "{:08x}", self.0)
160    }
161}
162
163/// This enum captures the various stages of shutting down the connection.
164#[derive(Debug)]
165enum Shutdown {
166    /// We are open for business.
167    NotStarted,
168    /// We have received a `ControlCommand::Close` and are shutting
169    /// down operations. The `Sender` will be informed once we are done.
170    InProgress(oneshot::Sender<()>),
171    /// The shutdown is complete and we are closed for good.
172    Complete,
173}
174
175impl Shutdown {
176    fn has_not_started(&self) -> bool {
177        matches!(self, Shutdown::NotStarted)
178    }
179
180    fn is_in_progress(&self) -> bool {
181        matches!(self, Shutdown::InProgress(_))
182    }
183
184    fn is_complete(&self) -> bool {
185        matches!(self, Shutdown::Complete)
186    }
187}
188
189/// Arbitrary limit of our internal command channels.
190///
191/// Since each `mpsc::Sender` gets a guaranteed slot in a channel the
192/// actual upper bound is this value + number of clones.
193const MAX_COMMAND_BACKLOG: usize = 32;
194const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
195
196type Result<T> = std::result::Result<T, ConnectionError>;
197
198pub struct Connection<T> {
199    id: Id,
200    reader: Pin<Box<dyn FusedStream<Item = std::result::Result<Frame, FrameDecodeError>> + Send>>,
201    writer: io::IO<WriteHalf<T>>,
202    is_closed: bool,
203    shutdown: Shutdown,
204    next_stream_id: u32,
205    streams: IntMap<StreamID, mpsc::Sender<Vec<u8>>>,
206    streams_stat: IntMap<StreamID, State>,
207    stream_sender: mpsc::Sender<StreamCommand>,
208    stream_receiver: mpsc::Receiver<StreamCommand>,
209    control_sender: mpsc::Sender<ControlCommand>,
210    control_receiver: Pausable<mpsc::Receiver<ControlCommand>>,
211    waiting_stream_sender: Option<oneshot::Sender<Result<stream::Stream>>>,
212    pending_streams: VecDeque<stream::Stream>,
213}
214
215impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Connection<T> {
216    /// Create a new `Connection` from the given I/O resource.
217    pub fn new(socket: T) -> Self {
218        let id = Id::random();
219        log::debug!("new connection: {}", id);
220
221        let (reader, writer) = socket.split();
222        let reader = io::IO::new(id, reader);
223        let reader = futures::stream::unfold(reader, |mut io| async { Some((io.recv_frame().await, io)) });
224        let reader = Box::pin(reader);
225
226        let writer = io::IO::new(id, writer);
227        let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
228        let (control_sender, control_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
229
230        Connection {
231            id,
232            reader,
233            writer,
234            is_closed: false,
235            next_stream_id: 0,
236            shutdown: Shutdown::NotStarted,
237            streams: IntMap::default(),
238            streams_stat: IntMap::default(),
239            stream_sender,
240            stream_receiver,
241            control_sender,
242            control_receiver: Pausable::new(control_receiver),
243            waiting_stream_sender: None,
244            pending_streams: VecDeque::default(),
245        }
246    }
247    /// Returns the id of the connection
248    pub fn id(&self) -> Id {
249        self.id
250    }
251
252    /// Get a controller for this connection.
253    pub fn control(&self) -> Control {
254        Control::new(self.control_sender.clone())
255    }
256
257    /// Get the next incoming stream, opened by the remote.
258    ///
259    /// This must be called repeatedly in order to make progress.
260    /// Once `Ok(()))` or `Err(_)` is returned the connection is
261    /// considered closed and no further invocation of this method
262    /// must be attempted.
263    pub async fn next_stream(&mut self) -> Result<()> {
264        if self.is_closed {
265            log::debug!("{}: connection is closed", self.id);
266            return Ok(());
267        }
268
269        let result = self.handle_coming().await;
270        log::info!("{}: error exit, {:?}", self.id, result);
271
272        self.is_closed = true;
273
274        if let Some(sender) = self.waiting_stream_sender.take() {
275            sender.send(Err(ConnectionError::Closed)).expect("send err");
276        }
277
278        // Close and drain the control command receiver.
279        if !self.control_receiver.stream().is_terminated() {
280            if self.control_receiver.is_paused() {
281                self.control_receiver.unpause();
282            }
283            self.control_receiver.stream().close();
284
285            while let Some(cmd) = self.control_receiver.next().await {
286                match cmd {
287                    ControlCommand::OpenStream(reply) => {
288                        let _ = reply.send(Err(ConnectionError::Closed));
289                    }
290                    ControlCommand::AcceptStream(reply) => {
291                        let _ = reply.send(Err(ConnectionError::Closed));
292                    }
293                    ControlCommand::CloseConnection(reply) => {
294                        let _ = reply.send(());
295                    }
296                }
297            }
298        }
299
300        self.drop_all_streams().await;
301        // Close and drain the stream command receiver.
302        if !self.stream_receiver.is_terminated() {
303            self.stream_receiver.close();
304            while self.stream_receiver.next().await.is_some() {
305                while let Some(cmd) = self.stream_receiver.next().await {
306                    if let StreamCommand::ResetStream(_, reply) = cmd {
307                        let _ = reply.send(());
308                    }
309                    // drop it
310                    log::debug!("drop stream receiver frame");
311                }
312            }
313        }
314
315        result
316    }
317
318    /// This is called from `Connection::next_stream` instead of being a
319    /// public method itself in order to guarantee proper closing in
320    /// case of an error or at EOF.
321    pub async fn handle_coming(&mut self) -> Result<()> {
322        loop {
323            select! {
324                // handle incoming
325                frame = self.reader.next() => {
326                    if let Some(f) = frame {
327                        let frame = f?;
328                        self.on_frame(frame).await?;
329                    }
330                }
331                // handle outcoming
332                scmd = self.stream_receiver.next() => {
333                    self.on_stream_command(scmd).await?;
334                }
335                ccmd = self.control_receiver.next() => {
336                    self.on_control_command(ccmd).await?;
337                }
338            }
339        }
340    }
341
342    /// Process the result of reading from the socket.
343    ///
344    /// Unless `frame` is `Ok(()))` we will assume the connection got closed
345    /// and return a corresponding error, which terminates the connection.
346    /// Otherwise we process the frame
347    async fn on_frame(&mut self, frame: Frame) -> Result<()> {
348        log::trace!("{}: received: {}", self.id, frame.header());
349        match frame.header().tag() {
350            Tag::NewStream => {
351                let stream_id = frame.header().stream_id();
352                if self.streams_stat.contains_key(&stream_id) {
353                    log::error!("received NewStream message for existing stream: {}", stream_id);
354                    return Err(ConnectionError::Io(std::io::ErrorKind::InvalidData.into()));
355                }
356
357                let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
358                self.streams.insert(stream_id, stream_sender);
359                self.streams_stat.insert(stream_id, State::Open);
360
361                let stream = Stream::new(stream_id, self.id, self.stream_sender.clone(), stream_receiver);
362
363                log::debug!("{}: new inbound {} of {}", self.id, stream, self);
364                if let Some(sender) = self.waiting_stream_sender.take() {
365                    sender.send(Ok(stream)).expect("send err");
366                } else {
367                    self.pending_streams.push_back(stream);
368                }
369            }
370            Tag::Message => {
371                let stream_id = frame.header().stream_id();
372                if let Some(stat) = self.streams_stat.get(&stream_id) {
373                    // if remote had close stream, ingore this stream's frame
374                    if *stat == State::RecvClosed {
375                        return Ok(());
376                    }
377                } else {
378                    return Ok(());
379                }
380
381                let mut reset = false;
382                let mut dropped = false;
383                // If stream is closed, ignore frame
384                if let Some(sender) = self.streams.get_mut(&stream_id) {
385                    if !sender.is_closed() {
386                        let sender = sender.send(frame.body());
387                        if send_channel_timeout(sender, RECEIVE_TIMEOUT).await.is_err() {
388                            // reset stream
389                            log::debug!("stream {} send timeout, Reset it", stream_id);
390                            reset = true;
391                            // info.sender.close().await;
392                            let frame = Frame::reset_frame(stream_id);
393                            self.writer.send_frame(&frame).await.or(Err(ConnectionError::Closed))?;
394                        }
395                    } else {
396                        dropped = true;
397                    }
398                }
399                // If the stream is dropped, remove sender from streams
400                if dropped {
401                    self.streams.remove(&stream_id);
402                }
403                if reset {
404                    self.streams.remove(&stream_id);
405                    self.streams_stat.remove(&stream_id);
406                }
407            }
408            Tag::Close => {
409                let stream_id = frame.header().stream_id();
410                log::debug!("{}: remote close stream {} of {}", self.id, stream_id, self);
411                self.streams.remove(&stream_id);
412                // flag to remove stat from streams_stat
413                let mut rm = false;
414                if let Some(stat) = self.streams_stat.get_mut(&stream_id) {
415                    if *stat == State::SendClosed {
416                        rm = true;
417                    } else {
418                        *stat = State::RecvClosed;
419                    }
420                }
421                // If stream is completely closed, remove it
422                if rm {
423                    self.streams_stat.remove(&stream_id);
424                }
425            }
426            Tag::Reset => {
427                let stream_id = frame.header().stream_id();
428                log::trace!("{}: remote reset stream {} of {}", self.id, stream_id, self);
429                self.streams_stat.remove(&stream_id);
430                self.streams.remove(&stream_id);
431            }
432        };
433
434        Ok(())
435    }
436
437    /// Process a command from one of our `Stream`s.
438    async fn on_stream_command(&mut self, cmd: Option<StreamCommand>) -> Result<()> {
439        match cmd {
440            Some(StreamCommand::SendFrame(frame)) => {
441                let stream_id = frame.stream_id();
442                if let Some(stat) = self.streams_stat.get(&stream_id) {
443                    if stat.can_write() {
444                        log::trace!("{}: sending: {}", self.id, frame.header());
445                        self.writer.send_frame(&frame).await.or(Err(ConnectionError::Closed))?;
446                    } else {
447                        log::trace!("{}: stream {} have been removed", self.id, stream_id);
448                    }
449                }
450            }
451            Some(StreamCommand::CloseStream(frame)) => {
452                let stream_id = frame.stream_id();
453                log::debug!("{}: closing stream {} of {}", self.id, stream_id, self);
454                // flag to remove stat from streams_stat
455                let mut rm = false;
456                if let Some(stat) = self.streams_stat.get_mut(&stream_id) {
457                    if stat.can_write() {
458                        // send close frame
459                        self.writer.send_frame(&frame).await.or(Err(ConnectionError::Closed))?;
460
461                        if *stat == State::RecvClosed {
462                            rm = true;
463                        } else {
464                            *stat = State::SendClosed;
465                        }
466                    }
467                }
468                // If stream is completely closed, remove it
469                if rm {
470                    self.streams_stat.remove(&stream_id);
471                }
472            }
473            Some(StreamCommand::ResetStream(frame, reply)) => {
474                let stream_id = frame.stream_id();
475                log::debug!("{}: reset stream {} of {}", self.id, stream_id, self);
476                if self.streams_stat.contains_key(&stream_id) {
477                    // step1: send close frame
478                    self.writer.send_frame(&frame).await.or(Err(ConnectionError::Closed))?;
479
480                    // step2: remove stream
481                    self.streams_stat.remove(&stream_id);
482                    self.streams.remove(&stream_id);
483                }
484                let _ = reply.send(());
485            }
486            None => {
487                // We only get to this point when `self.stream_receiver`
488                // was closed which only happens in response to a close control
489                // command. Now that we are at the end of the stream command queue,
490                // we send the final term frame to the remote and complete the
491                // closure.
492                debug_assert!(self.control_receiver.is_paused());
493                self.control_receiver.unpause();
494                self.control_receiver.stream().close();
495            }
496        }
497        Ok(())
498    }
499
500    /// Process a command from a `Control`.
501    ///
502    /// We only process control commands if we are not in the process of closing
503    /// the connection. Only once we finished closing will we drain the remaining
504    /// commands and reply back that we are closed.
505    async fn on_control_command(&mut self, cmd: Option<ControlCommand>) -> Result<()> {
506        match cmd {
507            Some(ControlCommand::OpenStream(reply)) => {
508                if self.shutdown.is_complete() {
509                    // We are already closed so just inform the control.
510                    let _ = reply.send(Err(ConnectionError::Closed));
511                    return Ok(());
512                }
513
514                let stream_id = self.next_stream_id()?;
515                let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
516                self.streams.insert(stream_id, stream_sender);
517                self.streams_stat.insert(stream_id, State::Open);
518
519                log::debug!("{}: new outbound {} of {}", self.id, stream_id, self);
520
521                // send to peer with new stream frame
522                let body = format!("{}", stream_id.val());
523                let frame = Frame::new_stream_frame(stream_id, body.as_bytes());
524                self.writer.send_frame(&frame).await.or(Err(ConnectionError::Closed))?;
525
526                let stream = Stream::new(stream_id, self.id, self.stream_sender.clone(), stream_receiver);
527                reply.send(Ok(stream)).expect("send err");
528            }
529            Some(ControlCommand::AcceptStream(reply)) => {
530                if self.waiting_stream_sender.is_some() {
531                    reply.send(Err(ConnectionError::Closed)).expect("send err");
532                    return Ok(());
533                }
534
535                if let Some(stream) = self.pending_streams.pop_front() {
536                    reply.send(Ok(stream)).expect("send err");
537                } else {
538                    self.waiting_stream_sender = Some(reply);
539                }
540            }
541            Some(ControlCommand::CloseConnection(reply)) => {
542                if !self.shutdown.has_not_started() {
543                    log::debug!("shutdown had started, ingore this request");
544                    let _ = reply.send(());
545                    return Ok(());
546                }
547                self.shutdown = Shutdown::InProgress(reply);
548                log::debug!("closing connection {}", self);
549                self.stream_receiver.close();
550                self.control_receiver.pause();
551            }
552            None => {
553                // We only get here after the whole connection shutdown is complete.
554                // No further processing of commands of any kind or incoming frames
555                // will happen.
556                debug_assert!(self.shutdown.is_in_progress());
557                log::debug!("{}: closing {}", self.id, self);
558
559                let shutdown = std::mem::replace(&mut self.shutdown, Shutdown::Complete);
560                if let Shutdown::InProgress(tx) = shutdown {
561                    // Inform the `Control` that initiated the shutdown.
562                    let _ = tx.send(());
563                }
564                self.writer.close().await.or(Err(ConnectionError::Closed))?;
565
566                return Err(ConnectionError::Closed);
567            }
568        }
569        Ok(())
570    }
571}
572
573async fn send_channel_timeout<F>(future: F, timeout: Duration) -> std::io::Result<()>
574where
575    F: Future + Unpin,
576{
577    let output = select(future, Delay::new(timeout)).await;
578    match output {
579        Either::Left((_, _)) => Ok(()),
580        Either::Right(_) => Err(std::io::ErrorKind::TimedOut.into()),
581    }
582}
583
584impl<T> fmt::Display for Connection<T> {
585    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
586        write!(f, "(Connection {} (streams {}))", self.id, self.streams.len())
587    }
588}
589
590impl<T> Connection<T> {
591    // next_stream_id is only used to get stream id when open stream
592    fn next_stream_id(&mut self) -> Result<StreamID> {
593        let proposed = StreamID::new(self.next_stream_id, true);
594        self.next_stream_id = self.next_stream_id.checked_add(1).ok_or(ConnectionError::NoMoreStreamIds)?;
595
596        Ok(proposed)
597    }
598
599    pub fn streams_length(&self) -> usize {
600        self.streams_stat.len()
601    }
602
603    /// Close and drop all `Stream`s sender and stat.
604    async fn drop_all_streams(&mut self) {
605        log::trace!("{}: Drop all Streams sender count={}", self.id, self.streams.len());
606        for (id, _sender) in self.streams.drain().take(1) {
607            // drop it
608            log::trace!("{}: drop stream sender {:?}", self.id, id);
609        }
610
611        log::trace!("{}: Drop all Streams stat count={}", self.id, self.streams.len());
612        for (id, _stat) in self.streams_stat.drain().take(1) {
613            // drop it
614            log::trace!("{}: drop stream stat {:?}", self.id, id);
615        }
616    }
617}