1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
use super::{
    builder::ShardBuilder,
    config::Config,
    event::Events,
    json,
    processor::{ConnectingError, Latency, Session, ShardProcessor},
    sink::ShardSink,
    stage::Stage,
};
use crate::{listener::Listeners, EventTypeFlags, Intents};
use async_tungstenite::tungstenite::{
    protocol::{frame::coding::CloseCode, CloseFrame},
    Error as TungsteniteError, Message,
};
use futures_channel::mpsc::TrySendError;
use futures_util::{
    future::{self, AbortHandle},
    stream::StreamExt,
};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::{
    borrow::Cow,
    error::Error,
    fmt::{Display, Formatter, Result as FmtResult},
    sync::{atomic::Ordering, Arc},
};
use tokio::sync::watch::Receiver as WatchReceiver;
use twilight_http::Error as HttpError;
use twilight_model::gateway::event::Event;
use url::ParseError as UrlParseError;

#[cfg(not(feature = "simd-json"))]
use serde_json::Error as JsonError;
#[cfg(feature = "simd-json")]
use simd_json::Error as JsonError;

/// Sending a command failed.
#[derive(Debug)]
#[non_exhaustive]
pub enum CommandError {
    /// Sending the payload over the WebSocket failed. This is indicative of a
    /// shutdown shard.
    Sending {
        /// Reason for the error.
        source: TrySendError<Message>,
    },
    /// Serializing the payload as JSON failed.
    Serializing {
        /// Reason for the error.
        source: JsonError,
    },
    /// Shard's session is inactive because the shard hasn't been started.
    SessionInactive {
        /// Reason for the error.
        source: SessionInactiveError,
    },
}

impl Display for CommandError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        f.write_str("the shard session is inactive and has not been started")
    }
}

impl Error for CommandError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            Self::Sending { source } => Some(source),
            Self::Serializing { source } => Some(source),
            Self::SessionInactive { source } => Some(source),
        }
    }
}

/// Shard's session is inactive.
///
/// This means that the shard has not yet been started.
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct SessionInactiveError;

impl Display for SessionInactiveError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        f.write_str("the shard session is inactive and was not started")
    }
}

impl Error for SessionInactiveError {}

/// Starting a shard and connecting to the gateway failed.
#[derive(Debug)]
#[non_exhaustive]
pub enum ShardStartError {
    /// Establishing a connection to the gateway failed.
    Establishing {
        /// Reason for the error.
        source: TungsteniteError,
    },
    /// Parsing the gateway URL provided by Discord to connect to the gateway
    /// failed due to an invalid URL.
    ParsingGatewayUrl {
        /// Reason for the error.
        source: UrlParseError,
        /// URL that couldn't be parsed.
        url: String,
    },
    /// Retrieving the gateway URL via the Twilight HTTP client failed.
    RetrievingGatewayUrl {
        /// The reason for the error.
        source: HttpError,
    },
}

impl Display for ShardStartError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            Self::Establishing { source } => Display::fmt(source, f),
            Self::ParsingGatewayUrl { source, url } => f.write_fmt(format_args!(
                "the gateway url `{}` is invalid: {}",
                url, source,
            )),
            Self::RetrievingGatewayUrl { .. } => {
                f.write_str("retrieving the gateway URL via HTTP failed")
            }
        }
    }
}

impl Error for ShardStartError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            Self::Establishing { source } => Some(source),
            Self::ParsingGatewayUrl { source, .. } => Some(source),
            Self::RetrievingGatewayUrl { source } => Some(source),
        }
    }
}

impl From<ConnectingError> for ShardStartError {
    fn from(error: ConnectingError) -> Self {
        match error {
            ConnectingError::Establishing { source } => Self::Establishing { source },
            ConnectingError::ParsingUrl { source, url } => Self::ParsingGatewayUrl { source, url },
        }
    }
}

