blitz_ws/handshake/
server.rs

1//! Server handshake machine
2
3use http::{
4    HeaderMap, Method, Request as HttpRequest, Response as HttpResponse, StatusCode, Version,
5};
6use httparse::{Status, EMPTY_HEADER};
7use std::{
8    io::{Read, Write},
9    marker::PhantomData,
10    result::Result as StdResult,
11};
12
13use crate::{
14    error::{Error, ProtocolError, Result},
15    handshake::{
16        core::{derive_accept_key, HandshakeRole, MidHandshake, ProcessingResult},
17        headers::{FromHttparse, MAX_HEADERS},
18        machine::{HandshakeMachine, StageResult, TryParse},
19    },
20    protocol::{
21        config::WebSocketConfig,
22        websocket::{OperationMode, WebSocket},
23    },
24};
25
26/// Server Request type
27pub type Request = HttpRequest<()>;
28/// Server Response type
29pub type Response = HttpResponse<()>;
30/// Server Error Response type
31pub type ErrorResponse = HttpResponse<Option<String>>;
32
33fn create_parts<T>(req: &HttpRequest<T>) -> Result<http::response::Builder> {
34    if req.method() != Method::GET {
35        return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
36    }
37
38    if req.version() < Version::HTTP_11 {
39        return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
40    }
41
42    let headers = req.headers();
43
44    if !headers
45        .get("Connection")
46        .and_then(|h| h.to_str().ok())
47        .map(|v| v.split([',', ' ']).any(|s| s.eq_ignore_ascii_case("Upgrade")))
48        .unwrap_or(false)
49    {
50        return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
51    }
52
53    if !headers
54        .get("Upgrade")
55        .and_then(|h| h.to_str().ok())
56        .map(|v| v.eq_ignore_ascii_case("websocket"))
57        .unwrap_or(false)
58    {
59        return Err(Error::Protocol(ProtocolError::MissingUpgradeHeader));
60    }
61
62    if !headers.get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
63        return Err(Error::Protocol(ProtocolError::MissingVersionHeader));
64    }
65
66    let key =
67        headers.get("Sec-WebSocket-Key").ok_or(Error::Protocol(ProtocolError::MissingKeyHeader))?;
68
69    let builder = Response::builder()
70        .status(StatusCode::SWITCHING_PROTOCOLS)
71        .version(req.version())
72        .header("Connection", "Upgrade")
73        .header("Upgrade", "websocket")
74        .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
75
76    Ok(builder)
77}
78
79/// Creates a response for the request
80pub fn create_response(req: &Request) -> Result<Response> {
81    Ok(create_parts(req)?.body(())?)
82}
83
84/// Creates a response for the request with a custom body
85pub fn create_response_with_body<T1, T2>(
86    req: &HttpRequest<T1>,
87    generate_body: impl FnOnce() -> T2,
88) -> Result<HttpResponse<T2>> {
89    Ok(create_parts(req)?.body(generate_body())?)
90}
91
92/// Writes `response` to the stream `w`
93pub fn write_response<T>(mut w: impl Write, res: &HttpResponse<T>) -> Result<()> {
94    writeln!(w, "{:?} {}\r", res.version(), res.status())?;
95    for (k, v) in res.headers() {
96        writeln!(w, "{}: {}\r", k, v.to_str()?)?;
97    }
98    writeln!(w, "\r")?;
99
100    Ok(())
101}
102
103impl TryParse for Request {
104    fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>> {
105        let mut header_buf = [EMPTY_HEADER; MAX_HEADERS];
106        let mut req = httparse::Request::new(&mut header_buf);
107
108        Ok(match req.parse(data)? {
109            Status::Complete(n) => Some((n, Request::from_httparse(req)?)),
110            Status::Partial => None,
111        })
112    }
113}
114
115impl<'b: 'h, 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
116    fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
117        if raw.method != Some("GET") {
118            return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
119        }
120
121        if raw.version != Some(1) {
122            return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
123        }
124
125        let mut req = Request::new(());
126        *req.method_mut() = Method::GET;
127        *req.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
128        *req.version_mut() = Version::HTTP_11;
129        *req.headers_mut() = HeaderMap::from_httparse(raw.headers)?;
130
131        Ok(req)
132    }
133}
134
135/// Callback trait
136///
137/// The callback is called when the server receives an incoming WebSocket
138/// handshake request from the client. Specifying a callback allows you to analyze incoming headers
139/// and add additional headers to the response that the server sends to the client and / or reject the
140/// connection based on the incoming headers.
141pub trait Callback: Sized {
142    /// Called whenever the server reads the request from the client and is ready to respond to it.
143    /// May return additional reply headers.
144    /// Returning an error resulting in rejecting the incoming connection.
145    fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse>;
146}
147
148impl<F> Callback for F
149where
150    F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
151{
152    fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
153        self(req, res)
154    }
155}
156
157/// Stub for an empty callback
158#[derive(Clone, Copy, Debug)]
159pub struct NoCallback;
160
161impl Callback for NoCallback {
162    fn on_request(self, _req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
163        Ok(res)
164    }
165}
166
167/// Server handshake role
168#[allow(missing_copy_implementations)]
169#[derive(Debug)]
170pub struct ServerHandshake<S, C> {
171    /// Callback which is called whenever the server read the request from the client and is ready
172    /// to reply to it. The callback returns an optional headers which will be added to the reply
173    /// which the server sends to the user.
174    callback: Option<C>,
175    /// WebSocket configuration.
176    config: Option<WebSocketConfig>,
177    /// Error code/flag. If set, an error will be returned after sending response to the client.
178    error_response: Option<ErrorResponse>,
179    /// Internal stream type.
180    _marker: PhantomData<S>,
181}
182
183impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
184    /// Start server handshake. `callback` specifies a custom callback which the user can pass to
185    /// the handshake, this callback will be called when the a websocket client connects to the
186    /// server, you can specify the callback if you want to add additional header to the client
187    /// upon join based on the incoming headers.
188    pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
189        MidHandshake {
190            machine: HandshakeMachine::start_read(stream),
191            role: ServerHandshake {
192                callback: Some(callback),
193                config,
194                error_response: None,
195                _marker: PhantomData,
196            },
197        }
198    }
199}
200
201impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
202    type IncomingData = Request;
203    type InternalStream = S;
204    type FinalResult = WebSocket<S>;
205
206    fn stage_finished(
207        &mut self,
208        finish: StageResult<Self::IncomingData, Self::InternalStream>,
209    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
210        match finish {
211            StageResult::DoneReading { result, stream, tail } => {
212                if !tail.is_empty() {
213                    return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
214                }
215
216                let response = create_response(&result)?;
217                let callback_result = if let Some(callback) = self.callback.take() {
218                    callback.on_request(&result, response)
219                } else {
220                    Ok(response)
221                };
222
223                match callback_result {
224                    Ok(resp) => {
225                        let mut output = vec![];
226                        write_response(&mut output, &resp)?;
227
228                        Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
229                            stream, output,
230                        )))
231                    }
232                    Err(resp) => {
233                        if resp.status().is_success() {
234                            return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
235                        }
236
237                        self.error_response = Some(resp);
238                        let resp_ref = self.error_response.as_ref().unwrap();
239
240                        let mut output = vec![];
241                        write_response(&mut output, resp_ref)?;
242
243                        if let Some(body) = resp_ref.body() {
244                            output.extend_from_slice(body.as_bytes());
245                        }
246
247                        Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
248                            stream, output,
249                        )))
250                    }
251                }
252            }
253            StageResult::DoneWriting(stream) => {
254                if let Some(err) = self.error_response.take() {
255                    let (parts, body) = err.into_parts();
256                    return Err(Error::Http(HttpResponse::from_parts(
257                        parts,
258                        body.map(|s| s.into_bytes()),
259                    )));
260                }
261
262                Ok(ProcessingResult::Done(WebSocket::new(
263                    stream,
264                    OperationMode::Server,
265                    self.config,
266                )))
267            }
268        }
269    }
270}