embedded_websocket_embedded_io/
http.rs

1use super::*;
2use base64::Engine;
3use core::str::FromStr;
4use heapless::{String, Vec};
5
6// NOTE: this struct is re-exported
7/// Websocket details extracted from the http header
8pub struct WebSocketContext {
9    /// The list of sub protocols is restricted to a maximum of 3
10    pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, 3>,
11    /// The websocket key user to build the accept string to complete the opening handshake
12    pub sec_websocket_key: WebSocketKey,
13}
14
15// NOTE: this function is re-exported
16/// Reads an http header and extracts websocket specific information from it.
17///
18/// # Examples
19/// ```
20/// use embedded_websocket as ws;
21/// let client_request = "GET /chat HTTP/1.1
22/// Host: myserver.example.com
23/// Upgrade: websocket
24/// Connection: Upgrade
25/// Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
26/// Origin: http://example.com
27/// Sec-WebSocket-Protocol: chat, advancedchat
28/// Sec-WebSocket-Version: 13
29///
30/// ";
31///
32/// let mut headers = [httparse::EMPTY_HEADER; 16];
33/// let mut request = httparse::Request::new(&mut headers);
34/// request.parse(client_request.as_bytes()).unwrap();
35/// let headers = request.headers.iter().map(|f| (f.name, f.value));
36/// let ws_context = ws::read_http_header(headers).unwrap().unwrap();
37/// assert_eq!("Z7OY1UwHOx/nkSz38kfPwg==", ws_context.sec_websocket_key);
38/// assert_eq!("chat", ws_context.sec_websocket_protocol_list.get(0).unwrap().as_str());
39/// assert_eq!("advancedchat", ws_context.sec_websocket_protocol_list.get(1).unwrap().as_str());
40///
41/// ```
42pub fn read_http_header<'a>(
43    headers: impl Iterator<Item = (&'a str, &'a [u8])>,
44) -> Result<Option<WebSocketContext>> {
45    let mut sec_websocket_protocol_list: Vec<String<24>, 3> = Vec::new();
46    let mut is_websocket_request = false;
47    let mut sec_websocket_key = String::new();
48
49    for (name, value) in headers {
50        match name {
51            "Upgrade" => is_websocket_request = str::from_utf8(value)? == "websocket",
52            "Sec-WebSocket-Protocol" => {
53                // extract a csv list of supported sub protocols
54                for item in str::from_utf8(value)?.split(", ") {
55                    if sec_websocket_protocol_list.len() < sec_websocket_protocol_list.capacity() {
56                        // it is safe to unwrap here because we have checked
57                        // the size of the list beforehand
58                        sec_websocket_protocol_list
59                            .push(String::from_str(item)?)
60                            .unwrap();
61                    }
62                }
63            }
64            "Sec-WebSocket-Key" => {
65                sec_websocket_key = String::from_str(str::from_utf8(value)?)?;
66            }
67            &_ => {
68                // ignore all other headers
69            }
70        }
71    }
72
73    if is_websocket_request {
74        Ok(Some(WebSocketContext {
75            sec_websocket_protocol_list,
76            sec_websocket_key,
77        }))
78    } else {
79        Ok(None)
80    }
81}
82
83pub fn read_server_connect_handshake_response(
84    sec_websocket_key: &WebSocketKey,
85    from: &[u8],
86) -> Result<(usize, Option<WebSocketSubProtocol>)> {
87    let mut headers = [httparse::EMPTY_HEADER; 16];
88    let mut response = httparse::Response::new(&mut headers);
89
90    match response.parse(from)? {
91        httparse::Status::Complete(len) => {
92            match response.code {
93                Some(101) => {
94                    // we are ok
95                }
96                code => {
97                    return Err(Error::HttpResponseCodeInvalid(code));
98                }
99            };
100
101            let mut sec_websocket_protocol: Option<WebSocketSubProtocol> = None;
102            for item in response.headers.iter() {
103                match item.name {
104                    "Sec-WebSocket-Accept" => {
105                        let mut output = [0; 28];
106                        build_accept_string(sec_websocket_key, &mut output)?;
107
108                        let expected_accept_string = str::from_utf8(&output)?;
109                        let actual_accept_string = str::from_utf8(item.value)?;
110
111                        if actual_accept_string != expected_accept_string {
112                            return Err(Error::AcceptStringInvalid);
113                        }
114                    }
115                    "Sec-WebSocket-Protocol" => {
116                        sec_websocket_protocol =
117                            Some(String::from_str(str::from_utf8(item.value)?)?);
118                    }
119                    _ => {
120                        // ignore all other headers
121                    }
122                }
123            }
124
125            Ok((len, sec_websocket_protocol))
126        }
127        httparse::Status::Partial => Err(Error::HttpHeaderIncomplete),
128    }
129}
130
131pub fn build_connect_handshake_request(
132    websocket_options: &WebSocketOptions,
133    rng: &mut impl RngCore,
134    to: &mut [u8],
135) -> Result<(usize, WebSocketKey)> {
136    let mut http_request: String<1024> = String::new();
137    let mut key_as_base64: [u8; 24] = [0; 24];
138
139    let mut key: [u8; 16] = [0; 16];
140    rng.fill_bytes(&mut key);
141    base64::engine::general_purpose::STANDARD.encode_slice(key, &mut key_as_base64)?;
142    let sec_websocket_key: String<24> = String::from_str(str::from_utf8(&key_as_base64)?)?;
143
144    http_request.push_str("GET ")?;
145    http_request.push_str(websocket_options.path)?;
146    http_request.push_str(" HTTP/1.1\r\nHost: ")?;
147    http_request.push_str(websocket_options.host)?;
148    http_request
149        .push_str("\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ")?;
150    http_request.push_str(sec_websocket_key.as_str())?;
151    http_request.push_str("\r\nOrigin: ")?;
152    http_request.push_str(websocket_options.origin)?;
153
154    // turn sub protocol list into a CSV list
155    if let Some(sub_protocols) = websocket_options.sub_protocols {
156        http_request.push_str("\r\nSec-WebSocket-Protocol: ")?;
157        for (i, sub_protocol) in sub_protocols.iter().enumerate() {
158            http_request.push_str(sub_protocol)?;
159            if i < (sub_protocols.len() - 1) {
160                http_request.push_str(", ")?;
161            }
162        }
163    }
164    http_request.push_str("\r\n")?;
165
166    if let Some(additional_headers) = websocket_options.additional_headers {
167        for additional_header in additional_headers.iter() {
168            http_request.push_str(additional_header)?;
169            http_request.push_str("\r\n")?;
170        }
171    }
172
173    http_request.push_str("Sec-WebSocket-Version: 13\r\n\r\n")?;
174    to[..http_request.len()].copy_from_slice(http_request.as_bytes());
175    Ok((http_request.len(), sec_websocket_key))
176}
177
178pub fn build_connect_handshake_response(
179    sec_websocket_key: &WebSocketKey,
180    sec_websocket_protocol: Option<&WebSocketSubProtocol>,
181    to: &mut [u8],
182) -> Result<usize> {
183    let mut http_response: String<1024> = String::new();
184    http_response.push_str(
185        "HTTP/1.1 101 Switching Protocols\r\n\
186         Connection: Upgrade\r\nUpgrade: websocket\r\n",
187    )?;
188
189    // if the user has specified a sub protocol
190    if let Some(sec_websocket_protocol) = sec_websocket_protocol {
191        http_response.push_str("Sec-WebSocket-Protocol: ")?;
192        http_response.push_str(sec_websocket_protocol)?;
193        http_response.push_str("\r\n")?;
194    }
195
196    let mut output = [0; 28];
197    http::build_accept_string(sec_websocket_key, &mut output)?;
198    let accept_string = str::from_utf8(&output)?;
199    http_response.push_str("Sec-WebSocket-Accept: ")?;
200    http_response.push_str(accept_string)?;
201    http_response.push_str("\r\n\r\n")?;
202
203    // save the response to the buffer
204    to[..http_response.len()].copy_from_slice(http_response.as_bytes());
205    Ok(http_response.len())
206}
207
208pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> {
209    // concatenate the key with a known websocket GUID (as per the spec)
210    let mut accept_string: String<64> = String::new();
211    accept_string.push_str(sec_websocket_key)?;
212    accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?;
213
214    // calculate the base64 encoded sha1 hash of the accept string above
215    let mut sha1 = Sha1::new();
216    sha1.update(&accept_string);
217    let input = sha1.finalize();
218    base64::engine::general_purpose::STANDARD.encode_slice(input, output)?;
219    Ok(())
220}