/// Information about a shard, including its latency, current session sequence,
/// and connection stage.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Information {
    id: u64,
    latency: Latency,
    session_id: Option<Box<str>>,
    seq: u64,
    stage: Stage,
}

impl Information {
    /// Return the ID of the shard.
    pub fn id(&self) -> u64 {
        self.id
    }

    /// Return an immutable reference to the latency information for the shard.
    ///
    /// This includes the average latency over all time, and the latency
    /// information for the 5 most recent heartbeats.
    pub fn latency(&self) -> &Latency {
        &self.latency
    }

    /// Return an immutable reference to the session ID of the shard.
    pub fn session_id(&self) -> Option<&str> {
        self.session_id.as_deref()
    }

    /// Current sequence of the connection.
    ///
    /// This is the number of the event that was received this session (without
    /// reconnecting). A larger number typically correlates that the shard has
    /// been connected for a longer time, while a smaller number typically
    /// correlates to meaning that it's been connected for a less amount of
    /// time.
    pub fn seq(&self) -> u64 {
        self.seq
    }

    /// Current stage of the shard.
    ///
    /// For example, once a shard is fully booted then it will be [`Connected`].
    ///
    /// [`Connected`]: Stage::Connected
    pub fn stage(&self) -> Stage {
        self.stage
    }
}
/// Details to resume a gateway session.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ResumeSession {
    /// ID of the session being resumed.
    pub session_id: String,
    /// Last received event sequence number.
    pub sequence: u64,
}

#[derive(Debug)]
struct ShardRef {
    config: Arc<Config>,
    listeners: Listeners<Event>,
    processor_handle: OnceCell<AbortHandle>,
    session: OnceCell<WatchReceiver<Arc<Session>>>,
}

/// Shard to run and manage a session with the gateway.
///
/// Shards are responsible for handling incoming events, process events relevant
/// to the operation of shards - such as requests from the gateway to re-connect
/// or invalidate a session - and then pass the events on to the user via an
/// [event stream][`events`].
///
/// Shards will [go through a queue][`queue`] to initialize new ratelimited
/// sessions with the ratelimit. Refer to Discord's [documentation][docs:shards]
/// on shards to have a better understanding of what they are.
///
/// # Cloning
///
/// The shard internally wraps its data within an Arc. This means that the shard
/// can be cloned and passed around tasks and threads cheaply.
///
/// # Examples
///
/// Create and start a shard and print new and deleted messages:
///
/// ```no_run
/// use futures::stream::StreamExt;
/// use std::env;
/// use twilight_gateway::{EventTypeFlags, Event, Intents, Shard};
///
/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
/// // token. Of course, you may pass this into your program however you want.
/// let token = env::var("DISCORD_TOKEN")?;
/// let mut shard = Shard::new(token, Intents::GUILD_MESSAGES);
///
/// // Start the shard.
/// shard.start().await?;
///
/// // Create a loop of only new message and deleted message events.
/// let event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
/// let mut events = shard.some_events(event_types);
///
/// while let Some(event) = events.next().await {
///     match event {
///         Event::MessageCreate(message) => {
///             println!("message received with content: {}", message.content);
///         },
///         Event::MessageDelete(message) => {
///             println!("message with ID {} deleted", message.id);
///         },
///         _ => {},
///     }
/// }
/// # Ok(()) }
/// ```
///
/// [`events`]: Self::events
/// [`queue`]: crate::queue
/// [docs:shards]: https://discord.com/developers/docs/topics/gateway#sharding
#[derive(Clone, Debug)]
pub struct Shard(Arc<ShardRef>);

