tokio_websockets/upgrade/
server_response.rs1use std::str::FromStr;
4
5use base64::{Engine, engine::general_purpose::STANDARD};
6use bytes::{Buf, BytesMut};
7use http::{HeaderValue, StatusCode, header::HeaderName};
8use httparse::{Header, Response};
9use tokio_util::codec::Decoder;
10
11use crate::{sha::digest, upgrade::Error};
12
13const SWITCHING_PROTOCOLS: u16 = 101;
15
16fn header<'a, 'header: 'a>(
18 headers: &'a [Header<'header>],
19 name: &'static str,
20) -> Result<&'header [u8], Error> {
21 let header = headers
22 .iter()
23 .find(|header| header.name.eq_ignore_ascii_case(name))
24 .ok_or(Error::MissingHeader(name))?;
25
26 Ok(header.value)
27}
28
29pub struct Codec {
32 ws_accept: [u8; 20],
34}
35
36impl Codec {
37 #[must_use]
42 pub fn new(key: &[u8]) -> Self {
43 Self {
44 ws_accept: digest(key),
45 }
46 }
47}
48
49impl Decoder for Codec {
50 type Error = crate::Error;
51 type Item = super::Response;
52
53 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54 let mut headers = [httparse::EMPTY_HEADER; 25];
55 let mut response = Response::new(&mut headers);
56 let status = response.parse(src).map_err(Error::Parsing)?;
57
58 if !status.is_complete() {
59 return Ok(None);
60 }
61
62 let response_len = status.unwrap();
63 let code = response.code.unwrap();
64
65 if code != SWITCHING_PROTOCOLS {
66 return Err(crate::Error::Upgrade(Error::DidNotSwitchProtocols(code)));
67 }
68
69 let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
70 let mut ws_accept = [0; 20];
71 STANDARD
72 .decode_slice_unchecked(ws_accept_header, &mut ws_accept)
73 .map_err(|_| Error::WrongWebSocketAccept)?;
74
75 if self.ws_accept != ws_accept {
76 return Err(crate::Error::Upgrade(Error::WrongWebSocketAccept));
77 }
78
79 let mut parsed_response = http::Response::new(());
80 *parsed_response.status_mut() =
81 StatusCode::from_u16(code).map_err(|_| Error::Parsing(httparse::Error::Status))?;
82
83 match response.version {
84 Some(0) => *parsed_response.version_mut() = http::Version::HTTP_10,
85 Some(1) => *parsed_response.version_mut() = http::Version::HTTP_11,
86 _ => Err(Error::Parsing(httparse::Error::Version))?,
87 }
88
89 let header_map = parsed_response.headers_mut();
90
91 header_map.reserve(response.headers.len());
92
93 for header in response.headers {
94 let name = HeaderName::from_str(header.name)
95 .map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
96 let value = HeaderValue::from_bytes(header.value)
97 .map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;
98
99 header_map.insert(name, value);
100 }
101
102 src.advance(response_len);
103
104 Ok(Some(parsed_response))
105 }
106}