#![allow(clippy::collapsible_if)]
use bytes::{BufMut, BytesMut};
use serde_json::{json, Value};
use std::collections::VecDeque;
use std::process;
use std::time::{Duration, Instant};
use super::traits::ipc_utils::read_u32_le;
use super::traits::{read_exact, write_all, AsyncRead, AsyncWrite};
use crate::activity::Activity;
use crate::debug_println;
use crate::error::{DiscordIpcError, Result};
use crate::ipc::{constants, Command, HandshakePayload, IpcMessage, Opcode};
use crate::nonce::generate_nonce;
pub struct AsyncDiscordIpcClient<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
connection: T,
client_id: String,
read_buf: BytesMut,
write_buf: BytesMut,
pending_messages: VecDeque<PendingMessage>,
}
impl<T> AsyncDiscordIpcClient<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
const INITIAL_BUFFER_CAPACITY: usize = 4096;
pub fn new(client_id: impl Into<String>, connection: T) -> Self {
Self {
connection,
client_id: client_id.into(),
read_buf: BytesMut::with_capacity(Self::INITIAL_BUFFER_CAPACITY),
write_buf: BytesMut::with_capacity(Self::INITIAL_BUFFER_CAPACITY),
pending_messages: VecDeque::new(),
}
}
pub async fn connect(&mut self) -> Result<Value> {
self.pending_messages.clear();
let handshake = HandshakePayload {
v: constants::IPC_VERSION,
client_id: self.client_id.clone(),
};
let payload =
serde_json::to_value(handshake).map_err(DiscordIpcError::SerializationFailed)?;
self.send_message(Opcode::Handshake, &payload).await?;
let (opcode, response) = self.recv_from_connection().await?;
debug_println!("Handshake response: {}", response);
if let Some(err) = response.get("error") {
if let (Some(code), Some(message)) = (
err.get("code").and_then(|c| c.as_i64()),
err.get("message").and_then(|m| m.as_str()),
) {
return Err(DiscordIpcError::discord_error(code as i32, message));
} else {
return Err(DiscordIpcError::HandshakeFailed(format!(
"Invalid error format: {}",
err
)));
}
}
if !opcode.is_handshake_response() {
return Err(DiscordIpcError::HandshakeFailed(format!(
"Expected handshake response opcode, got {:?}",
opcode
)));
}
Ok(response)
}
pub async fn set_activity(&mut self, activity: &Activity) -> Result<()> {
if let Err(reason) = activity.validate() {
return Err(DiscordIpcError::InvalidActivity(reason));
}
let nonce = generate_nonce("set-activity");
let message = IpcMessage {
cmd: Command::SetActivity,
args: json!({
"pid": process::id(),
"activity": activity
}),
nonce: nonce.clone(),
};
let payload = serde_json::to_value(message)?;
self.send_message(Opcode::Frame, &payload).await?;
let (opcode, response) = self.recv_for_nonce(&nonce).await?;
if !opcode.is_frame_response() {
return Err(DiscordIpcError::InvalidResponse(format!(
"Expected frame response, got {:?}",
opcode
)));
}
if let Some(err) = response.get("error") {
if let (Some(code), Some(message)) = (
err.get("code").and_then(|c| c.as_i64()),
err.get("message").and_then(|m| m.as_str()),
) {
return Err(DiscordIpcError::discord_error(code as i32, message));
} else {
return Err(DiscordIpcError::InvalidResponse(format!(
"Invalid error format in response: {}",
err
)));
}
}
if let Some(resp_nonce) = response.get("nonce").and_then(|n| n.as_str()) {
if resp_nonce != nonce {
return Err(DiscordIpcError::InvalidResponse(format!(
"Nonce mismatch: expected {}, got {}",
nonce, resp_nonce
)));
}
}
Ok(())
}
pub async fn clear_activity(&mut self) -> Result<Value> {
let nonce = generate_nonce("clear-activity");
let message = IpcMessage {
cmd: Command::SetActivity,
args: json!({
"pid": process::id(),
"activity": Value::Null
}),
nonce: nonce.clone(),
};
let payload = serde_json::to_value(message)?;
self.send_message(Opcode::Frame, &payload).await?;
let (opcode, response) = self.recv_for_nonce(&nonce).await?;
debug_println!("Clear Activity response: {}", response);
if !opcode.is_frame_response() {
return Err(DiscordIpcError::InvalidResponse(format!(
"Expected frame response, got {:?}",
opcode
)));
}
if let Some(err) = response.get("error") {
if let (Some(code), Some(message)) = (
err.get("code").and_then(|c| c.as_i64()),
err.get("message").and_then(|m| m.as_str()),
) {
return Err(DiscordIpcError::discord_error(code as i32, message));
} else {
return Err(DiscordIpcError::InvalidResponse(format!(
"Invalid error format in response: {}",
err
)));
}
}
if let Some(resp_nonce) = response.get("nonce").and_then(|n| n.as_str()) {
if resp_nonce != nonce {
return Err(DiscordIpcError::InvalidResponse(format!(
"Nonce mismatch: expected {}, got {}",
nonce, resp_nonce
)));
}
}
Ok(response)
}
pub async fn send_message(&mut self, opcode: Opcode, payload: &Value) -> Result<()> {
let raw = serde_json::to_vec(payload)?;
self.write_buf.clear();
self.write_buf.reserve(8 + raw.len());
self.write_buf.put_u32_le(opcode.into());
self.write_buf.put_u32_le(raw.len() as u32);
self.write_buf.extend_from_slice(&raw);
write_all(&mut self.connection, &self.write_buf).await?;
Ok(())
}
pub async fn recv_message(&mut self) -> Result<(Opcode, Value)> {
self.next_message().await
}
pub fn cleanup_pending(&mut self, max_age: Duration) -> usize {
if max_age.is_zero() {
let dropped = self.pending_messages.len();
self.pending_messages.clear();
return dropped;
}
let now = Instant::now();
let original_len = self.pending_messages.len();
self.pending_messages
.retain(|message| now.saturating_duration_since(message.received_at) <= max_age);
original_len - self.pending_messages.len()
}
async fn next_message(&mut self) -> Result<(Opcode, Value)> {
if let Some(message) = self.pending_messages.pop_front() {
let PendingMessage {
opcode, payload, ..
} = message;
return Ok((opcode, payload));
}
self.recv_from_connection().await
}
async fn recv_for_nonce(&mut self, expected_nonce: &str) -> Result<(Opcode, Value)> {
if let Some(message) = self.take_pending_by_nonce(expected_nonce) {
return Ok(message);
}
loop {
let (opcode, response) = self.recv_from_connection().await?;
if Self::value_has_nonce(&response, expected_nonce) {
return Ok((opcode, response));
}
self.pending_messages
.push_back(PendingMessage::new(opcode, response));
}
}
async fn recv_from_connection(&mut self) -> Result<(Opcode, Value)> {
let opcode_raw = read_u32_le(&mut self.connection).await?;
let length = read_u32_le(&mut self.connection).await?;
if length > crate::ipc::protocol::constants::MAX_PAYLOAD_SIZE {
return Err(DiscordIpcError::InvalidResponse(format!(
"Payload size {} exceeds maximum allowed size of {} bytes",
length,
crate::ipc::protocol::constants::MAX_PAYLOAD_SIZE
)));
}
let opcode = Opcode::try_from(opcode_raw)?;
self.read_buf.clear();
self.read_buf.resize(length as usize, 0);
read_exact(&mut self.connection, &mut self.read_buf[..])
.await
.map_err(|_| DiscordIpcError::SocketClosed)?;
let value: Value = serde_json::from_slice(&self.read_buf)?;
Ok((opcode, value))
}
fn take_pending_by_nonce(&mut self, expected_nonce: &str) -> Option<(Opcode, Value)> {
let position = self
.pending_messages
.iter()
.position(|message| Self::value_has_nonce(&message.payload, expected_nonce));
position.and_then(|index| {
self.pending_messages.remove(index).map(|message| {
let PendingMessage {
opcode, payload, ..
} = message;
(opcode, payload)
})
})
}
fn value_has_nonce(value: &Value, expected_nonce: &str) -> bool {
value
.get("nonce")
.and_then(|n| n.as_str())
.map(|actual| actual == expected_nonce)
.unwrap_or(false)
}
}
#[derive(Debug)]
struct PendingMessage {
opcode: Opcode,
payload: Value,
received_at: Instant,
}
impl PendingMessage {
fn new(opcode: Opcode, payload: Value) -> Self {
Self {
opcode,
payload,
received_at: Instant::now(),
}
}
}