Skip to main content

embedded_websocket/
http.rs

1use super::*;
2use heapless::{String, Vec};
3
4// NOTE: this struct is re-exported
5/// Websocket details extracted from the http header
6pub struct WebSocketContext {
7    /// The list of sub protocols is restricted to a maximum of 3
8    pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, 3>,
9    /// The websocket key user to build the accept string to complete the opening handshake
10    pub sec_websocket_key: WebSocketKey,
11}
12
13// NOTE: this function is re-exported
14/// Reads an http header and extracts websocket specific information from it.
15///
16/// # Examples
17/// ```
18/// use embedded_websocket as ws;
19/// let client_request = "GET /chat HTTP/1.1
20/// Host: myserver.example.com
21/// Upgrade: websocket
22/// Connection: Upgrade
23/// Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
24/// Origin: http://example.com
25/// Sec-WebSocket-Protocol: chat, advancedchat
26/// Sec-WebSocket-Version: 13
27///
28/// ";
29///
30/// let mut headers = [httparse::EMPTY_HEADER; 16];
31/// let mut request = httparse::Request::new(&mut headers);
32/// request.parse(client_request.as_bytes()).unwrap();
33/// let headers = request.headers.iter().map(|f| (f.name, f.value));
34/// let ws_context = ws::read_http_header(headers).unwrap().unwrap();
35/// assert_eq!("Z7OY1UwHOx/nkSz38kfPwg==", ws_context.sec_websocket_key);
36/// assert_eq!("chat", ws_context.sec_websocket_protocol_list.get(0).unwrap().as_str());
37/// assert_eq!("advancedchat", ws_context.sec_websocket_protocol_list.get(1).unwrap().as_str());
38///
39/// ```
40pub fn read_http_header<'a>(
41    headers: impl Iterator<Item = (&'a str, &'a [u8])>,
42) -> Result<Option<WebSocketContext>> {
43    let mut sec_websocket_protocol_list: Vec<String<24>, 3> = Vec::new();
44    let mut is_websocket_request = false;
45    let mut sec_websocket_key = String::new();
46
47    for (name, value) in headers {
48        if name.eq_ignore_ascii_case("upgrade") {
49            is_websocket_request = str::from_utf8(value)?.eq_ignore_ascii_case("websocket");
50        } else if name.eq_ignore_ascii_case("sec-websocket-protocol") {
51            // extract a csv list of supported sub protocols
52            for item in str::from_utf8(value)?.split(", ") {
53                if sec_websocket_protocol_list.len() < sec_websocket_protocol_list.capacity() {
54                    // it is safe to unwrap here because we have checked
55                    // the size of the list beforehand
56                    sec_websocket_protocol_list
57                        .push(String::from(item))
58                        .unwrap();
59                }
60            }
61        } else if name.eq_ignore_ascii_case("sec-websocket-key") {
62            sec_websocket_key = String::from(str::from_utf8(value)?);
63        }
64    }
65
66    if is_websocket_request {
67        Ok(Some(WebSocketContext {
68            sec_websocket_protocol_list,
69            sec_websocket_key,
70        }))
71    } else {
72        Ok(None)
73    }
74}
75
76pub fn read_server_connect_handshake_response(
77    sec_websocket_key: &WebSocketKey,
78    from: &[u8],
79) -> Result<(usize, Option<WebSocketSubProtocol>)> {
80    let mut headers = [httparse::EMPTY_HEADER; 64];
81    let mut response = httparse::Response::new(&mut headers);
82
83    match response.parse(from)? {
84        httparse::Status::Complete(len) => {
85            match response.code {
86                Some(101) => {
87                    // we are ok
88                }
89                code => {
90                    return Err(Error::HttpResponseCodeInvalid(code));
91                }
92            };
93
94            let mut sec_websocket_protocol: Option<WebSocketSubProtocol> = None;
95            for item in response.headers.iter() {
96                if item.name.eq_ignore_ascii_case("sec-websocket-accept") {
97                    let mut output = [0; 28];
98                    build_accept_string(sec_websocket_key, &mut output)?;
99
100                    let expected_accept_string = str::from_utf8(&output)?;
101                    let actual_accept_string = str::from_utf8(item.value)?;
102
103                    if actual_accept_string != expected_accept_string {
104                        return Err(Error::AcceptStringInvalid);
105                    }
106                } else if item.name.eq_ignore_ascii_case("sec-websocket-protocol") {
107                    sec_websocket_protocol = Some(String::from(str::from_utf8(item.value)?));
108                }
109            }
110
111            Ok((len, sec_websocket_protocol))
112        }
113        httparse::Status::Partial => Err(Error::HttpHeaderIncomplete),
114    }
115}
116
117pub fn build_connect_handshake_request(
118    websocket_options: &WebSocketOptions,
119    rng: &mut impl RngCore,
120    to: &mut [u8],
121) -> Result<(usize, WebSocketKey)> {
122    let mut http_request: String<1024> = String::new();
123    let mut key_as_base64: [u8; 24] = [0; 24];
124
125    let mut key: [u8; 16] = [0; 16];
126    rng.fill_bytes(&mut key);
127    base64::encode_config_slice(key, base64::STANDARD, &mut key_as_base64);
128    let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?);
129
130    http_request.push_str("GET ")?;
131    http_request.push_str(websocket_options.path)?;
132    http_request.push_str(" HTTP/1.1\r\nHost: ")?;
133    http_request.push_str(websocket_options.host)?;
134    http_request
135        .push_str("\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ")?;
136    http_request.push_str(sec_websocket_key.as_str())?;
137    http_request.push_str("\r\nOrigin: ")?;
138    http_request.push_str(websocket_options.origin)?;
139
140    // turn sub protocol list into a CSV list
141    if let Some(sub_protocols) = websocket_options.sub_protocols {
142        http_request.push_str("\r\nSec-WebSocket-Protocol: ")?;
143        for (i, sub_protocol) in sub_protocols.iter().enumerate() {
144            http_request.push_str(sub_protocol)?;
145            if i < (sub_protocols.len() - 1) {
146                http_request.push_str(", ")?;
147            }
148        }
149    }
150    http_request.push_str("\r\n")?;
151
152    if let Some(additional_headers) = websocket_options.additional_headers {
153        for additional_header in additional_headers.iter() {
154            http_request.push_str(additional_header)?;
155            http_request.push_str("\r\n")?;
156        }
157    }
158
159    http_request.push_str("Sec-WebSocket-Version: 13\r\n\r\n")?;
160    to[..http_request.len()].copy_from_slice(http_request.as_bytes());
161    Ok((http_request.len(), sec_websocket_key))
162}
163
164pub fn build_connect_handshake_response(
165    sec_websocket_key: &WebSocketKey,
166    sec_websocket_protocol: Option<&WebSocketSubProtocol>,
167    to: &mut [u8],
168) -> Result<usize> {
169    let mut http_response: String<1024> = String::new();
170    http_response.push_str(
171        "HTTP/1.1 101 Switching Protocols\r\n\
172         Connection: Upgrade\r\nUpgrade: websocket\r\n",
173    )?;
174
175    // if the user has specified a sub protocol
176    if let Some(sec_websocket_protocol) = sec_websocket_protocol {
177        http_response.push_str("Sec-WebSocket-Protocol: ")?;
178        http_response.push_str(sec_websocket_protocol)?;
179        http_response.push_str("\r\n")?;
180    }
181
182    let mut output = [0; 28];
183    http::build_accept_string(sec_websocket_key, &mut output)?;
184    let accept_string = str::from_utf8(&output)?;
185    http_response.push_str("Sec-WebSocket-Accept: ")?;
186    http_response.push_str(accept_string)?;
187    http_response.push_str("\r\n\r\n")?;
188
189    // save the response to the buffer
190    to[..http_response.len()].copy_from_slice(http_response.as_bytes());
191    Ok(http_response.len())
192}
193
194pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> {
195    // concatenate the key with a known websocket GUID (as per the spec)
196    let mut accept_string: String<64> = String::new();
197    accept_string.push_str(sec_websocket_key)?;
198    accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?;
199
200    // calculate the base64 encoded sha1 hash of the accept string above
201    let mut sha1 = Sha1::new();
202    sha1.update(&accept_string);
203    let input = sha1.finalize();
204    base64::encode_config_slice(input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes
205    Ok(())
206}