1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
//! Client to manage nodes and players.

use crate::{
    model::{IncomingEvent, OutgoingEvent, VoiceUpdate},
    node::{Node, NodeConfig, NodeError, Resume},
    player::{Player, PlayerManager},
};
use dashmap::{mapref::one::Ref, DashMap};
use futures_channel::mpsc::{TrySendError, UnboundedReceiver};
use std::{
    error::Error,
    fmt::{Display, Formatter, Result as FmtResult},
    net::SocketAddr,
    sync::Arc,
};
use twilight_model::{
    gateway::{
        event::Event,
        payload::{VoiceServerUpdate, VoiceStateUpdate},
    },
    id::{GuildId, UserId},
};

/// An error that can occur while interacting with the client.
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ClientError {
    /// A node isn't configured, so the operation isn't possible to fulfill.
    NodesUnconfigured,
    /// Sending a voice update event to the node failed because the node's
    /// connection was shutdown.
    SendingVoiceUpdate {
        /// The source of the error.
        source: TrySendError<OutgoingEvent>,
    },
}

impl Display for ClientError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            Self::NodesUnconfigured => f.write_str("no node has been configured"),
            Self::SendingVoiceUpdate { .. } => f.write_str("couldn't send voice update to node"),
        }
    }
}

impl Error for ClientError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            Self::NodesUnconfigured => None,
            Self::SendingVoiceUpdate { source } => Some(source),
        }
    }
}

#[derive(Debug)]
enum VoiceStateHalf {
    Server(VoiceServerUpdate),
    State(Box<VoiceStateUpdate>),
}

#[derive(Debug, Default)]
struct LavalinkRef {
    guilds: DashMap<GuildId, SocketAddr>,
    nodes: DashMap<SocketAddr, Node>,
    players: PlayerManager,
    resume: Option<Resume>,
    shard_count: u64,
    user_id: UserId,
    waiting: DashMap<GuildId, VoiceStateHalf>,
}

/// The lavalink client that manages nodes, players, and processes events from
/// Discord to tie it all together.
///
/// **Note**: You must call the [`process`] method with every Voice State Update
/// and Voice Server Update event you receive from Discord. It will
/// automatically forward these events to Lavalink. See its documentation for
/// more information.
///
/// You can retrieve players using the [`player`] method. Players contain
/// information about the active playing information of a guild and allows you to send events to the
/// connected node, such as [`Play`] events.
///
/// # Cloning
///
/// The client internally wraps its data within an Arc. This means that the
/// client can be cloned and passed around tasks and threads cheaply.
///
/// [`Play`]: crate::model::outgoing::Play
/// [`player`]: Self::player
/// [`process`]: Self::process
#[derive(Clone, Debug)]
pub struct Lavalink(Arc<LavalinkRef>);

impl Lavalink {
    /// Create a new Lavalink client instance.
    ///
    /// The user ID and number of shards provided may not be modified during
    /// runtime, and the client must be re-created. These parameters are
    /// automatically passed to new nodes created via [`add`].
    ///
    /// See also [`new_with_resume`], which allows you to specify session resume
    /// capability.
    ///
    /// [`add`]: Self::add
    /// [`new_with_resume`]: Self::new_with_resume
    pub fn new(user_id: UserId, shard_count: u64) -> Self {
        Self::_new_with_resume(user_id, shard_count, None)
    }

    /// Like [`new`], but allows you to specify resume capability (if any).
    ///
    /// Provide `None` for the `resume` parameter to disable session resume
    /// capability. See the [`Resume`] documentation for defaults.
    ///
    /// [`Resume`]: crate::node::Resume
    /// [`new`]: Self::new
    pub fn new_with_resume(
        user_id: UserId,
        shard_count: u64,
        resume: impl Into<Option<Resume>>,
    ) -> Self {
        Self::_new_with_resume(user_id, shard_count, resume.into())
    }

