spectacles_gateway/
shard.rs

1use std::{
2    io::{Error as IoError, ErrorKind},
3    str::FromStr,
4    sync::Arc,
5    time::{Duration, Instant}
6};
7
8use futures::{
9    future::Future,
10    Sink,
11    stream::{SplitStream, Stream},
12    sync::mpsc::{self, UnboundedSender}
13};
14use native_tls::TlsConnector;
15use parking_lot::Mutex;
16use tokio::net::TcpStream as TokioTcpStream;
17use tokio::timer::Interval;
18use tokio_dns::TcpStream;
19use tokio_tls::TlsStream;
20use tokio_tungstenite::{
21    tungstenite::{
22        Error as TungsteniteError,
23        handshake::client::Request,
24        protocol::{Message as WebsocketMessage, WebSocketConfig},
25    },
26    WebSocketStream
27};
28use tokio_tungstenite::stream::Stream as TungsteniteStream;
29use url::Url;
30
31use spectacles_model::{
32    gateway::{
33        GatewayEvent,
34        HeartbeatPacket,
35        HelloPacket,
36        IdentifyPacket,
37        IdentifyProperties,
38        Opcodes,
39        ReadyPacket,
40        ReceivePacket,
41        ResumeSessionPacket,
42        SendablePacket,
43    },
44    presence::{ClientActivity, ClientPresence, Status}
45};
46
47use crate::{
48    constants::GATEWAY_VERSION,
49    errors::{Error, Result}
50};
51
52pub type ShardSplitStream = SplitStream<WebSocketStream<TungsteniteStream<TokioTcpStream, TlsStream<TokioTcpStream>>>>;
53
54/// A Spectacles Gateway shard.
55#[derive(Clone)]
56pub struct Shard {
57    /// The bot token that this shard will use.
58    pub token: String,
59    /// The shard's info. Includes the shard's ID and the total amount of shards.
60    pub info: [usize; 2],
61    /// The currently active presence for this shard.
62    pub presence: ClientPresence,
63    /// The session ID of this shard, if applicable.
64    pub session_id: Option<String>,
65    /// The interval at which a heartbeat is made.
66    pub interval: Option<u64>,
67    /// The channel which is used to send websocket messages.
68    pub sender: Arc<Mutex<UnboundedSender<WebsocketMessage>>>,
69    /// The shard's message stream, which is used to receive messages.
70    pub stream: Arc<Mutex<Option<ShardSplitStream>>>,
71    /// Used to determine whether or not the shard is currently in a state of connecting.
72    current_state: Arc<Mutex<String>>,
73    /// This shard's current heartbeat.
74    pub heartbeat: Arc<Mutex<Heartbeat>>,
75    /// The URL of the Discord Gateway.
76    ws_uri: String
77}
78
79/// Various actions that a shard can perform.
80pub enum ShardAction {
81    NoneAction,
82    Autoreconnect,
83    Reconnect,
84    Identify,
85    Resume
86}
87/// A shard's heartbeat information.
88#[derive(Debug, Copy, Clone)]
89pub struct Heartbeat {
90    pub acknowledged: bool,
91    pub seq: u64,
92}
93
94impl Heartbeat {
95    fn new() -> Heartbeat {
96        Self {
97            acknowledged: false,
98            seq: 0
99        }
100    }
101}
102
103impl Shard {
104    /// Creates a new Discord Shard, with the provided token.
105    pub fn new(token: String, info: [usize; 2], ws_uri: String) -> impl Future<Item=Shard, Error=Error> {
106        Shard::begin_connection(&ws_uri, info[0])
107            .map(move |(sender, stream)| {
108                Shard {
109                    token,
110                    session_id: None,
111                    presence: ClientPresence {
112                        status: String::from("online"),
113                        ..Default::default()
114                    },
115                    info,
116                    interval: None,
117                    sender: Arc::new(Mutex::new(sender)),
118                    current_state: Arc::new(Mutex::new(String::from("handshake"))),
119                    stream: Arc::new(Mutex::new(Some(stream))),
120                    heartbeat: Arc::new(Mutex::new(Heartbeat::new())),
121                    ws_uri
122                }
123            })
124    }
125
126    pub fn fulfill_gateway(&mut self, packet: ReceivePacket) -> Result<ShardAction> {
127        let info = self.info.clone();
128        let current_state = self.current_state.lock().clone();
129        match packet.op {
130            Opcodes::Dispatch => {
131                if let Some(GatewayEvent::READY) = packet.t {
132                    let ready: ReadyPacket = serde_json::from_str(packet.d.get()).unwrap();
133                    *self.current_state.lock() = "connected".to_string();
134                    self.session_id = Some(ready.session_id.clone());
135                    trace!("[Shard {}] Received ready, set session ID as {}", &info[0], ready.session_id)
136                };
137                Ok(ShardAction::NoneAction)
138            }
139            Opcodes::Hello => {
140                if self.current_state.lock().clone() == "resume".to_string() {
141                    return Ok(ShardAction::NoneAction)
142                };
143                let hello: HelloPacket = serde_json::from_str(packet.d.get()).unwrap();
144                if hello.heartbeat_interval > 0 {
145                    self.interval = Some(hello.heartbeat_interval);
146                }
147                if current_state == "handshake".to_string() {
148                    let dur = Duration::from_millis(hello.heartbeat_interval);
149                    tokio::spawn(Shard::begin_interval(self.clone(), dur));
150                    return Ok(ShardAction::Identify);
151                }
152                Ok(ShardAction::Autoreconnect)
153            },
154            Opcodes::HeartbeatAck => {
155                let mut hb = self.heartbeat.lock().clone();
156                hb.acknowledged = true;
157                Ok(ShardAction::NoneAction)
158            },
159            Opcodes::Reconnect => Ok(ShardAction::Reconnect),
160            Opcodes::InvalidSession => {
161                let invalid: bool = serde_json::from_str(packet.d.get()).unwrap();
162                if !invalid {
163                    Ok(ShardAction::Identify)
164                } else { Ok(ShardAction::Resume) }
165            },
166            _ => Ok(ShardAction::NoneAction)
167        }
168    }
169
170    /// Identifies a shard with Discord.
171    pub fn identify(&mut self) -> Result<()> {
172        let token = self.token.clone();
173        let shard = self.info.clone();
174        let presence = self.presence.clone();
175        self.send_payload(IdentifyPacket {
176            large_threshold: 250,
177            token,
178            shard,
179            compress: false,
180            presence: Some(presence),
181            version: GATEWAY_VERSION,
182            properties: IdentifyProperties {
183                os: std::env::consts::OS.to_string(),
184                browser: String::from("spectacles-rs"),
185                device: String::from("spectacles-rs")
186            }
187        })
188    }
189
190    /// Attempts to automatically reconnect the shard to Discord.
191    pub fn autoreconnect(&mut self) -> Box<Future<Item = (), Error = Error> + Send>{
192        if self.session_id.is_some() && self.heartbeat.lock().seq > 0 {
193            Box::new(self.resume())
194        } else {
195            Box::new(self.reconnect())
196        }
197    }
198
199    /// Makes a request to reconnect the shard.
200    pub fn reconnect(&mut self) -> impl Future<Item = (), Error = Error> + Send {
201        debug!("[Shard {}] Attempting to reconnect to gateway.", &self.info[0]);
202        self.reset_values().expect("[Shard] Failed to reset this shard for autoreconnecting.");
203        self.dial_gateway()
204    }
205
206    /// Resumes a shard's past session.
207    pub fn resume(&mut self) -> impl Future<Item = (), Error = Error> + Send {
208        debug!("[Shard {}] Attempting to resume gateway connection.", &self.info[0]);
209        let seq = self.heartbeat.lock().seq;
210        let token = self.token.clone();
211        let state = self.current_state.clone();
212        let session = self.session_id.clone();
213        let sender = self.sender.clone();
214
215        self.dial_gateway().then(move |result|{
216            if result.is_err() { return result };
217            *state.lock() = "resuming".to_string();
218            let payload = ResumeSessionPacket {
219                session_id: session.unwrap(),
220                seq,
221                token
222            };
223
224            send(&sender, WebsocketMessage::text(payload.to_json()?))
225        })
226    }
227    /// Resolves a Websocket message into a ReceivePacket struct.
228    pub fn resolve_packet(&self, mess: &WebsocketMessage) -> Result<ReceivePacket> {
229        match mess {
230            WebsocketMessage::Binary(v) => serde_json::from_slice(v),
231            WebsocketMessage::Text(v) => serde_json::from_str(v),
232            _ => unreachable!("Invalid type detected."),
233        }.map_err(Error::from)
234    }
235
236    /// Sends a payload to the Discord Gateway.
237    pub fn send_payload<T: SendablePacket>(&self, payload: T) -> Result<()> {
238        let json = payload.to_json()?;
239        send(&self.sender, WebsocketMessage::text(json))
240    }
241
242
243    /// Change the status of the current shard.
244    pub fn change_status(&mut self, status: Status) -> Result<()> {
245        self.presence.status = status.to_string();
246        let oldpresence = self.presence.clone();
247        self.change_presence(oldpresence)
248    }
249
250    /// Change the activity of the current shard.
251    pub fn change_activity(&mut self, activity: ClientActivity) -> Result<()> {
252        self.presence.game = Some(activity);
253        let oldpresence = self.presence.clone();
254        self.change_presence(oldpresence)
255    }
256
257    /// Change the presence of the current shard.
258    pub fn change_presence(&mut self, presence: ClientPresence) -> Result<()> {
259        debug!("[Shard {}] Sending a presence change payload. {:?}", self.info[0], presence.clone());
260        self.send_payload(presence.clone())?;
261        self.presence = presence;
262        Ok(())
263    }
264
265    fn reset_values(&mut self) -> Result<()> {
266        self.session_id = None;
267        *self.current_state.lock() = "disconnected".to_string();
268
269        let mut hb = self.heartbeat.lock();
270        hb.acknowledged = true;
271        hb.seq = 0;
272
273        Ok(())
274    }
275
276    fn heartbeat(&mut self) -> Result<()> {
277        debug!("[Shard {}] Sending heartbeat.", self.info[0]);
278        let seq = self.heartbeat.lock().seq;
279
280        self.send_payload(HeartbeatPacket { seq })
281    }
282
283    fn dial_gateway(&mut self) -> impl Future<Item = (), Error = Error> + Send {
284        let info = self.info.clone();
285        *self.current_state.lock() = String::from("connected");
286        let state = self.current_state.clone();
287        let orig_sender = self.sender.clone();
288        let orig_stream = self.stream.clone();
289        let heartbeat = self.heartbeat.clone();
290
291        Shard::begin_connection(&self.ws_uri, info[0])
292            .map(move |(sender, stream)| {
293                *orig_sender.lock() = sender;
294                *heartbeat.lock() = Heartbeat::new();
295                *state.lock() = String::from("handshake");
296                *orig_stream.lock() = Some(stream);
297            })
298    }
299
300
301    fn begin_interval(mut shard: Shard, duration: Duration) -> impl Future<Item = (), Error = ()> {
302        let info = shard.info.clone();
303        Interval::new(Instant::now(), duration)
304            .map_err(move |err| {
305                warn!("[Shard {}] Failed to begin heartbeat interval. {:?}", info[0], err);
306            })
307            .for_each(move |_| {
308                if let Err(r) = shard.heartbeat() {
309                    warn!("[Shard {}] Failed to perform heartbeat. {:?}", info[0], r);
310                    return Err(());
311                }
312                Ok(())
313            })
314    }
315
316    fn begin_connection(ws: &str, shard_id: usize) -> impl Future<Item = (UnboundedSender<WebsocketMessage>, ShardSplitStream), Error = Error> {
317        let url = Url::from_str(ws).expect("Invalid Websocket URL has been provided.");
318        let req = Request::from(url);
319        let (host, port) = Shard::get_addr_info(&req);
320        let tlsconn = TlsConnector::new().unwrap();
321        let tlsconn = tokio_tls::TlsConnector::from(tlsconn);
322
323        let socket = TcpStream::connect((host.as_ref(), port));
324        let handshake = socket.and_then(move |socket| {
325            debug!("[Shard {}] Beginning handshake with gateway.", shard_id);
326            tlsconn.connect(host.as_ref(), socket)
327                .map(|s| TungsteniteStream::Tls(s))
328                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
329        });
330        let stream = handshake.and_then(|mut stream| {
331            tokio_tungstenite::stream::NoDelay::set_nodelay(&mut stream, true)
332                .map(move |()| stream)
333        });
334        let stream = stream.and_then(move |stream| {
335            tokio_tungstenite::client_async_with_config(req, stream, Some(WebSocketConfig {
336                max_message_size: Some(usize::max_value()),
337                max_frame_size: Some(usize::max_value()),
338                ..Default::default()
339            })).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
340        });
341
342        stream.map(move |(wstream, _)| {
343            let (tx, rx) = mpsc::unbounded();
344            let (sink, stream) = wstream.split();
345            tokio::spawn(rx.map_err(|err| {
346                error!("Failed to select sink. {:?}", err);
347                TungsteniteError::Io(IoError::new(ErrorKind::Other, "Error whilst attempting to select sink."))
348            }).forward(sink).map(|_| ()).map_err(|_| ()));
349
350            (tx, stream)
351        }).from_err()
352    }
353
354    fn get_addr_info(req: &Request) -> (String, u16) {
355        let host = req.url.host_str().expect("Could Not parse the Websocket Host.");
356        let port = req.url.port_or_known_default().expect("Could not parse the websocket port.");
357
358        (host.to_string(), port)
359    }
360}
361
362fn send(sender: &Arc<Mutex<UnboundedSender<WebsocketMessage>>>, mess: WebsocketMessage) -> Result<()> {
363    sender.lock().start_send(mess)
364        .map(|_| ())
365        .map_err(From::from)
366}