embedded_websocket/
http.rs

1use super::*;
2use heapless::{String, Vec};
3
4// NOTE: this struct is re-exported
5/// Websocket details extracted from the http header
6pub struct WebSocketContext {
7    /// The list of sub protocols is restricted to a maximum of 3
8    pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, 3>,
9    /// The websocket key user to build the accept string to complete the opening handshake
10    pub sec_websocket_key: WebSocketKey,
11}
12
13// NOTE: this function is re-exported
14/// Reads an http header and extracts websocket specific information from it.
15///
16/// # Examples
17/// ```
18/// use embedded_websocket as ws;
19/// let client_request = "GET /chat HTTP/1.1
20/// Host: myserver.example.com
21/// Upgrade: websocket
22/// Connection: Upgrade
23/// Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
24/// Origin: http://example.com
25/// Sec-WebSocket-Protocol: chat, advancedchat
26/// Sec-WebSocket-Version: 13
27///
28/// ";
29///
30/// let mut headers = [httparse::EMPTY_HEADER; 16];
31/// let mut request = httparse::Request::new(&mut headers);
32/// request.parse(client_request.as_bytes()).unwrap();
33/// let headers = request.headers.iter().map(|f| (f.name, f.value));
34/// let ws_context = ws::read_http_header(headers).unwrap().unwrap();
35/// assert_eq!("Z7OY1UwHOx/nkSz38kfPwg==", ws_context.sec_websocket_key);
36/// assert_eq!("chat", ws_context.sec_websocket_protocol_list.get(0).unwrap().as_str());
37/// assert_eq!("advancedchat", ws_context.sec_websocket_protocol_list.get(1).unwrap().as_str());
38///
39/// ```
40pub fn read_http_header<'a>(
41    headers: impl Iterator<Item = (&'a str, &'a [u8])>,
42) -> Result<Option<WebSocketContext>> {
43    let mut sec_websocket_protocol_list: Vec<String<24>, 3> = Vec::new();
44    let mut is_websocket_request = false;
45    let mut sec_websocket_key = String::new();
46
47    for (name, value) in headers {
48        match name {
49            "Upgrade" => is_websocket_request = str::from_utf8(value)? == "websocket",
50            "Sec-WebSocket-Protocol" => {
51                // extract a csv list of supported sub protocols
52                for item in str::from_utf8(value)?.split(", ") {
53                    if sec_websocket_protocol_list.len() < sec_websocket_protocol_list.capacity() {
54                        // it is safe to unwrap here because we have checked
55                        // the size of the list beforehand
56                        sec_websocket_protocol_list
57                            .push(String::from(item))
58                            .unwrap();
59                    }
60                }
61            }
62            "Sec-WebSocket-Key" => {
63                sec_websocket_key = String::from(str::from_utf8(value)?);
64            }
65            &_ => {
66                // ignore all other headers
67            }
68        }
69    }
70
71    if is_websocket_request {
72        Ok(Some(WebSocketContext {
73            sec_websocket_protocol_list,
74            sec_websocket_key,
75        }))
76    } else {
77        Ok(None)
78    }
79}
80
81pub fn read_server_connect_handshake_response(
82    sec_websocket_key: &WebSocketKey,
83    from: &[u8],
84) -> Result<(usize, Option<WebSocketSubProtocol>)> {
85    let mut headers = [httparse::EMPTY_HEADER; 16];
86    let mut response = httparse::Response::new(&mut headers);
87
88    match response.parse(from)? {
89        httparse::Status::Complete(len) => {
90            match response.code {
91                Some(101) => {
92                    // we are ok
93                }
94                code => {
95                    return Err(Error::HttpResponseCodeInvalid(code));
96                }
97            };
98
99            let mut sec_websocket_protocol: Option<WebSocketSubProtocol> = None;
100            for item in response.headers.iter() {
101                match item.name {
102                    "Sec-WebSocket-Accept" => {
103                        let mut output = [0; 28];
104                        build_accept_string(sec_websocket_key, &mut output)?;
105
106                        let expected_accept_string = str::from_utf8(&output)?;
107                        let actual_accept_string = str::from_utf8(item.value)?;
108
109                        if actual_accept_string != expected_accept_string {
110                            return Err(Error::AcceptStringInvalid);
111                        }
112                    }
113                    "Sec-WebSocket-Protocol" => {
114                        sec_websocket_protocol = Some(String::from(str::from_utf8(item.value)?));
115                    }
116                    _ => {
117                        // ignore all other headers
118                    }
119                }
120            }
121
122            Ok((len, sec_websocket_protocol))
123        }
124        httparse::Status::Partial => Err(Error::HttpHeaderIncomplete),
125    }
126}
127
128pub fn build_connect_handshake_request(
129    websocket_options: &WebSocketOptions,
130    rng: &mut impl RngCore,
131    to: &mut [u8],
132) -> Result<(usize, WebSocketKey)> {
133    let mut http_request: String<1024> = String::new();
134    let mut key_as_base64: [u8; 24] = [0; 24];
135
136    let mut key: [u8; 16] = [0; 16];
137    rng.fill_bytes(&mut key);
138    base64::encode_config_slice(key, base64::STANDARD, &mut key_as_base64);
139    let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?);
140
141    http_request.push_str("GET ")?;
142    http_request.push_str(websocket_options.path)?;
143    http_request.push_str(" HTTP/1.1\r\nHost: ")?;
144    http_request.push_str(websocket_options.host)?;
145    http_request
146        .push_str("\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ")?;
147    http_request.push_str(sec_websocket_key.as_str())?;
148    http_request.push_str("\r\nOrigin: ")?;
149    http_request.push_str(websocket_options.origin)?;
150
151    // turn sub protocol list into a CSV list
152    if let Some(sub_protocols) = websocket_options.sub_protocols {
153        http_request.push_str("\r\nSec-WebSocket-Protocol: ")?;
154        for (i, sub_protocol) in sub_protocols.iter().enumerate() {
155            http_request.push_str(sub_protocol)?;
156            if i < (sub_protocols.len() - 1) {
157                http_request.push_str(", ")?;
158            }
159        }
160    }
161    http_request.push_str("\r\n")?;
162
163    if let Some(additional_headers) = websocket_options.additional_headers {
164        for additional_header in additional_headers.iter() {
165            http_request.push_str(additional_header)?;
166            http_request.push_str("\r\n")?;
167        }
168    }
169
170    http_request.push_str("Sec-WebSocket-Version: 13\r\n\r\n")?;
171    to[..http_request.len()].copy_from_slice(http_request.as_bytes());
172    Ok((http_request.len(), sec_websocket_key))
173}
174
175pub fn build_connect_handshake_response(
176    sec_websocket_key: &WebSocketKey,
177    sec_websocket_protocol: Option<&WebSocketSubProtocol>,
178    to: &mut [u8],
179) -> Result<usize> {
180    let mut http_response: String<1024> = String::new();
181    http_response.push_str(
182        "HTTP/1.1 101 Switching Protocols\r\n\
183         Connection: Upgrade\r\nUpgrade: websocket\r\n",
184    )?;
185
186    // if the user has specified a sub protocol
187    if let Some(sec_websocket_protocol) = sec_websocket_protocol {
188        http_response.push_str("Sec-WebSocket-Protocol: ")?;
189        http_response.push_str(sec_websocket_protocol)?;
190        http_response.push_str("\r\n")?;
191    }
192
193    let mut output = [0; 28];
194    http::build_accept_string(sec_websocket_key, &mut output)?;
195    let accept_string = str::from_utf8(&output)?;
196    http_response.push_str("Sec-WebSocket-Accept: ")?;
197    http_response.push_str(accept_string)?;
198    http_response.push_str("\r\n\r\n")?;
199
200    // save the response to the buffer
201    to[..http_response.len()].copy_from_slice(http_response.as_bytes());
202    Ok(http_response.len())
203}
204
205pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> {
206    // concatenate the key with a known websocket GUID (as per the spec)
207    let mut accept_string: String<64> = String::new();
208    accept_string.push_str(sec_websocket_key)?;
209    accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?;
210
211    // calculate the base64 encoded sha1 hash of the accept string above
212    let mut sha1 = Sha1::new();
213    sha1.update(&accept_string);
214    let input = sha1.finalize();
215    base64::encode_config_slice(input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes
216    Ok(())
217}