    fn _new_with_resume(user_id: UserId, shard_count: u64, resume: Option<Resume>) -> Self {
        Self(Arc::new(LavalinkRef {
            guilds: DashMap::new(),
            nodes: DashMap::new(),
            players: PlayerManager::new(),
            resume,
            shard_count,
            user_id,
            waiting: DashMap::new(),
        }))
    }

    /// Process an event into the Lavalink client.
    ///
    /// **Note**: calling this method in your event loop is required. See the
    /// [crate documentation] for an example.
    ///
    /// This requires the `VoiceServerUpdate` and `VoiceStateUpdate` events that
    /// you receive from Discord over the gateway to send voice updates to
    /// nodes. For simplicity in some applications' event loops, any event can
    /// be provided, but they will just be ignored.
    ///
    /// The Ready event can optionally be provided to do some cleaning of
    /// stalled voice states that never received their voice server update half
    /// or vice versa. It is recommended that you process Ready events.
    ///
    /// # Errors
    ///
    /// Returns [`ClientError::NodesUnconfigured`] if no nodes have been added
    /// to the client when attempting to retrieve a guild's player.
    ///
    /// [crate documentation]: crate#examples
    pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
        tracing::trace!("processing event: {:?}", event);

        let (guild_id, half) = match event {
            Event::Ready(e) => {
                let shard_id = e.shard.map_or(0, |[id, _]| id);

                self.clear_shard_states(shard_id);

                return Ok(());
            }
            Event::VoiceServerUpdate(e) => (e.guild_id, VoiceStateHalf::Server(e.clone())),
            Event::VoiceStateUpdate(e) => {
                if e.0.user_id != self.0.user_id {
                    tracing::trace!("got voice state update from another user");

                    return Ok(());
                }

                (e.0.guild_id, VoiceStateHalf::State(e.clone()))
            }
            _ => return Ok(()),
        };

        tracing::debug!(
            "got voice server/state update for {:?}: {:?}",
            guild_id,
            half
        );

        let guild_id = match guild_id {
            Some(guild_id) => guild_id,
            None => {
                tracing::trace!("event has no guild ID: {:?}", event);

                return Ok(());
            }
        };

        let update = {
            let existing_half = match self.0.waiting.get(&guild_id) {
                Some(existing_half) => existing_half,
                None => {
                    tracing::debug!(
                        "guild {} is now waiting for other half; got: {:?}",
                        guild_id,
                        half
                    );
                    self.0.waiting.insert(guild_id, half);

                    return Ok(());
                }
            };
            tracing::debug!(
                "got both halves for {}: {:?}; {:?}",
                guild_id,
                half,
                existing_half.value()
            );

            match (existing_half.value(), half) {
                (VoiceStateHalf::Server(_), VoiceStateHalf::Server(server)) => {
                    // We got the same half twice... weird, but let's just replace
                    // the existing one.
                    tracing::debug!(
                        "got the same server half twice for guild {}: {:?}",
                        guild_id,
                        server
                    );
                    self.0
                        .waiting
                        .insert(guild_id, VoiceStateHalf::Server(server));

                    return Ok(());
                }
                (VoiceStateHalf::Server(ref server), VoiceStateHalf::State(ref state)) => {
                    VoiceUpdate::new(guild_id, &state.0.session_id, From::from(server.clone()))
                }
                (VoiceStateHalf::State(_), VoiceStateHalf::State(state)) => {
                    // Just like above, we got the same half twice...
                    tracing::debug!(
                        "got the same state half twice for guild {}: {:?}",
                        guild_id,
                        state
                    );
                    self.0
                        .waiting
                        .insert(guild_id, VoiceStateHalf::State(state));

                    return Ok(());
                }
                (VoiceStateHalf::State(ref state), VoiceStateHalf::Server(ref server)) => {
                    VoiceUpdate::new(guild_id, &state.0.session_id, From::from(server.clone()))
                }
            }
        };

        tracing::debug!("removing guild {} from waiting list", guild_id);
        self.0.waiting.remove(&guild_id);

