use std::backtrace::{Backtrace, BacktraceStatus};
use std::panic::Location;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ErrorKind {
InvalidInput,
InvalidData,
InvalidState,
InsufficientBuffer,
Unsupported,
ProtocolViolation,
HandshakeRejected,
VersionNotSupported,
InvalidUtf8,
HttpResponse,
}
#[derive(Debug, Clone)]
pub struct HttpResponseInfo {
pub status_code: u16,
pub reason_phrase: String,
pub headers: Vec<(String, String)>,
}
pub struct Error {
pub kind: ErrorKind,
pub reason: String,
pub location: &'static Location<'static>,
pub backtrace: Backtrace,
pub close_code: Option<u16>,
pub http_response: Option<Box<HttpResponseInfo>>,
}
impl Error {
#[track_caller]
pub fn new(kind: ErrorKind) -> Self {
Self::with_reason(kind, String::new())
}
#[track_caller]
pub fn with_reason<T: Into<String>>(kind: ErrorKind, reason: T) -> Self {
Self {
kind,
reason: reason.into(),
location: Location::caller(),
backtrace: Backtrace::capture(),
close_code: None,
http_response: None,
}
}
#[track_caller]
#[allow(dead_code)]
pub(crate) fn invalid_input<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::InvalidInput, reason)
}
#[track_caller]
pub(crate) fn invalid_data<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::InvalidData, reason)
}
#[track_caller]
pub(crate) fn invalid_state<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::InvalidState, reason)
}
#[track_caller]
#[allow(dead_code)]
pub(crate) fn unsupported<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::Unsupported, reason)
}
#[track_caller]
pub(crate) fn insufficient_buffer() -> Self {
Self::new(ErrorKind::InsufficientBuffer)
}
#[track_caller]
pub(crate) fn protocol_violation<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::ProtocolViolation, reason)
}
#[track_caller]
#[allow(dead_code)]
pub(crate) fn handshake_rejected<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::HandshakeRejected, reason)
}
#[track_caller]
pub(crate) fn version_not_supported<T: Into<String>>(received: T) -> Self {
Self::with_reason(ErrorKind::VersionNotSupported, received)
}
#[track_caller]
pub(crate) fn invalid_utf8<T: Into<String>>(reason: T) -> Self {
Self::with_reason(ErrorKind::InvalidUtf8, reason)
}
#[track_caller]
pub(crate) fn check_buffer_size(required_size: usize, buf: &[u8]) -> Result<(), Self> {
if buf.len() < required_size {
Err(Self::insufficient_buffer())
} else {
Ok(())
}
}
#[allow(dead_code)]
pub fn with_close_code(mut self, code: u16) -> Self {
self.close_code = Some(code);
self
}
#[track_caller]
pub(crate) fn http_response(info: HttpResponseInfo) -> Self {
let reason = format!("HTTP {} {}", info.status_code, info.reason_phrase);
let mut err = Self::with_reason(ErrorKind::HttpResponse, reason);
err.http_response = Some(Box::new(info));
err
}
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}: {}", self.kind, self.reason)?;
if let Some(code) = self.close_code {
write!(f, " (WebSocket close code: {code})")?;
}
write!(f, " (at {}:{})", self.location.file(), self.location.line())?;
if self.backtrace.status() == BacktraceStatus::Captured {
write!(f, "\n\nBacktrace:\n{}", self.backtrace)?;
}
Ok(())
}
}
impl std::error::Error for Error {}