Skip to main content

filthy_rich/
runner.rs

1use std::time::Duration;
2use tokio::{sync::mpsc, task::JoinHandle, time::sleep};
3
4use crate::{
5    PresenceClient,
6    errors::{DisconnectReason, DiscordSockError, PresenceRunnerError},
7    socket::{DiscordSock, Opcode},
8    types::{
9        ActivitySpec, DynamicRPCFrame, IPCCommand, ReadyRPCFrame,
10        data::{ActivityResponseData, ReadyData},
11    },
12    utils::get_current_timestamp,
13};
14
15macro_rules! callback {
16    ($t:ty) => {
17        Option<Box<dyn Fn($t) + Send + Sync + 'static>>
18    };
19}
20
21macro_rules! impl_callback {
22    ($name:ident, $arg:ty, $doc:expr) => {
23        #[doc = $doc]
24        pub fn $name<F: Fn($arg) + Send + Sync + 'static>(mut self, f: F) -> Self {
25            self.$name = Some(Box::new(f));
26            self
27        }
28    };
29}
30
31/// A runner that manages the Discord RPC background task.
32/// Create a runner, configure it, run it to get a client handle, then clone the handle for sharing.
33pub struct PresenceRunner {
34    rx: Option<tokio::sync::mpsc::Receiver<IPCCommand>>,
35    client: PresenceClient,
36    join_handle: Option<JoinHandle<()>>,
37    on_ready: callback!(ReadyData),
38    on_activity_send: callback!(ActivityResponseData),
39    on_disconnect: callback!(DisconnectReason),
40    on_retry: callback!(usize),
41    show_errors: bool,
42    max_retries: usize,
43}
44
45impl PresenceRunner {
46    #[must_use]
47    /// Create a new [`PresenceRunner`] instance. Requires the client ID of your chosen app from the
48    /// [Discord Developer Portal](https://discord.com/developers/applications).
49    pub fn new(client_id: impl Into<String>) -> Self {
50        let (tx, rx) = mpsc::channel(32);
51        let client = PresenceClient {
52            tx,
53            client_id: client_id.into(),
54        };
55
56        Self {
57            rx: Some(rx),
58            client,
59            join_handle: None,
60            on_ready: None,
61            on_activity_send: None,
62            on_disconnect: None,
63            on_retry: None,
64            show_errors: false,
65            max_retries: 0,
66        }
67    }
68
69    impl_callback!(
70        on_ready,
71        ReadyData,
72        "Runs a particular closure after receiving a READY event.
73
74This can fire multiple times depending on how many times the client
75needs to disconnect and reconnect."
76    );
77
78    impl_callback!(
79        on_activity_send,
80        ActivityResponseData,
81        "Run a particular closure after ensuring that a [`PresenceClient::set_activity`] has successfully managed to pass its data through the IPC channel.
82
83This can fire multiple times."
84    );
85
86    impl_callback!(
87        on_disconnect,
88        DisconnectReason,
89        "Runs a particular closure after the RPC connection is lost.
90
91This can fire multiple times depending on how many times the client disconnects and reconnects again."
92    );
93
94    impl_callback!(
95        on_retry,
96        usize,
97        "Runs a particular closure when retrying for socket creation or handshake.
98
99This can fire multiple times, or for a limited amount depending on the amount of maximum
100retries that has been set (through [`PresenceRunner::set_max_retries`]).
101
102The closure parameter is the count of retries done at the time of its execution."
103    );
104
105    /// Enable verbose error logging over [`std::io::stderr`] writes for RPC and code events.
106    #[must_use]
107    pub fn show_errors(mut self) -> Self {
108        self.show_errors = true;
109        self
110    }
111
112    /// Sets the amount of maximum retries to do on socket creation and handshakes before the runner should give up.
113    ///
114    /// By default this is set to `0` internally, which means the inner loop would retry indefinitely.
115    #[must_use]
116    pub fn set_max_retries(mut self, count: usize) -> Self {
117        self.max_retries = count;
118        self
119    }
120
121    /// Run the runner.
122    /// Must be called before any client handle operations.
123    pub async fn run(
124        &mut self,
125        wait_for_ready: bool,
126    ) -> Result<&PresenceClient, PresenceRunnerError> {
127        if self.join_handle.is_some() {
128            return Err(PresenceRunnerError::MultipleRun);
129        }
130
131        let client_id = self.client.client_id.clone();
132        let show_errors = self.show_errors;
133        let max_retries = self.max_retries;
134
135        let mut rx = self
136            .rx
137            .take()
138            .ok_or_else(|| PresenceRunnerError::ReceiverError)?;
139
140        // oneshot channel to signal when READY is received the first time
141        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
142
143        // executable closers (executed within the loop)
144        let on_ready = self.on_ready.take();
145        let on_activity_send = self.on_activity_send.take();
146        let on_disconnect = self.on_disconnect.take();
147        let on_retry = self.on_retry.take();
148
149        let join_handle = tokio::spawn(async move {
150            let mut backoff = 1;
151            let mut last_activity: Option<ActivitySpec> = None;
152            let mut ready_tx = Some(ready_tx);
153            let mut connected = false;
154            let mut retries = 0;
155
156            let mut session_start: Option<u64> = None;
157
158            'outer: loop {
159                if max_retries != 0 && retries == max_retries {
160                    break;
161                }
162
163                // initial connect
164                let mut socket = match DiscordSock::new().await {
165                    Ok(s) => s,
166                    Err(_) => {
167                        sleep(Duration::from_secs(backoff)).await;
168
169                        retries += 1;
170                        if let Some(f) = &on_retry {
171                            f(retries)
172                        }
173
174                        continue;
175                    }
176                };
177
178                // initial handshake
179                if socket.do_handshake(&client_id).await.is_err() {
180                    sleep(Duration::from_secs(backoff)).await;
181
182                    retries += 1;
183                    if let Some(f) = &on_retry {
184                        f(retries)
185                    }
186
187                    continue;
188                }
189
190                // ready loop
191                loop {
192                    let frame = match socket.read_frame().await {
193                        Ok(f) => f,
194                        Err(_) => {
195                            break;
196                        }
197                    };
198
199                    if frame.opcode != 1 {
200                        continue;
201                    }
202
203                    if let Ok(json) = serde_json::from_slice::<ReadyRPCFrame>(&frame.body) {
204                        if json.cmd.as_deref() == Some("DISPATCH")
205                            && json.evt.as_deref() == Some("READY")
206                        {
207                            if let Some(tx) = ready_tx.take() {
208                                let _ = tx.send(());
209                            }
210                            if let Some(f) = &on_ready {
211                                if let Some(data) = json.data {
212                                    f(data);
213                                }
214                            }
215                            connected = true;
216                            break;
217                        }
218
219                        if json.evt.as_deref() == Some("ERROR") && show_errors {
220                            eprintln!("Discord RPC ready event receiver error: {:?}", json.data);
221                        }
222                    }
223                }
224
225                // restore last activity (if any)
226                if let Some(activity) = &last_activity {
227                    if let Some(t) = session_start {
228                        if let Err(e) = socket.send_activity(activity.clone(), t).await {
229                            if show_errors {
230                                eprintln!("Discord RPC last activity restore error: {e}")
231                            }
232                        }
233                    }
234                }
235
236                backoff = 1;
237                retries = 0;
238
239                // generic loop for receiving commands and responding to pings from Discord itself
240                let disconnect_reason = loop {
241                    tokio::select! {
242                        biased;
243
244                        cmd = rx.recv() => {
245                            match cmd {
246                                Some(cmd) => {
247                                    match cmd {
248                                        IPCCommand::SetActivity { activity } => {
249                                            let session_start_unpacked = if let Some(s) = session_start {
250                                                s
251                                            } else {
252                                                match get_current_timestamp() {
253                                                    Ok(t) => {
254                                                        session_start = Some(t);
255                                                        t
256                                                    },
257                                                    Err(e) => {
258                                                        if show_errors {
259                                                            eprintln!("Discord RPC pre-send_activity time parsing error: {e}")
260                                                        }
261                                                        break Some(DisconnectReason::OldRelicComputer(e.to_string()));
262                                                    }
263                                                }
264                                            };
265
266                                            let activity = *activity;
267                                            last_activity = Some(activity.clone());
268
269                                            if let Err(e) = socket.send_activity(activity, session_start_unpacked).await {
270                                                if show_errors {
271                                                    eprintln!("Discord RPC send_activity error: {e}");
272                                                }
273                                                break Some(DisconnectReason::SendActivityError(e.to_string()));
274                                            }
275                                        },
276                                        IPCCommand::ClearActivity => {
277                                            last_activity = None;
278                                            session_start = None;
279
280                                            if let Err(e) = socket.clear_activity().await {
281                                                if show_errors {
282                                                    eprintln!("Discord RPC clear_activity error: {e}");
283                                                }
284                                                break Some(DisconnectReason::ClearActivityError(e.to_string()));
285                                            }
286                                        },
287                                        IPCCommand::Close { done_tx }=> {
288                                            let _ = socket.close().await;
289                                            let _ = done_tx.send(());
290                                            break 'outer;
291                                        }
292                                    }
293                                },
294                                None => break Some(DisconnectReason::ClientChannelClosed),
295                            }
296                        }
297
298                        frame = socket.read_frame() => {
299                            match frame {
300                                Ok(frame) => {
301                                    if let Ok(o) = Opcode::try_from(frame.opcode) { match o {
302                                        Opcode::Frame => {
303                                            if let Ok(json) = serde_json::from_slice::<DynamicRPCFrame>(&frame.body) {
304                                                if json.evt.as_deref() == Some("ERROR") && show_errors {
305                                                    eprintln!("Discord RPC DynamicRPCFrame error: {:?}", json.data);
306                                                } else if json.cmd.as_deref() == Some("SET_ACTIVITY") {
307                                                    if let Some(f) = &on_activity_send {
308                                                        if let Some(data) = json.data {
309                                                            let data = serde_json::from_value::<ActivityResponseData>(data);
310
311                                                            if let Ok(d) = data {
312                                                                f(d)
313                                                            } else if let Err(e) = data{
314                                                                println!("{e}")
315                                                            }
316                                                        }
317                                                    }
318                                                }
319                                            }
320                                        },
321                                        Opcode::Close => break Some(DisconnectReason::ServerClosed),
322                                        Opcode::Ping => {
323                                            if let Err(e) = socket.send_frame(Opcode::Pong, frame.body).await {
324                                                if show_errors {
325                                                    eprintln!("Discord RPC send_frame error: {e}");
326                                                }
327                                                break Some(DisconnectReason::SendFrameError(e.to_string()));
328                                            }
329                                        },
330                                        _ => {}
331                                    } }
332                                },
333                                Err(e) => {
334                                    if show_errors {
335                                        eprintln!("Discord RPC generic frame read error: {e}")
336                                    }
337                                    if let DiscordSockError::IoError(error) = &e {
338                                        if error.kind() == std::io::ErrorKind::UnexpectedEof {
339                                            break Some(DisconnectReason::PeerClosed);
340                                        }
341                                    }
342                                    break Some(DisconnectReason::ReadFrameError(e.to_string()));
343                                },
344                            }
345                        }
346                    }
347                };
348
349                if connected {
350                    if let Some(f) = &on_disconnect {
351                        f(disconnect_reason.unwrap_or(DisconnectReason::Unknown));
352                    }
353                    connected = false;
354                }
355
356                sleep(Duration::from_secs(backoff)).await;
357                backoff = (backoff * 2).min(4);
358            }
359        });
360
361        self.join_handle = Some(join_handle);
362
363        if wait_for_ready {
364            match ready_rx.await {
365                Ok(()) => (),
366                Err(_) => return Err(PresenceRunnerError::ExitBeforeReady),
367            }
368        }
369
370        Ok(&self.client)
371    }
372
373    /// Returns a clone of the client handle for sharing.
374    #[must_use]
375    pub fn clone_handle(&self) -> PresenceClient {
376        self.client.clone()
377    }
378
379    /// Waits for the IPC task to finish.
380    ///
381    /// NOTE: If there's no `join_handle` present, the function will do nothing and
382    /// just return blank.
383    pub async fn wait(&mut self) -> Result<(), PresenceRunnerError> {
384        if let Some(handle) = self.join_handle.take() {
385            handle.await?;
386        }
387
388        Ok(())
389    }
390}