impl Shard {
    /// Create a new unconfingured shard.
    ///
    /// Use [`start`] to initiate the gateway session.
    ///
    /// # Examples
    ///
    /// Create a new shard and start it, wait a second, and then print its
    /// current connection stage:
    ///
    /// ```no_run
    /// use twilight_gateway::{Intents, Shard};
    /// use std::{env, time::Duration};
    /// use tokio::time as tokio_time;
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    /// let token = env::var("DISCORD_TOKEN")?;
    ///
    /// let intents = Intents::GUILD_MESSAGES | Intents::GUILD_MESSAGE_TYPING;
    /// let mut shard = Shard::new(token, intents);
    /// shard.start().await?;
    ///
    /// tokio_time::sleep(Duration::from_secs(1)).await;
    ///
    /// let info = shard.info()?;
    /// println!("Shard stage: {}", info.stage());
    /// # Ok(()) }
    /// ```
    ///
    /// [`start`]: Self::start
    pub fn new(token: impl Into<String>, intents: Intents) -> Self {
        Self::builder(token, intents).build()
    }

    pub(crate) fn new_with_config(config: Config) -> Self {
        let config = Arc::new(config);

        Self(Arc::new(ShardRef {
            config,
            listeners: Listeners::default(),
            processor_handle: OnceCell::new(),
            session: OnceCell::new(),
        }))
    }

    /// Create a builder to configure and construct a shard.
    ///
    /// Refer to the builder for more information.
    pub fn builder(token: impl Into<String>, intents: Intents) -> ShardBuilder {
        ShardBuilder::new(token, intents)
    }

    /// Return an immutable reference to the configuration used for this client.
    pub fn config(&self) -> &Config {
        &self.0.config
    }

    /// Start the shard, connecting it to the gateway and starting the process
    /// of receiving and processing events.
    ///
    /// # Errors
    ///
    /// Returns [`ShardStartError::Establishing`] if establishing a connection
    /// to the gateway failed.
    ///
    /// Returns [`ShardStartError::ParsingGatewayUrl`] if the gateway URL
    /// couldn't be parsed.
    ///
    /// Returns [`ShardStartError::RetrievingGatewayUrl`] if the gateway URL
    /// couldn't be retrieved from the HTTP API.
    pub async fn start(&mut self) -> Result<(), ShardStartError> {
        let url = if let Some(u) = self.0.config.gateway_url.clone() {
            u.into_string()
        } else {
            self.0
                .config
                .http_client()
                .gateway()
                .authed()
                .await
                .map_err(|source| ShardStartError::RetrievingGatewayUrl { source })?
                .url
        };

        let config = Arc::clone(&self.0.config);
        let listeners = self.0.listeners.clone();
        let (processor, wrx) = ShardProcessor::new(config, url, listeners)
            .await
            .map_err(ShardStartError::from)?;
        let (fut, handle) = future::abortable(processor.run());

        tokio::spawn(async move {
            let _ = fut.await;

            tracing::debug!("shard processor future ended");
        });

        // We know that these haven't been set, so we can ignore the result.
        let _ = self.0.processor_handle.set(handle);
        let _ = self.0.session.set(wrx);

        Ok(())
    }

    /// Create a new stream of events from the shard.
    ///
    /// There can be multiple streams of events. All events will be broadcast to
    /// all streams of events.
    ///
    /// **Note** that we *highly* recommend specifying only the events that you
    /// need via [`some_events`], especially if performance is a concern. This
    /// will ensure that events you don't care about aren't deserialized from
    /// received websocket messages. Gateway intents only allow specifying
    /// categories of events, but using [`some_events`] will filter it further
    /// on the client side.
    ///
    /// The returned event stream implements [`futures::stream::Stream`].
    ///
    /// All event types except for [`EventType::ShardPayload`] are enabled. If
    /// you need to enable it, consider calling [`some_events`] instead.
    ///
    /// [`EventType::ShardPayload`]: ::twilight_model::gateway::event::EventType::ShardPayload
    /// [`futures::stream::Stream`]: https://docs.rs/futures/*/futures/stream/trait.Stream.html
    /// [`some_events`]: Self::some_events
    pub fn events(&self) -> Events {
        self.some_events(EventTypeFlags::default())
    }

