1use std::{
2 sync::{atomic::Ordering, Arc},
3 thread::{JoinHandle, Thread},
4 time::Duration,
5};
6
7use crate::{
8 connection::Manager as ConnectionManager,
9 event_handler::{Context as EventContext, EventCallbackHandle, HandlerRegistry},
10 models::{
11 commands::{Subscription, SubscriptionArgs},
12 message::Message,
13 payload::Payload,
14 rich_presence::{
15 Activity, CloseActivityRequestArgs, SendActivityJoinInviteArgs, SetActivityArgs,
16 },
17 Command, Event, OpCode,
18 },
19 DiscordError, Result,
20};
21use crossbeam_channel::Sender;
22use serde::{de::DeserializeOwned, Serialize};
23use serde_json::Value;
24
25macro_rules! event_handler_function {
26 ( $( $name:ident, $event:expr ),* ) => {
27 event_handler_function!{@gen $([ $name, $event])*}
28 };
29
30 (@gen $( [ $name:ident, $event:expr ] ), *) => {
31 $(
32 #[doc = concat!("Listens for the `", stringify!($event), "` event")]
33 pub fn $name<F>(&self, handler: F) -> EventCallbackHandle
34 where F: Fn(EventContext) + 'static + Send + Sync
35 {
36 self.on_event($event, handler)
37 }
38 )*
39 }
40}
41
42#[allow(clippy::module_name_repetitions)]
44pub struct ClientThread(JoinHandle<()>, Sender<()>);
45
46impl ClientThread {
47 #[allow(clippy::missing_errors_doc)]
49 pub fn join(self) -> std::thread::Result<()> {
51 self.0.join()
52 }
53
54 #[allow(clippy::missing_errors_doc)]
56 #[must_use]
57 pub fn is_finished(&self) -> bool {
59 self.0.is_finished()
60 }
61 #[allow(clippy::missing_errors_doc)]
63 #[must_use]
64 pub fn thread(&self) -> &Thread {
66 self.0.thread()
67 }
68
69 pub fn stop(self) -> Result<()> {
75 self.1.send(())?;
77
78 self.join().map_err(|_| DiscordError::EventLoopError)?;
79
80 Ok(())
81 }
82
83 pub fn persist(self) {
85 std::mem::forget(self);
86 }
87}
88
89#[derive(Clone)]
90pub struct Client {
92 connection_manager: ConnectionManager,
93 event_handler_registry: Arc<HandlerRegistry>,
94 thread: Option<Arc<ClientThread>>,
95}
96
97impl Client {
98 #[must_use]
100 pub fn new(client_id: u64) -> Self {
101 Self::with_error_config(client_id, Duration::from_secs(5), None)
102 }
103
104 #[must_use]
106 pub fn with_error_config(
107 client_id: u64,
108 sleep_duration: Duration,
109 attempts: Option<usize>,
110 ) -> Self {
111 let event_handler_registry = Arc::new(HandlerRegistry::new());
112
113 let connection_manager = ConnectionManager::new(
114 client_id,
115 event_handler_registry.clone(),
116 sleep_duration,
117 attempts,
118 );
119
120 Self {
121 connection_manager,
122 event_handler_registry,
123 thread: None,
124 }
125 }
126
127 pub fn start(&mut self) {
134 let (tx, rx) = crossbeam_channel::bounded::<()>(1);
136
137 let thread = self.connection_manager.start(rx);
138
139 self.thread = Some(Arc::new(ClientThread(thread, tx)));
140 }
141
142 pub fn shutdown(self) -> Result<()> {
148 if let Some(thread) = self.thread.as_ref() {
149 thread.1.send(())?;
150
151 crate::READY.store(false, Ordering::Relaxed);
152
153 self.block_on()
154 } else {
155 Err(DiscordError::NotStarted)
156 }
157 }
158
159 pub fn block_on(mut self) -> Result<()> {
169 let thread = self.unwrap_thread()?;
170
171 thread.join().map_err(|_| DiscordError::ThreadError)?;
174
175 Ok(())
176 }
177
178 fn unwrap_thread(&mut self) -> Result<ClientThread> {
179 if let Some(thread) = self.thread.take() {
180 let thread = Arc::try_unwrap(thread).map_err(|_| DiscordError::ThreadInUse)?;
181
182 Ok(thread)
183 } else {
184 Err(DiscordError::NotStarted)
185 }
186 }
187
188 #[must_use]
189 pub fn is_ready() -> bool {
191 crate::READY.load(Ordering::Relaxed)
192 }
193
194 fn execute<A, E>(&mut self, cmd: Command, args: A, evt: Option<Event>) -> Result<Payload<E>>
195 where
196 A: Serialize + Send + Sync,
197 E: Serialize + DeserializeOwned + Send + Sync,
198 {
199 if !crate::READY.load(Ordering::Relaxed) {
200 return Err(DiscordError::NotStarted);
201 }
202
203 trace!("Executing command: {cmd:?}");
204
205 let message = Message::new(
206 OpCode::Frame,
207 Payload::with_nonce(cmd, Some(args), None, evt),
208 );
209 self.connection_manager.send(message?)?;
210 let Message { payload, .. } = self.connection_manager.recv()?;
211 trace!("Received response payload: {payload}");
212 let response: Payload<E> = serde_json::from_str(&payload)?;
213 trace!("Parsed response payload.");
214
215 match response.evt {
216 Some(Event::Error) => Err(DiscordError::SubscriptionFailed),
217 _ => Ok(response),
218 }
219 }
220
221 pub fn set_activity<F>(&mut self, f: F) -> Result<Payload<Activity>>
226 where
227 F: FnOnce(Activity) -> Activity,
228 {
229 let args = SetActivityArgs::new(f);
230 self.update_activity(args)
231 }
232
233 pub fn clear_activity(&mut self) -> Result<Payload<Activity>> {
238 self.update_activity(SetActivityArgs::default())
239 }
240
241 fn update_activity(&mut self, args: SetActivityArgs) -> Result<Payload<Activity>> {
243 let result = self.execute(Command::SetActivity, args, None);
244
245 if result.is_ok() {
247 self.connection_manager.rate_limiter.mark_sent();
248 }
249
250 result
251 }
252
253 pub fn queue_activity<F>(&mut self, f: F)
259 where
260 F: FnOnce(Activity) -> Activity,
261 {
262 self.connection_manager
263 .rate_limiter
264 .queue(SetActivityArgs::new(f));
265 }
266
267 pub fn send_activity_join_invite(&mut self, user_id: u64) -> Result<Payload<Value>> {
275 self.execute(
276 Command::SendActivityJoinInvite,
277 SendActivityJoinInviteArgs::new(user_id),
278 None,
279 )
280 }
281
282 pub fn close_activity_request(&mut self, user_id: u64) -> Result<Payload<Value>> {
287 self.execute(
288 Command::CloseActivityRequest,
289 CloseActivityRequestArgs::new(user_id),
290 None,
291 )
292 }
293
294 pub fn subscribe<F>(&mut self, evt: Event, f: F) -> Result<Payload<Subscription>>
299 where
300 F: FnOnce(SubscriptionArgs) -> SubscriptionArgs,
301 {
302 self.execute(Command::Subscribe, f(SubscriptionArgs::new()), Some(evt))
303 }
304
305 pub fn unsubscribe<F>(&mut self, evt: Event, f: F) -> Result<Payload<Subscription>>
310 where
311 F: FnOnce(SubscriptionArgs) -> SubscriptionArgs,
312 {
313 self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt))
314 }
315
316 pub fn on_event<F>(&self, event: Event, handler: F) -> EventCallbackHandle
377 where
378 F: Fn(EventContext) + 'static + Send + Sync,
379 {
380 self.event_handler_registry.register(event, handler)
381 }
382
383 pub fn block_until_event(&mut self, event: Event) -> Result<crate::event_handler::Context> {
395 let (tx, rx) = crossbeam_channel::unbounded::<crate::event_handler::Context>();
397
398 let handler = move |info| {
399 if let Err(e) = tx.send(info) {
401 error!("{e}");
402 }
403 };
404
405 let cb_handle = self.on_event(event, handler);
407
408 let response = rx.recv()?;
409
410 drop(cb_handle);
411
412 Ok(response)
413 }
414
415 event_handler_function!(on_ready, Event::Ready);
416
417 event_handler_function!(on_error, Event::Error);
418
419 event_handler_function!(on_activity_join, Event::ActivityJoin);
420
421 event_handler_function!(on_activity_join_request, Event::ActivityJoinRequest);
422
423 event_handler_function!(on_activity_spectate, Event::ActivitySpectate);
424
425 event_handler_function!(on_connected, Event::Connected);
426
427 event_handler_function!(on_disconnected, Event::Disconnected);
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_client_send_sync() {
437 #[allow(dead_code)]
438 trait SendSyncReq: Send + Sync {}
439
440 impl SendSyncReq for Client {}
441 }
442
443 #[test]
444 fn test_is_ready() {
445 assert!(!Client::is_ready());
446
447 crate::READY.store(true, Ordering::Relaxed);
448
449 assert!(Client::is_ready());
450 }
451}