1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
// Copyright 2015-2021 Swim Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::protocol::{CloseCodeParseErr, OpCodeParseErr};
use http::header::{HeaderName, InvalidHeaderValue};
use http::status::InvalidStatusCode;
use http::uri::InvalidUri;
use http::StatusCode;
use std::any::Any;
use std::error::Error as StdError;
use std::fmt::{Display, Formatter};
use std::io;
use std::str::Utf8Error;
use std::string::FromUtf8Error;
use thiserror::Error;

pub(crate) type BoxError = Box<dyn StdError + Send + Sync + 'static>;

/// The errors that may occur during a WebSocket connection.
#[derive(Debug)]
pub struct Error {
    inner: Inner,
}

impl Display for Error {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self.inner)
    }
}

impl StdError for Error {
    fn source(&self) -> Option<&(dyn StdError + 'static)> {
        self.inner.source.as_deref().map(|e| e as &dyn StdError)
    }
}

impl Error {
    /// Construct a new error with the provided kind and no cause.
    pub fn new(kind: ErrorKind) -> Error {
        Error {
            inner: Inner { kind, source: None },
        }
    }

    /// Construct a new error with the provided kind and a cause.
    pub fn with_cause<E>(kind: ErrorKind, source: E) -> Error
    where
        E: Into<BoxError>,
    {
        Error {
            inner: Inner {
                kind,
                source: Some(source.into()),
            },
        }
    }

    /// Returns some reference to the boxed value if it is of type T, or None if it isn’t.
    pub fn downcast_ref<T: Any + StdError>(&self) -> Option<&T> {
        match &self.inner.source {
            Some(source) => source.downcast_ref(),
            None => None,
        }
    }

    /// Whether this error is related to an IO error.
    pub fn is_io(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::IO)
    }

    /// Whether this error is related to an HTTP error.
    pub fn is_http(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::Http)
    }

    /// Whether this error is related to an extension error.
    pub fn is_extension(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::Extension)
    }

    /// Whether this error is related to a protocol error.
    pub fn is_protocol(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::Protocol)
    }

    /// Whether this error is related to an encoding error.
    pub fn is_encoding(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::Encoding)
    }

    /// Whether this error is related to a close error.
    pub fn is_close(&self) -> bool {
        matches!(self.inner.kind, ErrorKind::Close)
    }
}

#[derive(Debug)]
struct Inner {
    kind: ErrorKind,
    source: Option<BoxError>,
}

/// A type of error represented.
#[derive(Copy, Clone, Debug)]
pub enum ErrorKind {
    /// An IO error.
    IO,
    /// An HTTP error.
    Http,
    /// An extension error.
    Extension,
    /// A protocol error.
    Protocol,
    /// An encoding error.
    Encoding,
    /// A close error.
    Close,
}

impl From<io::Error> for Error {
    fn from(e: io::Error) -> Self {
        Error::with_cause(ErrorKind::IO, e)
    }
}

