1use http::{
4 HeaderMap, Method, Request as HttpRequest, Response as HttpResponse, StatusCode, Version,
5};
6use httparse::{Status, EMPTY_HEADER};
7use std::{
8 io::{Read, Write},
9 marker::PhantomData,
10 result::Result as StdResult,
11};
12
13use crate::{
14 error::{Error, ProtocolError, Result},
15 handshake::{
16 core::{derive_accept_key, HandshakeRole, MidHandshake, ProcessingResult},
17 headers::{FromHttparse, MAX_HEADERS},
18 machine::{HandshakeMachine, StageResult, TryParse},
19 },
20 protocol::{
21 config::WebSocketConfig,
22 websocket::{OperationMode, WebSocket},
23 },
24};
25
26pub type Request = HttpRequest<()>;
28pub type Response = HttpResponse<()>;
30pub type ErrorResponse = HttpResponse<Option<String>>;
32
33fn create_parts<T>(req: &HttpRequest<T>) -> Result<http::response::Builder> {
34 if req.method() != Method::GET {
35 return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
36 }
37
38 if req.version() < Version::HTTP_11 {
39 return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
40 }
41
42 let headers = req.headers();
43
44 if !headers
45 .get("Connection")
46 .and_then(|h| h.to_str().ok())
47 .map(|v| v.split([',', ' ']).any(|s| s.eq_ignore_ascii_case("Upgrade")))
48 .unwrap_or(false)
49 {
50 return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
51 }
52
53 if !headers
54 .get("Upgrade")
55 .and_then(|h| h.to_str().ok())
56 .map(|v| v.eq_ignore_ascii_case("websocket"))
57 .unwrap_or(false)
58 {
59 return Err(Error::Protocol(ProtocolError::MissingUpgradeHeader));
60 }
61
62 if !headers.get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
63 return Err(Error::Protocol(ProtocolError::MissingVersionHeader));
64 }
65
66 let key =
67 headers.get("Sec-WebSocket-Key").ok_or(Error::Protocol(ProtocolError::MissingKeyHeader))?;
68
69 let builder = Response::builder()
70 .status(StatusCode::SWITCHING_PROTOCOLS)
71 .version(req.version())
72 .header("Connection", "Upgrade")
73 .header("Upgrade", "websocket")
74 .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
75
76 Ok(builder)
77}
78
79pub fn create_response(req: &Request) -> Result<Response> {
81 Ok(create_parts(req)?.body(())?)
82}
83
84pub fn create_response_with_body<T1, T2>(
86 req: &HttpRequest<T1>,
87 generate_body: impl FnOnce() -> T2,
88) -> Result<HttpResponse<T2>> {
89 Ok(create_parts(req)?.body(generate_body())?)
90}
91
92pub fn write_response<T>(mut w: impl Write, res: &HttpResponse<T>) -> Result<()> {
94 writeln!(w, "{:?} {}\r", res.version(), res.status())?;
95 for (k, v) in res.headers() {
96 writeln!(w, "{}: {}\r", k, v.to_str()?)?;
97 }
98 writeln!(w, "\r")?;
99
100 Ok(())
101}
102
103impl TryParse for Request {
104 fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>> {
105 let mut header_buf = [EMPTY_HEADER; MAX_HEADERS];
106 let mut req = httparse::Request::new(&mut header_buf);
107
108 Ok(match req.parse(data)? {
109 Status::Complete(n) => Some((n, Request::from_httparse(req)?)),
110 Status::Partial => None,
111 })
112 }
113}
114
115impl<'b: 'h, 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
116 fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
117 if raw.method != Some("GET") {
118 return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
119 }
120
121 if raw.version != Some(1) {
122 return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
123 }
124
125 let mut req = Request::new(());
126 *req.method_mut() = Method::GET;
127 *req.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
128 *req.version_mut() = Version::HTTP_11;
129 *req.headers_mut() = HeaderMap::from_httparse(raw.headers)?;
130
131 Ok(req)
132 }
133}
134
135pub trait Callback: Sized {
142 fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse>;
146}
147
148impl<F> Callback for F
149where
150 F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
151{
152 fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
153 self(req, res)
154 }
155}
156
157#[derive(Clone, Copy, Debug)]
159pub struct NoCallback;
160
161impl Callback for NoCallback {
162 fn on_request(self, _req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
163 Ok(res)
164 }
165}
166
167#[allow(missing_copy_implementations)]
169#[derive(Debug)]
170pub struct ServerHandshake<S, C> {
171 callback: Option<C>,
175 config: Option<WebSocketConfig>,
177 error_response: Option<ErrorResponse>,
179 _marker: PhantomData<S>,
181}
182
183impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
184 pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
189 MidHandshake {
190 machine: HandshakeMachine::start_read(stream),
191 role: ServerHandshake {
192 callback: Some(callback),
193 config,
194 error_response: None,
195 _marker: PhantomData,
196 },
197 }
198 }
199}
200
201impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
202 type IncomingData = Request;
203 type InternalStream = S;
204 type FinalResult = WebSocket<S>;
205
206 fn stage_finished(
207 &mut self,
208 finish: StageResult<Self::IncomingData, Self::InternalStream>,
209 ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
210 match finish {
211 StageResult::DoneReading { result, stream, tail } => {
212 if !tail.is_empty() {
213 return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
214 }
215
216 let response = create_response(&result)?;
217 let callback_result = if let Some(callback) = self.callback.take() {
218 callback.on_request(&result, response)
219 } else {
220 Ok(response)
221 };
222
223 match callback_result {
224 Ok(resp) => {
225 let mut output = vec![];
226 write_response(&mut output, &resp)?;
227
228 Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
229 stream, output,
230 )))
231 }
232 Err(resp) => {
233 if resp.status().is_success() {
234 return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
235 }
236
237 self.error_response = Some(resp);
238 let resp_ref = self.error_response.as_ref().unwrap();
239
240 let mut output = vec![];
241 write_response(&mut output, resp_ref)?;
242
243 if let Some(body) = resp_ref.body() {
244 output.extend_from_slice(body.as_bytes());
245 }
246
247 Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
248 stream, output,
249 )))
250 }
251 }
252 }
253 StageResult::DoneWriting(stream) => {
254 if let Some(err) = self.error_response.take() {
255 let (parts, body) = err.into_parts();
256 return Err(Error::Http(HttpResponse::from_parts(
257 parts,
258 body.map(|s| s.into_bytes()),
259 )));
260 }
261
262 Ok(ProcessingResult::Done(WebSocket::new(
263 stream,
264 OperationMode::Server,
265 self.config,
266 )))
267 }
268 }
269 }
270}