use std::{io, str::Utf8Error, string::FromUtf8Error};
use http::{HeaderName, Response};
use thiserror::Error;
use crate::protocol::frame::codec::Data;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("Connection closed")]
ConnectionClosed,
#[error("Connection already closed")]
AlreadyClosed,
#[error("I/O Error: {0}")]
Io(#[from] io::Error),
#[error("Protool Error: {0}")]
Protocol(#[from] ProtocolError),
#[error("UTF-8 Error: {0}")]
Utf8(String),
#[error("Write buffer is full")]
WriteBufferFull,
#[error("Capacity Error: {0}")]
Capacity(#[from] CapacityError),
#[error("HTTP Error: {}", .0.status())]
#[cfg(feature = "handshake")]
Http(Response<Option<Vec<u8>>>),
#[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error),
#[error("URL Error: {0}")]
Url(#[from] UrlError),
#[error("TLS Error: {0}")]
Tls(#[from] TlsError),
#[error("Detected attempted attack")]
AttackAttempt,
}
impl From<Utf8Error> for Error {
fn from(value: Utf8Error) -> Self {
Error::Utf8(value.to_string())
}
}
impl From<FromUtf8Error> for Error {
fn from(value: FromUtf8Error) -> Self {
Error::Utf8(value.to_string())
}
}
#[cfg(feature = "handshake")]
impl From<http::header::InvalidHeaderName> for Error {
fn from(value: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(value.into())
}
}
#[cfg(feature = "handshake")]
impl From<http::header::InvalidHeaderValue> for Error {
fn from(value: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(value.into())
}
}
#[cfg(feature = "handshake")]
impl From<http::header::ToStrError> for Error {
fn from(value: http::header::ToStrError) -> Self {
Error::Utf8(value.to_string())
}
}
#[cfg(feature = "handshake")]
impl From<http::uri::InvalidUri> for Error {
fn from(value: http::uri::InvalidUri) -> Self {
Error::HttpFormat(value.into())
}
}
#[cfg(feature = "handshake")]
impl From<http::status::InvalidStatusCode> for Error {
fn from(value: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(value.into())
}
}
#[cfg(feature = "handshake")]
impl From<httparse::Error> for Error {
fn from(value: httparse::Error) -> Self {
match value {
httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders),
e => Error::Protocol(ProtocolError::HttparseError(e)),
}
}
}
#[allow(missing_copy_implementations)]
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ProtocolError {
#[error("Invalid HTTP method (must be GET)")]
InvalidHttpMethod,
#[error("Unsupported HTTP version (must be at least HTTP/1.1)")]
InvalidHttpVersion,
#[error("Missing, duplicated or incorrect header {0}")]
#[cfg(feature = "handshake")]
InvalidHeader(HeaderName),
#[error("Missing 'Connection: upgrade' header")]
MissingConnectionUpgradeHeader,
#[error("Missing 'Upgrade: websocket' header")]
MissingUpgradeHeader,
#[error("Missing 'Sec-WebSocket-Version: 13' header")]
MissingVersionHeader,
#[error("Missing 'Sec-WebSocket-Key' header")]
MissingKeyHeader,
#[error("Mismatched 'Sec-WebSocket-Accept' header")]
AcceptKeyMismatch,
#[error("SubProtocol error: {0}")]
SecWebSocketSubProtocolError(SubProtocolError),
#[error("Handshake incomplete")]
IncompleteHandshake,
#[error("httparse error: {0}")]
#[cfg(feature = "handshake")]
HttparseError(#[from] httparse::Error),
#[error("Encountered frame with non-zero reserved bits")]
NonZeroReservedBits,
#[error("Control frame must not be fragmented")]
FragmentedControlFrame,
#[error("Control frame payload too large")]
ControlFrameTooBig,
#[error("Received unmasked frame from client")]
UnmaskedFrameFromClient,
#[error("Received masked frame from server")]
MaskedFrameFromServer,
#[error("Received unknown control opcode: {0}")]
UnknownControlOpCode(u8),
#[error("Received unknown data opcode: {0}")]
UnknownDataOpCode(u8),
#[error("Received continue frame without open fragmentation context")]
UnexpectedContinue,
#[error("Expected fragment of type {0:?} but received something else")]
ExpectedFragment(Data),
#[error("Sent after close handshake started")]
SendAfterClose,
#[error("Received after close handshake completed")]
ReceiveAfterClose,
#[error("Invalid close frame payload")]
InvalidCloseFrame,
#[error("Connection closed without proper handshake")]
ResetWithoutClosing,
#[error("Junk after client request")]
JunkAfterRequest,
#[error("Custom response must not be successful")]
CustomResponseSuccessful,
}
#[derive(Error, Clone, PartialEq, Eq, Debug, Copy)]
pub enum SubProtocolError {
#[error("Server sent a subprotocol but none was requested")]
ServerSentSubProtocolNoneRequested,
#[error("Server sent an invalid subprotocol")]
InvalidSubProtocol,
#[error("Server sent no subprotocol")]
NoSubProtocol,
}
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
pub enum CapacityError {
#[error("Too many headers received")]
TooManyHeaders,
#[error("Payload too large: {size} > {max}")]
MessageTooLarge {
size: usize,
max: usize,
},
}
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum UrlError {
#[error("Missing host name in URL")]
MissingHost,
#[error("Empty host name in URL")]
EmptyHost,
#[error("Unsupported URL scheme (expected 'ws://' or 'wss://')")]
UnsupportedScheme,
#[error("TLS feature not enabled but 'wss://' URL used")]
TlsFeatureNotEnabled,
#[error("No path / query segment in URL")]
NoPathOrQuery,
#[error("Unable to connect to host: {0}")]
UnableToConnect(String),
}
#[allow(missing_copy_implementations)]
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum TlsError {
#[cfg(feature = "native-tls")]
#[error("Native TLS Error: {0}")]
Native(#[from] native_tls_crate::Error),
#[cfg(feature = "rustls")]
#[error("Rustls Error: {0}")]
Rustls(#[from] rustls::Error),
#[cfg(feature = "rustls")]
#[error("Invalid DNS name for TLS")]
InvalidDnsName,
}