impl From<httparse::Error> for Error {
    fn from(e: httparse::Error) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

impl From<InvalidStatusCode> for Error {
    fn from(e: InvalidStatusCode) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

/// HTTP errors.
#[derive(Error, Debug, PartialEq)]
pub enum HttpError {
    /// An invalid HTTP method was received.
    #[error("Invalid HTTP method: `{0:?}`")]
    HttpMethod(Option<String>),
    /// The server responded with a redirect.
    #[error("Redirected: `{0}`")]
    Redirected(String),
    /// The peer returned with a status code other than 101.
    #[error("Status code: `{0}`")]
    Status(StatusCode),
    /// An invalid HTTP version was received in a request.
    #[error("Invalid HTTP version: `{0:?}`")]
    HttpVersion(Option<u8>),
    /// A request or response was missing an expected header.
    #[error("Missing header: `{0}`")]
    MissingHeader(HeaderName),
    /// A request or response contained an invalid header.
    #[error("Invalid header: `{0}`")]
    InvalidHeader(HeaderName),
    /// Sec-WebSocket-Key was invalid.
    #[error("Sec-WebSocket-Accept mismatch")]
    KeyMismatch,
    /// The provided URI was malformatted
    #[error("The provided URI was malformatted")]
    MalformattedUri(Option<String>),
    /// A provided header was malformatted
    #[error("A provided header was malformatted")]
    MalformattedHeader(String),
}

impl From<HttpError> for Error {
    fn from(e: HttpError) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

/// An invalid header was received.
#[derive(Debug)]
pub struct InvalidHeader(pub String);

impl From<InvalidHeader> for HttpError {
    fn from(e: InvalidHeader) -> Self {
        HttpError::MalformattedHeader(e.0)
    }
}

impl From<InvalidHeader> for Error {
    fn from(e: InvalidHeader) -> Self {
        Error::with_cause::<HttpError>(ErrorKind::Http, e.into())
    }
}

impl From<InvalidUri> for Error {
    fn from(e: InvalidUri) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

impl From<InvalidUri> for HttpError {
    fn from(e: InvalidUri) -> Self {
        HttpError::MalformattedUri(Some(format!("{:?}", e)))
    }
}

impl From<http::Error> for Error {
    fn from(e: http::Error) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

impl From<ProtocolError> for Error {
    fn from(e: ProtocolError) -> Self {
        Error::with_cause(ErrorKind::Protocol, e)
    }
}

impl From<OpCodeParseErr> for ProtocolError {
    fn from(e: OpCodeParseErr) -> Self {
        ProtocolError::OpCode(e)
    }
}

impl From<OpCodeParseErr> for Error {
    fn from(e: OpCodeParseErr) -> Self {
        Error::with_cause(ErrorKind::Protocol, Box::new(ProtocolError::from(e)))
    }
}

impl From<Utf8Error> for Error {
    fn from(e: Utf8Error) -> Self {
        Error::with_cause(ErrorKind::Encoding, e)
    }
}

impl From<CloseCodeParseErr> for Error {
    fn from(e: CloseCodeParseErr) -> Self {
        Error::with_cause(ErrorKind::Protocol, e)
    }
}

impl From<InvalidHeaderValue> for Error {
    fn from(e: InvalidHeaderValue) -> Self {
        Error::with_cause(ErrorKind::Http, e)
    }
}

#[derive(Clone, Copy, Error, Debug, PartialEq)]
/// The channel is closed
#[error("The channel is already closed")]
pub enum CloseCause {
    /// The channel closed nominally. This is **not** an error and instead indicates a clean closure
    /// of the channel by either ourselves or the peer.
    #[error("The channel closed as expected")]
    Stopped,
    /// This is only produced when a user attempts to reuse a closed channel and instead indicates a
    /// bug in your code.
    #[error("Attempted to use a closed channel")]
    Error,
}

/// WebSocket protocol errors.
#[derive(Copy, Clone, Debug, PartialEq, Error)]
pub enum ProtocolError {
    /// Invalid encoding was received.
    #[error("Not valid UTF-8 encoding")]
    Encoding,
    /// A peer selected a protocol that was not sent.
    #[error("Received an unknown subprotocol")]
    UnknownProtocol,
    /// An invalid OpCode was received.
    #[error("Bad OpCode: `{0}`")]
    OpCode(OpCodeParseErr),
    /// The peer sent an unmasked frame when one was expected.
    #[error("Received an unexpected unmasked frame")]
    UnmaskedFrame,
    /// The peer sent an masked frame when one was not expected.
    #[error("Received an unexpected masked frame")]
    MaskedFrame,
    /// Received a fragmented control frame
    #[error("Received a fragmented control frame")]
    FragmentedControl,
    /// A received frame exceeded the maximum permitted size
    #[error("A frame exceeded the maximum permitted size")]
    FrameOverflow,
    /// A peer attempted to use an extension that has not been negotiated
    #[error("Attempted to use an extension that has not been negotiated")]
    UnknownExtension,
    /// Received a continuation frame before one has been started
    #[error("Received a continuation frame before one has been started")]
    ContinuationNotStarted,
    /// A peer attempted to start another continuation before the previous one has completed
    #[error("Attempted to start another continuation before the previous one has completed")]
    ContinuationAlreadyStarted,
    /// Received an illegal close code
    #[error("Received an illegal close code: `{0}`")]
    CloseCode(u16),
    /// Received unexpected control frame data
    #[error("Received unexpected control frame data")]
    ControlDataMismatch,
}

impl From<FromUtf8Error> for Error {
    fn from(e: FromUtf8Error) -> Self {
        Error::with_cause(ErrorKind::Encoding, e)
    }
}