        tracing::debug!("getting player for guild {}", guild_id);
        let player = self.player(guild_id).await?;
        tracing::debug!("sending voice update for guild {}: {:?}", guild_id, update);
        player
            .send(update)
            .map_err(|source| ClientError::SendingVoiceUpdate { source })?;
        tracing::debug!("sent voice update for guild {}", guild_id);

        Ok(())
    }

    /// Add a new node to be managed by the Lavalink client.
    ///
    /// If a node already exists with the provided address, then it will be
    /// replaced.
    pub async fn add(
        &self,
        address: SocketAddr,
        authorization: impl Into<String>,
    ) -> Result<(Node, UnboundedReceiver<IncomingEvent>), NodeError> {
        let config = NodeConfig {
            address,
            authorization: authorization.into(),
            resume: self.0.resume.clone(),
            shard_count: self.0.shard_count,
            user_id: self.0.user_id,
        };

        let (node, rx) = Node::connect(config, self.0.players.clone()).await?;
        self.0.nodes.insert(address, node.clone());

        Ok((node, rx))
    }

    /// Remove a node from the list of nodes being managed by the Lavalink
    /// client.
    ///
    /// The node is returned if it existed.
    pub async fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Node)> {
        self.0.nodes.remove(&address)
    }

    /// Determine the "best" node for new players according to available nodes'
    /// penalty scores.
    ///
    /// Refer to [`Node::penalty`] for how this is calculated.
    ///
    /// # Errors
    ///
    /// Returns [`ClientError::NodesUnconfigured`] if there are no configured
    /// nodes available in the client.
    ///
    /// [`Node::penalty`]: crate::node::Node::penalty
    pub async fn best(&self) -> Result<Node, ClientError> {
        let mut lowest = i32::MAX;
        let mut best = None;

        for node in self.0.nodes.iter() {
            let penalty = node.value().penalty().await;

            if penalty < lowest {
                lowest = penalty;
                best.replace(node.clone());
            }
        }

        best.ok_or(ClientError::NodesUnconfigured)
    }

    /// Retrieve an immutable reference to the player manager.
    pub fn players(&self) -> &PlayerManager {
        &self.0.players
    }

    /// Retrieve a player for the guild.
    ///
    /// Creates a player configured to use the best available node if a player
    /// for the guild doesn't already exist. Use [`PlayerManager::get`] to only
    /// retrieve and not create.
    ///
    /// # Errors
    ///
    /// Returns [`ClientError::NodesUnconfigured`] if no node has been
    /// configured via [`add`].
    ///
    /// [`PlayerManager::get`]: crate::player::PlayerManager::get
    /// [`add`]: Self::add
    pub async fn player(&self, guild_id: GuildId) -> Result<Ref<'_, GuildId, Player>, ClientError> {
        if let Some(player) = self.players().get(&guild_id) {
            return Ok(player);
        }

        let node = self.best().await?;

        Ok(self.players().get_or_insert(guild_id, node).downgrade())
    }

    /// Clear out the map of guild states/updates for a shard that are waiting
    /// for their other half.
    ///
    /// We can do this by iterating over the map and removing the ones that we
    /// can calculate came from a shard.
    ///
    /// This map should be small or empty, and if it isn't, then it needs to be
    /// cleared out anyway.
    fn clear_shard_states(&self, shard_id: u64) {
        let shard_count = self.0.shard_count;

        for r in self.0.waiting.iter() {
            let guild_id = r.key();

            if (guild_id.0 >> 22) % shard_count == shard_id {
                self.0.waiting.remove(guild_id);
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{ClientError, Lavalink, VoiceStateHalf};
    use static_assertions::{assert_fields, assert_impl_all};
    use std::{error::Error, fmt::Debug};

    assert_fields!(ClientError::SendingVoiceUpdate: source);
    assert_impl_all!(ClientError: Clone, Debug, Error, PartialEq, Send, Sync);
    assert_impl_all!(Lavalink: Clone, Debug, Send, Sync);
    assert_impl_all!(VoiceStateHalf: Debug, Send, Sync);
}