discord_presence/models/
message.rs

1use crate::{DiscordError, Result};
2use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
3use num_derive::FromPrimitive;
4use num_traits::FromPrimitive;
5use serde::Serialize;
6use std::io::{Read, Write};
7
8pub(crate) const MAX_RPC_FRAME_SIZE: usize = 64 * 1024;
9pub(crate) const MAX_RPC_MESSAGE_SIZE: usize =
10    MAX_RPC_FRAME_SIZE - std::mem::size_of::<FrameHeader>();
11
12/// Codes for payload types
13#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
14#[repr(u32)]
15pub enum OpCode {
16    /// Handshake payload
17    Handshake = 0,
18    /// Frame payload
19    Frame = 1,
20    /// Close payload
21    Close = 2,
22    /// Ping payload
23    Ping = 3,
24    /// Pong payload
25    Pong = 4,
26}
27
28#[derive(Debug, Copy, Clone, PartialEq, Eq)]
29#[repr(C)]
30/// Header for the payload
31///
32/// Determines the length of the payload, and the type of payload
33pub struct FrameHeader {
34    /// The opcode for the payload
35    opcode: OpCode,
36    /// The length of the payload
37    length: u32,
38}
39
40impl FrameHeader {
41    #[must_use]
42    /// Convert an array of bytes to a [`FrameHeader`]
43    ///
44    /// # Safety
45    /// This reinterprets the bytes as a [`FrameHeader`]. It is up to the caller to ensure that the
46    /// bytes are valid.
47    pub unsafe fn from_bytes(bytes: &[u8]) -> Option<Self> {
48        if bytes.len() != std::mem::size_of::<FrameHeader>() {
49            return None;
50        }
51
52        let header: Self = unsafe { std::ptr::read_unaligned(bytes.as_ptr().cast()) };
53
54        if header.message_length() > MAX_RPC_MESSAGE_SIZE {
55            return None;
56        }
57
58        Some(header)
59    }
60
61    #[must_use]
62    /// Get the expected message length
63    pub fn message_length(&self) -> usize {
64        self.length as usize
65    }
66
67    #[must_use]
68    /// Get the opcode
69    pub fn opcode(&self) -> OpCode {
70        self.opcode
71    }
72}
73
74// NOTE: Currently unused
75// Probably remove in future
76// #[derive(Debug, Copy, Clone, PartialEq, Eq)]
77// #[repr(C)]
78// /// Frame passed over the socket
79// ///
80// /// Contains the header and the payload
81// pub struct Frame {
82//     /// The header for the payload
83//     header: FrameHeader,
84//     /// The actual payload
85//     message: [std::os::raw::c_char; MAX_RPC_FRAME_SIZE - std::mem::size_of::<FrameHeader>()],
86// }
87
88// impl From<Frame> for Message {
89//     fn from(header: Frame) -> Self {
90//         Self {
91//             opcode: header.header.opcode,
92//             payload: unsafe { CStr::from_ptr(header.message.as_ptr()) }
93//                 .to_string_lossy()
94//                 .into_owned(),
95//         }
96//     }
97// }
98
99/// Message struct for the Discord RPC
100#[derive(Debug, PartialEq, Eq, Clone)]
101pub struct Message {
102    /// The payload type for this `Message`
103    pub opcode: OpCode,
104    /// The actual payload
105    pub payload: String,
106}
107
108impl Message {
109    /// Create a new `Message`
110    ///
111    /// # Errors
112    /// - Could not serialize the payload
113    pub fn new<T>(opcode: OpCode, payload: T) -> Result<Self>
114    where
115        T: Serialize,
116    {
117        Ok(Self {
118            opcode,
119            payload: serde_json::to_string(&payload)?,
120        })
121    }
122
123    /// Encode message
124    ///
125    /// # Errors
126    /// - Failed to write to the buffer
127    ///
128    /// # Panics
129    /// - The payload length is not a 32 bit number
130    pub fn encode(&self) -> Result<Vec<u8>> {
131        let mut bytes: Vec<u8> = vec![];
132
133        let payload_length = u32::try_from(self.payload.len()).expect("32-bit payload length");
134
135        bytes.write_u32::<LittleEndian>(self.opcode as u32)?;
136        bytes.write_u32::<LittleEndian>(payload_length)?;
137        bytes.write_all(self.payload.as_bytes())?;
138
139        Ok(bytes)
140    }
141
142    /// Decode message
143    ///
144    /// # Errors
145    /// - Failed to read from buffer
146    pub fn decode(mut bytes: &[u8]) -> Result<Self> {
147        let opcode =
148            OpCode::from_u32(bytes.read_u32::<LittleEndian>()?).ok_or(DiscordError::Conversion)?;
149        let len = bytes.read_u32::<LittleEndian>()? as usize;
150        let mut payload = String::with_capacity(len);
151        bytes.read_to_string(&mut payload)?;
152
153        Ok(Self { opcode, payload })
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[derive(Debug, PartialEq, Serialize, Deserialize)]
162    struct Something {
163        empty: bool,
164    }
165
166    #[test]
167    fn test_encoder() {
168        let msg = Message::new(OpCode::Frame, Something { empty: true })
169            .expect("Failed to serialize message");
170        let encoded = msg.encode().expect("Failed to encode message");
171        let decoded = Message::decode(&encoded).expect("Failed to decode message");
172        assert_eq!(msg, decoded);
173    }
174
175    #[test]
176    fn test_opcode() {
177        assert_eq!(OpCode::from_u32(0), Some(OpCode::Handshake));
178        assert_eq!(OpCode::from_u32(4), Some(OpCode::Pong));
179        assert_eq!(OpCode::from_u32(5), None);
180    }
181}