1use std::{
2 io::{Error as IoError, ErrorKind},
3 str::FromStr,
4 sync::Arc,
5 time::{Duration, Instant}
6};
7
8use futures::{
9 future::Future,
10 Sink,
11 stream::{SplitStream, Stream},
12 sync::mpsc::{self, UnboundedSender}
13};
14use native_tls::TlsConnector;
15use parking_lot::Mutex;
16use tokio::net::TcpStream as TokioTcpStream;
17use tokio::timer::Interval;
18use tokio_dns::TcpStream;
19use tokio_tls::TlsStream;
20use tokio_tungstenite::{
21 tungstenite::{
22 Error as TungsteniteError,
23 handshake::client::Request,
24 protocol::{Message as WebsocketMessage, WebSocketConfig},
25 },
26 WebSocketStream
27};
28use tokio_tungstenite::stream::Stream as TungsteniteStream;
29use url::Url;
30
31use spectacles_model::{
32 gateway::{
33 GatewayEvent,
34 HeartbeatPacket,
35 HelloPacket,
36 IdentifyPacket,
37 IdentifyProperties,
38 Opcodes,
39 ReadyPacket,
40 ReceivePacket,
41 ResumeSessionPacket,
42 SendablePacket,
43 },
44 presence::{ClientActivity, ClientPresence, Status}
45};
46
47use crate::{
48 constants::GATEWAY_VERSION,
49 errors::{Error, Result}
50};
51
52pub type ShardSplitStream = SplitStream<WebSocketStream<TungsteniteStream<TokioTcpStream, TlsStream<TokioTcpStream>>>>;
53
54#[derive(Clone)]
56pub struct Shard {
57 pub token: String,
59 pub info: [usize; 2],
61 pub presence: ClientPresence,
63 pub session_id: Option<String>,
65 pub interval: Option<u64>,
67 pub sender: Arc<Mutex<UnboundedSender<WebsocketMessage>>>,
69 pub stream: Arc<Mutex<Option<ShardSplitStream>>>,
71 current_state: Arc<Mutex<String>>,
73 pub heartbeat: Arc<Mutex<Heartbeat>>,
75 ws_uri: String
77}
78
79pub enum ShardAction {
81 NoneAction,
82 Autoreconnect,
83 Reconnect,
84 Identify,
85 Resume
86}
87#[derive(Debug, Copy, Clone)]
89pub struct Heartbeat {
90 pub acknowledged: bool,
91 pub seq: u64,
92}
93
94impl Heartbeat {
95 fn new() -> Heartbeat {
96 Self {
97 acknowledged: false,
98 seq: 0
99 }
100 }
101}
102
103impl Shard {
104 pub fn new(token: String, info: [usize; 2], ws_uri: String) -> impl Future<Item=Shard, Error=Error> {
106 Shard::begin_connection(&ws_uri, info[0])
107 .map(move |(sender, stream)| {
108 Shard {
109 token,
110 session_id: None,
111 presence: ClientPresence {
112 status: String::from("online"),
113 ..Default::default()
114 },
115 info,
116 interval: None,
117 sender: Arc::new(Mutex::new(sender)),
118 current_state: Arc::new(Mutex::new(String::from("handshake"))),
119 stream: Arc::new(Mutex::new(Some(stream))),
120 heartbeat: Arc::new(Mutex::new(Heartbeat::new())),
121 ws_uri
122 }
123 })
124 }
125
126 pub fn fulfill_gateway(&mut self, packet: ReceivePacket) -> Result<ShardAction> {
127 let info = self.info.clone();
128 let current_state = self.current_state.lock().clone();
129 match packet.op {
130 Opcodes::Dispatch => {
131 if let Some(GatewayEvent::READY) = packet.t {
132 let ready: ReadyPacket = serde_json::from_str(packet.d.get()).unwrap();
133 *self.current_state.lock() = "connected".to_string();
134 self.session_id = Some(ready.session_id.clone());
135 trace!("[Shard {}] Received ready, set session ID as {}", &info[0], ready.session_id)
136 };
137 Ok(ShardAction::NoneAction)
138 }
139 Opcodes::Hello => {
140 if self.current_state.lock().clone() == "resume".to_string() {
141 return Ok(ShardAction::NoneAction)
142 };
143 let hello: HelloPacket = serde_json::from_str(packet.d.get()).unwrap();
144 if hello.heartbeat_interval > 0 {
145 self.interval = Some(hello.heartbeat_interval);
146 }
147 if current_state == "handshake".to_string() {
148 let dur = Duration::from_millis(hello.heartbeat_interval);
149 tokio::spawn(Shard::begin_interval(self.clone(), dur));
150 return Ok(ShardAction::Identify);
151 }
152 Ok(ShardAction::Autoreconnect)
153 },
154 Opcodes::HeartbeatAck => {
155 let mut hb = self.heartbeat.lock().clone();
156 hb.acknowledged = true;
157 Ok(ShardAction::NoneAction)
158 },
159 Opcodes::Reconnect => Ok(ShardAction::Reconnect),
160 Opcodes::InvalidSession => {
161 let invalid: bool = serde_json::from_str(packet.d.get()).unwrap();
162 if !invalid {
163 Ok(ShardAction::Identify)
164 } else { Ok(ShardAction::Resume) }
165 },
166 _ => Ok(ShardAction::NoneAction)
167 }
168 }
169
170 pub fn identify(&mut self) -> Result<()> {
172 let token = self.token.clone();
173 let shard = self.info.clone();
174 let presence = self.presence.clone();
175 self.send_payload(IdentifyPacket {
176 large_threshold: 250,
177 token,
178 shard,
179 compress: false,
180 presence: Some(presence),
181 version: GATEWAY_VERSION,
182 properties: IdentifyProperties {
183 os: std::env::consts::OS.to_string(),
184 browser: String::from("spectacles-rs"),
185 device: String::from("spectacles-rs")
186 }
187 })
188 }
189
190 pub fn autoreconnect(&mut self) -> Box<Future<Item = (), Error = Error> + Send>{
192 if self.session_id.is_some() && self.heartbeat.lock().seq > 0 {
193 Box::new(self.resume())
194 } else {
195 Box::new(self.reconnect())
196 }
197 }
198
199 pub fn reconnect(&mut self) -> impl Future<Item = (), Error = Error> + Send {
201 debug!("[Shard {}] Attempting to reconnect to gateway.", &self.info[0]);
202 self.reset_values().expect("[Shard] Failed to reset this shard for autoreconnecting.");
203 self.dial_gateway()
204 }
205
206 pub fn resume(&mut self) -> impl Future<Item = (), Error = Error> + Send {
208 debug!("[Shard {}] Attempting to resume gateway connection.", &self.info[0]);
209 let seq = self.heartbeat.lock().seq;
210 let token = self.token.clone();
211 let state = self.current_state.clone();
212 let session = self.session_id.clone();
213 let sender = self.sender.clone();
214
215 self.dial_gateway().then(move |result|{
216 if result.is_err() { return result };
217 *state.lock() = "resuming".to_string();
218 let payload = ResumeSessionPacket {
219 session_id: session.unwrap(),
220 seq,
221 token
222 };
223
224 send(&sender, WebsocketMessage::text(payload.to_json()?))
225 })
226 }
227 pub fn resolve_packet(&self, mess: &WebsocketMessage) -> Result<ReceivePacket> {
229 match mess {
230 WebsocketMessage::Binary(v) => serde_json::from_slice(v),
231 WebsocketMessage::Text(v) => serde_json::from_str(v),
232 _ => unreachable!("Invalid type detected."),
233 }.map_err(Error::from)
234 }
235
236 pub fn send_payload<T: SendablePacket>(&self, payload: T) -> Result<()> {
238 let json = payload.to_json()?;
239 send(&self.sender, WebsocketMessage::text(json))
240 }
241
242
243 pub fn change_status(&mut self, status: Status) -> Result<()> {
245 self.presence.status = status.to_string();
246 let oldpresence = self.presence.clone();
247 self.change_presence(oldpresence)
248 }
249
250 pub fn change_activity(&mut self, activity: ClientActivity) -> Result<()> {
252 self.presence.game = Some(activity);
253 let oldpresence = self.presence.clone();
254 self.change_presence(oldpresence)
255 }
256
257 pub fn change_presence(&mut self, presence: ClientPresence) -> Result<()> {
259 debug!("[Shard {}] Sending a presence change payload. {:?}", self.info[0], presence.clone());
260 self.send_payload(presence.clone())?;
261 self.presence = presence;
262 Ok(())
263 }
264
265 fn reset_values(&mut self) -> Result<()> {
266 self.session_id = None;
267 *self.current_state.lock() = "disconnected".to_string();
268
269 let mut hb = self.heartbeat.lock();
270 hb.acknowledged = true;
271 hb.seq = 0;
272
273 Ok(())
274 }
275
276 fn heartbeat(&mut self) -> Result<()> {
277 debug!("[Shard {}] Sending heartbeat.", self.info[0]);
278 let seq = self.heartbeat.lock().seq;
279
280 self.send_payload(HeartbeatPacket { seq })
281 }
282
283 fn dial_gateway(&mut self) -> impl Future<Item = (), Error = Error> + Send {
284 let info = self.info.clone();
285 *self.current_state.lock() = String::from("connected");
286 let state = self.current_state.clone();
287 let orig_sender = self.sender.clone();
288 let orig_stream = self.stream.clone();
289 let heartbeat = self.heartbeat.clone();
290
291 Shard::begin_connection(&self.ws_uri, info[0])
292 .map(move |(sender, stream)| {
293 *orig_sender.lock() = sender;
294 *heartbeat.lock() = Heartbeat::new();
295 *state.lock() = String::from("handshake");
296 *orig_stream.lock() = Some(stream);
297 })
298 }
299
300
301 fn begin_interval(mut shard: Shard, duration: Duration) -> impl Future<Item = (), Error = ()> {
302 let info = shard.info.clone();
303 Interval::new(Instant::now(), duration)
304 .map_err(move |err| {
305 warn!("[Shard {}] Failed to begin heartbeat interval. {:?}", info[0], err);
306 })
307 .for_each(move |_| {
308 if let Err(r) = shard.heartbeat() {
309 warn!("[Shard {}] Failed to perform heartbeat. {:?}", info[0], r);
310 return Err(());
311 }
312 Ok(())
313 })
314 }
315
316 fn begin_connection(ws: &str, shard_id: usize) -> impl Future<Item = (UnboundedSender<WebsocketMessage>, ShardSplitStream), Error = Error> {
317 let url = Url::from_str(ws).expect("Invalid Websocket URL has been provided.");
318 let req = Request::from(url);
319 let (host, port) = Shard::get_addr_info(&req);
320 let tlsconn = TlsConnector::new().unwrap();
321 let tlsconn = tokio_tls::TlsConnector::from(tlsconn);
322
323 let socket = TcpStream::connect((host.as_ref(), port));
324 let handshake = socket.and_then(move |socket| {
325 debug!("[Shard {}] Beginning handshake with gateway.", shard_id);
326 tlsconn.connect(host.as_ref(), socket)
327 .map(|s| TungsteniteStream::Tls(s))
328 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
329 });
330 let stream = handshake.and_then(|mut stream| {
331 tokio_tungstenite::stream::NoDelay::set_nodelay(&mut stream, true)
332 .map(move |()| stream)
333 });
334 let stream = stream.and_then(move |stream| {
335 tokio_tungstenite::client_async_with_config(req, stream, Some(WebSocketConfig {
336 max_message_size: Some(usize::max_value()),
337 max_frame_size: Some(usize::max_value()),
338 ..Default::default()
339 })).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
340 });
341
342 stream.map(move |(wstream, _)| {
343 let (tx, rx) = mpsc::unbounded();
344 let (sink, stream) = wstream.split();
345 tokio::spawn(rx.map_err(|err| {
346 error!("Failed to select sink. {:?}", err);
347 TungsteniteError::Io(IoError::new(ErrorKind::Other, "Error whilst attempting to select sink."))
348 }).forward(sink).map(|_| ()).map_err(|_| ()));
349
350 (tx, stream)
351 }).from_err()
352 }
353
354 fn get_addr_info(req: &Request) -> (String, u16) {
355 let host = req.url.host_str().expect("Could Not parse the Websocket Host.");
356 let port = req.url.port_or_known_default().expect("Could not parse the websocket port.");
357
358 (host.to_string(), port)
359 }
360}
361
362fn send(sender: &Arc<Mutex<UnboundedSender<WebsocketMessage>>>, mess: WebsocketMessage) -> Result<()> {
363 sender.lock().start_send(mess)
364 .map(|_| ())
365 .map_err(From::from)
366}