use std::fmt::{self, Display};
use std::io;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct ProtocolContext {
pub expected_opcode: Option<u32>,
pub received_opcode: Option<u32>,
pub payload_size: Option<usize>,
}
impl ProtocolContext {
pub fn new() -> Self {
Self {
expected_opcode: None,
received_opcode: None,
payload_size: None,
}
}
pub fn with_opcodes(expected: u32, received: u32) -> Self {
Self {
expected_opcode: Some(expected),
received_opcode: Some(received),
payload_size: None,
}
}
pub fn with_payload(received_opcode: u32, payload_size: usize) -> Self {
Self {
expected_opcode: None,
received_opcode: Some(received_opcode),
payload_size: Some(payload_size),
}
}
}
impl Default for ProtocolContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCategory {
Connection,
Protocol,
Serialization,
Application,
Other,
}
impl Display for ErrorCategory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Connection => write!(f, "connection"),
Self::Protocol => write!(f, "protocol"),
Self::Serialization => write!(f, "serialization"),
Self::Application => write!(f, "application"),
Self::Other => write!(f, "other"),
}
}
}
#[derive(Error, Debug)]
pub enum DiscordIpcError {
#[error("Failed to connect to Discord IPC socket: {0}")]
ConnectionFailed(#[source] io::Error),
#[error("Failed to discover Discord socket. Attempted paths: {}", attempted_paths.join(", "))]
SocketDiscoveryFailed {
#[source]
source: io::Error,
attempted_paths: Vec<String>,
},
#[error("Connection timeout after {timeout_ms}ms")]
ConnectionTimeout {
timeout_ms: u64,
last_error: Option<String>,
},
#[error("No Discord IPC socket found. Is Discord running?")]
NoValidSocket,
#[error("Failed to serialize JSON payload: {0}")]
SerializationFailed(#[source] serde_json::Error),
#[error("Failed to deserialize response from Discord: {0}")]
DeserializationFailed(#[source] serde_json::Error),
#[error("Invalid response from Discord: {0}")]
InvalidResponse(String),
#[error("Handshake failed: {0}")]
HandshakeFailed(String),
#[error("Socket connection was closed unexpectedly")]
SocketClosed,
#[error("Invalid opcode: {0}")]
InvalidOpcode(u32),
#[error("Protocol violation: {message}")]
ProtocolViolation {
message: String,
context: ProtocolContext,
},
#[error("Discord error: {code} - {message}")]
DiscordError {
code: i32,
message: String,
},
#[error("Invalid activity: {0}")]
InvalidActivity(String),
#[error("System time error: {0}")]
SystemTimeError(String),
}
impl DiscordIpcError {
pub fn category(&self) -> ErrorCategory {
match self {
Self::ConnectionFailed(_)
| Self::SocketDiscoveryFailed { .. }
| Self::ConnectionTimeout { .. }
| Self::NoValidSocket
| Self::SocketClosed => ErrorCategory::Connection,
Self::SerializationFailed(_) | Self::DeserializationFailed(_) => {
ErrorCategory::Serialization
}
Self::InvalidResponse(_)
| Self::HandshakeFailed(_)
| Self::InvalidOpcode(_)
| Self::ProtocolViolation { .. } => ErrorCategory::Protocol,
Self::DiscordError { .. } => ErrorCategory::Application,
Self::InvalidActivity(_) | Self::SystemTimeError(_) => ErrorCategory::Other,
}
}
pub fn is_connection_error(&self) -> bool {
matches!(self.category(), ErrorCategory::Connection)
}
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Self::ConnectionTimeout { .. }
| Self::SocketClosed
| Self::InvalidResponse(_)
| Self::SocketDiscoveryFailed { .. }
)
}
pub fn discord_error(code: i32, message: impl Into<String>) -> Self {
Self::DiscordError {
code,
message: message.into(),
}
}
pub fn socket_discovery_failed(source: io::Error, attempted_paths: Vec<String>) -> Self {
Self::SocketDiscoveryFailed {
source,
attempted_paths,
}
}
pub fn connection_timeout(timeout_ms: u64, last_error: Option<String>) -> Self {
Self::ConnectionTimeout {
timeout_ms,
last_error,
}
}
pub fn protocol_violation(message: impl Into<String>, context: ProtocolContext) -> Self {
Self::ProtocolViolation {
message: message.into(),
context,
}
}
}
impl From<io::Error> for DiscordIpcError {
fn from(error: io::Error) -> Self {
Self::ConnectionFailed(error)
}
}
impl From<serde_json::Error> for DiscordIpcError {
fn from(error: serde_json::Error) -> Self {
Self::SerializationFailed(error)
}
}
pub type Result<T = ()> = std::result::Result<T, DiscordIpcError>;
pub mod utils {
use super::DiscordIpcError;
use std::error::Error;
use std::fmt::{self, Display};
#[derive(Debug)]
pub struct AppError {
source: DiscordIpcError,
context: Option<String>,
}
impl AppError {
pub fn new(source: DiscordIpcError, context: impl Into<String>) -> Self {
Self {
source,
context: Some(context.into()),
}
}
pub fn from_error(source: DiscordIpcError) -> Self {
Self {
source,
context: None,
}
}
pub fn discord_error(&self) -> &DiscordIpcError {
&self.source
}
pub fn context(&self) -> Option<&str> {
self.context.as_deref()
}
}
impl Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(context) = &self.context {
write!(f, "{}: {}", context, self.source)
} else {
write!(f, "{}", self.source)
}
}
}
impl Error for AppError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(&self.source)
}
}
pub trait ResultExt<T> {
fn with_context(self, context: impl Into<String>) -> std::result::Result<T, AppError>;
fn map_err_to<E>(self, f: impl FnOnce(DiscordIpcError) -> E) -> std::result::Result<T, E>;
fn retry_if<F>(
self,
is_recoverable: fn(&DiscordIpcError) -> bool,
retry_op: F,
) -> std::result::Result<T, DiscordIpcError>
where
F: FnOnce() -> std::result::Result<T, DiscordIpcError>;
}
impl<T> ResultExt<T> for std::result::Result<T, DiscordIpcError> {
fn with_context(self, context: impl Into<String>) -> std::result::Result<T, AppError> {
self.map_err(|err| AppError::new(err, context))
}
fn map_err_to<E>(self, f: impl FnOnce(DiscordIpcError) -> E) -> std::result::Result<T, E> {
self.map_err(f)
}
fn retry_if<F>(
self,
is_recoverable: fn(&DiscordIpcError) -> bool,
retry_op: F,
) -> std::result::Result<T, DiscordIpcError>
where
F: FnOnce() -> std::result::Result<T, DiscordIpcError>,
{
match self {
Ok(value) => Ok(value),
Err(err) if is_recoverable(&err) => retry_op(),
Err(err) => Err(err),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::utils::{AppError, ResultExt};
#[test]
fn protocol_context_helpers_populate_fields() {
let empty = ProtocolContext::new();
assert!(empty.expected_opcode.is_none());
assert!(empty.received_opcode.is_none());
let with_opcodes = ProtocolContext::with_opcodes(1, 2);
assert_eq!(with_opcodes.expected_opcode, Some(1));
assert_eq!(with_opcodes.received_opcode, Some(2));
let with_payload = ProtocolContext::with_payload(3, 42);
assert_eq!(with_payload.received_opcode, Some(3));
assert_eq!(with_payload.payload_size, Some(42));
}
#[test]
fn error_category_and_recoverable_detection() {
let conn_err = DiscordIpcError::SocketClosed;
assert_eq!(conn_err.category(), ErrorCategory::Connection);
assert!(conn_err.is_connection_error());
assert!(conn_err.is_recoverable());
let proto_err = DiscordIpcError::InvalidResponse("oops".into());
assert_eq!(proto_err.category(), ErrorCategory::Protocol);
assert!(proto_err.is_recoverable());
let app_err = DiscordIpcError::discord_error(4000, "bad");
assert_eq!(app_err.category(), ErrorCategory::Application);
assert!(!app_err.is_recoverable());
}
#[test]
fn app_error_preserves_context() {
let err = DiscordIpcError::SocketClosed;
let wrapped = AppError::new(err, "while sending message");
assert!(matches!(
wrapped.discord_error(),
DiscordIpcError::SocketClosed
));
assert_eq!(wrapped.context(), Some("while sending message"));
assert!(format!("{}", wrapped).contains("while sending message"));
}
#[test]
fn result_ext_retry_if_retries_on_recoverable() {
use std::cell::Cell;
let attempts = Cell::new(0);
let initial: Result<()> = Err(DiscordIpcError::SocketClosed);
let outcome = initial.retry_if(DiscordIpcError::is_recoverable, || {
attempts.set(attempts.get() + 1);
Ok(())
});
assert!(outcome.is_ok());
assert_eq!(attempts.get(), 1);
}
#[test]
fn result_ext_with_context_maps_error() {
let result: Result<()> = Err(DiscordIpcError::SocketClosed);
let app_result = result.with_context("connecting");
let err = app_result.unwrap_err();
assert_eq!(err.context(), Some("connecting"));
}
}