    /// Create a new filtered stream of events from the shard.
    ///
    /// Only the events specified in the bitflags will be sent over the stream.
    ///
    /// The returned event stream implements [`futures::stream::Stream`].
    ///
    /// # Examples
    ///
    /// Filter the events so that you only receive the [`Event::ShardConnected`]
    /// and [`Event::ShardDisconnected`] events:
    ///
    /// ```no_run
    /// use twilight_gateway::{EventTypeFlags, Event, Intents, Shard};
    /// use futures::StreamExt;
    /// use std::env;
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    /// let mut shard = Shard::new(env::var("DISCORD_TOKEN")?, Intents::empty());
    /// shard.start().await?;
    ///
    /// let event_types = EventTypeFlags::SHARD_CONNECTED | EventTypeFlags::SHARD_DISCONNECTED;
    /// let mut events = shard.some_events(event_types);
    ///
    /// while let Some(event) = events.next().await {
    ///     match event {
    ///         Event::ShardConnected(_) => println!("Shard is now connected"),
    ///         Event::ShardDisconnected(_) => println!("Shard is now disconnected"),
    ///         // No other event will come in through the stream.
    ///         _ => {},
    ///     }
    /// }
    /// # Ok(()) }
    /// ```
    ///
    /// [`futures::stream::Stream`]: https://docs.rs/futures/*/futures/stream/trait.Stream.html
    pub fn some_events(&self, event_types: EventTypeFlags) -> Events {
        let rx = self.0.listeners.add(event_types);

        Events::new(event_types, rx)
    }

    /// Retrieve information about the running of the shard, such as the current
    /// connection stage.
    ///
    /// # Errors
    ///
    /// Returns a [`SessionInactiveError`] if the shard's session is inactive.
    pub fn info(&self) -> Result<Information, SessionInactiveError> {
        let session = self.session()?;

        Ok(Information {
            id: self.config().shard()[0],
            latency: session.heartbeats.latency(),
            session_id: session.id(),
            seq: session.seq(),
            stage: session.stage(),
        })
    }

    /// Retrieve an interface implementing the `Sink` trait which can be used to
    /// send messages.
    ///
    /// This sink is only valid for the current websocket connection. If the
    /// shard's session is invalidated, or network connectivity is lost, or
    /// anything else happens that causes a need to create a new connection,
    /// then the sink will be invalidated.
    ///
    /// # Errors
    ///
    /// Returns a [`SessionInactiveError`] if the shard's session is inactive.
    pub fn sink(&self) -> Result<ShardSink, SessionInactiveError> {
        let session = self.session()?;

        Ok(ShardSink(session.tx.clone()))
    }

    /// Send a command over the gateway.
    ///
    /// # Errors
    ///
    /// Returns [`CommandError::Sending`] if the message could not be sent
    /// over the websocket. This indicates the shard is currently restarting.
    ///
    /// Returns [`CommandError::Serializing`] if the provided value failed to
    /// serialize into JSON.
    ///
    /// Returns [`CommandError::SessionInactive`] if the shard has not been
    /// started.
    pub async fn command(&self, value: &impl serde::Serialize) -> Result<(), CommandError> {
        let json = json::to_vec(value).map_err(|source| CommandError::Serializing { source })?;
        self.command_raw(json).await
    }

    /// Send a raw command over the gateway.
    ///
    /// This method should be used with caution, [`command`] should be preferred.
    ///
    /// # Errors
    ///
    /// Returns [`CommandError::Sending`] if the message could not be sent
    /// over the websocket. This indicates the shard is currently restarting.
    ///
    /// Returns [`CommandError::Serializing`] if the provided value failed to
    /// serialize into JSON.
    ///
    /// Returns [`CommandError::SessionInactive`] if the shard has not been
    /// started.
    ///
    /// [`command`]: Self::command
    pub async fn command_raw(&self, value: Vec<u8>) -> Result<(), CommandError> {
        let session = self
            .session()
            .map_err(|source| CommandError::SessionInactive { source })?;
        let message = Message::Binary(value);

        // Tick ratelimiter.
        session.ratelimit.lock().await.next().await;

        session
            .tx
            .unbounded_send(message)
            .map_err(|source| CommandError::Sending { source })
    }

