1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use super::super::ShardStream;
use futures_util::{
    future::{self, Either, FutureExt},
    sink::SinkExt,
    stream::StreamExt,
};
use std::time::Duration;
use tokio::{
    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
    time::sleep,
};
use tokio_tungstenite::tungstenite::Message;

pub struct SocketForwarder {
    rx: UnboundedReceiver<Message>,
    pub stream: ShardStream,
    tx: UnboundedSender<Message>,
}

impl SocketForwarder {
    const TIMEOUT: Duration = Duration::from_secs(90);

    pub fn new(
        stream: ShardStream,
    ) -> (Self, UnboundedReceiver<Message>, UnboundedSender<Message>) {
        let (to_user, from_forwarder) = mpsc::unbounded_channel();
        let (to_forwarder, from_user) = mpsc::unbounded_channel();

        (
            Self {
                rx: from_user,
                stream,
                tx: to_user,
            },
            from_forwarder,
            to_forwarder,
        )
    }

    pub async fn run(mut self) {
        tracing::debug!("starting driving loop");

        loop {
            let timeout = sleep(Self::TIMEOUT).fuse();
            tokio::pin!(timeout);

            let rx = Box::pin(self.rx.recv().fuse());
            let tx = Box::pin(self.stream.next().fuse());

            let select_message = future::select(rx, tx);

            match future::select(select_message, timeout).await {
                // `rx` future finished first.
                Either::Left((Either::Left((maybe_msg, _)), _)) => {
                    if let Some(msg) = maybe_msg {
                        tracing::trace!("sending message: {}", msg);

                        if let Err(err) = self.stream.send(msg).await {
                            tracing::warn!("sending failed: {}", err);
                            break;
                        }
                    } else {
                        tracing::debug!("rx stream ended, closing socket");
                        let _res = self.stream.close(None).await;

                        break;
                    }
                }
                // `tx` future finished first.
                Either::Left((Either::Right((try_msg, _)), _)) => match try_msg {
                    Some(Ok(msg)) => {
                        if self.tx.send(msg).is_err() {
                            break;
                        }
                    }
                    Some(Err(err)) => {
                        tracing::warn!("socket errored: {}", err);
                        break;
                    }
                    None => {
                        tracing::debug!("socket ended");
                        break;
                    }
                },
                // Timeout future finished first.
                Either::Right((_, _)) => {
                    tracing::warn!("socket timed out");
                    break;
                }
            };
        }

        tracing::debug!("Leaving loop");
    }
}