    /// Shut down the shard.
    ///
    /// The shard will cleanly close the connection by sending a normal close
    /// code, causing Discord to show the bot as being offline. The session will
    /// not be resumable.
    pub fn shutdown(&self) {
        self.0.listeners.remove_all();

        if let Some(processor_handle) = self.0.processor_handle.get() {
            processor_handle.abort();
        }

        if let Ok(session) = self.session() {
            // Since we're shutting down now, we don't care if it sends or not.
            let _ = session.close(Some(CloseFrame {
                code: CloseCode::Normal,
                reason: "".into(),
            }));
            session.stop_heartbeater();
        }
    }

    /// Shut down the shard in a resumable fashion.
    ///
    /// The shard will cleanly close the connection by sending a restart close
    /// code, causing Discord to keep the bot as showing online. The connection
    /// will be resumable by using the provided session resume information
    /// to [`ClusterBuilder::resume_sessions`].
    ///
    /// [`ClusterBuilder::resume_sessions`]: crate::cluster::ClusterBuilder::resume_sessions
    pub fn shutdown_resumable(&self) -> (u64, Option<ResumeSession>) {
        self.0.listeners.remove_all();

        if let Some(processor_handle) = self.0.processor_handle.get() {
            processor_handle.abort();
        }

        let shard_id = self.config().shard()[0];

        let session = match self.session() {
            Ok(session) => session,
            Err(_) => return (shard_id, None),
        };

        let _ = session.close(Some(CloseFrame {
            code: CloseCode::Restart,
            reason: Cow::from("Closing in a resumable way"),
        }));

        let session_id = session.id();
        let sequence = session.seq.load(Ordering::Relaxed);

        session.stop_heartbeater();

        let data = session_id.map(|id| ResumeSession {
            session_id: id.into_string(),
            sequence,
        });

        (shard_id, data)
    }

    /// Return a handle to the current session.
    ///
    /// # Errors
    ///
    /// Returns a [`SessionInactiveError`] if the shard's session is inactive.
    fn session(&self) -> Result<Arc<Session>, SessionInactiveError> {
        let session = self.0.session.get().ok_or(SessionInactiveError)?;

        Ok(Arc::clone(&session.borrow()))
    }
}

#[cfg(test)]
mod tests {
    use super::{
        CommandError, ConnectingError, Information, ResumeSession, SessionInactiveError, Shard,
        ShardStartError,
    };
    use static_assertions::{assert_fields, assert_impl_all};
    use std::{error::Error, fmt::Debug};

    assert_fields!(CommandError::Sending: source);
    assert_fields!(CommandError::Serializing: source);
    assert_fields!(CommandError::SessionInactive: source);
    assert_impl_all!(CommandError: Debug, Error, Send, Sync);
    assert_impl_all!(Information: Clone, Debug, Send, Sync);
    assert_impl_all!(ResumeSession: Clone, Debug, Send, Sync);
    assert_impl_all!(
        SessionInactiveError: Clone,
        Debug,
        Error,
        Eq,
        PartialEq,
        Send,
        Sync
    );
    assert_fields!(ShardStartError::Establishing: source);
    assert_fields!(ShardStartError::ParsingGatewayUrl: source, url);
    assert_fields!(ShardStartError::RetrievingGatewayUrl: source);
    assert_impl_all!(
        ShardStartError: Debug,
        Error,
        From<ConnectingError>,
        Send,
        Sync
    );
    assert_impl_all!(Shard: Clone, Debug, Send, Sync);
}