use super::*;
use heapless::{String, Vec};
pub struct WebSocketContext {
pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, 3>,
pub sec_websocket_key: WebSocketKey,
}
pub fn read_http_header<'a>(
headers: impl Iterator<Item = (&'a str, &'a [u8])>,
) -> Result<Option<WebSocketContext>> {
let mut sec_websocket_protocol_list: Vec<String<24>, 3> = Vec::new();
let mut is_websocket_request = false;
let mut sec_websocket_key = String::new();
for (name, value) in headers {
match name {
"Upgrade" => is_websocket_request = str::from_utf8(value)? == "websocket",
"Sec-WebSocket-Protocol" => {
for item in str::from_utf8(value)?.split(", ") {
if sec_websocket_protocol_list.len() < sec_websocket_protocol_list.capacity() {
sec_websocket_protocol_list
.push(String::from(item))
.unwrap();
}
}
}
"Sec-WebSocket-Key" => {
sec_websocket_key = String::from(str::from_utf8(value)?);
}
&_ => {
}
}
}
if is_websocket_request {
Ok(Some(WebSocketContext {
sec_websocket_protocol_list,
sec_websocket_key,
}))
} else {
Ok(None)
}
}
pub fn read_server_connect_handshake_response(
sec_websocket_key: &WebSocketKey,
from: &[u8],
) -> Result<(usize, Option<WebSocketSubProtocol>)> {
let mut headers = [httparse::EMPTY_HEADER; 16];
let mut response = httparse::Response::new(&mut headers);
match response.parse(from)? {
httparse::Status::Complete(len) => {
match response.code {
Some(101) => {
}
code => {
return Err(Error::HttpResponseCodeInvalid(code));
}
};
let mut sec_websocket_protocol: Option<WebSocketSubProtocol> = None;
for item in response.headers.iter() {
match item.name {
"Sec-WebSocket-Accept" => {
let mut output = [0; 28];
build_accept_string(sec_websocket_key, &mut output)?;
let expected_accept_string = str::from_utf8(&output)?;
let actual_accept_string = str::from_utf8(item.value)?;
if actual_accept_string != expected_accept_string {
return Err(Error::AcceptStringInvalid);
}
}
"Sec-WebSocket-Protocol" => {
sec_websocket_protocol = Some(String::from(str::from_utf8(item.value)?));
}
_ => {
}
}
}
Ok((len, sec_websocket_protocol))
}
httparse::Status::Partial => Err(Error::HttpHeaderIncomplete),
}
}
pub fn build_connect_handshake_request(
websocket_options: &WebSocketOptions,
rng: &mut impl RngCore,
to: &mut [u8],
) -> Result<(usize, WebSocketKey)> {
let mut http_request: String<1024> = String::new();
let mut key_as_base64: [u8; 24] = [0; 24];
let mut key: [u8; 16] = [0; 16];
rng.fill_bytes(&mut key);
base64::encode_config_slice(key, base64::STANDARD, &mut key_as_base64);
let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?);
http_request.push_str("GET ")?;
http_request.push_str(websocket_options.path)?;
http_request.push_str(" HTTP/1.1\r\nHost: ")?;
http_request.push_str(websocket_options.host)?;
http_request
.push_str("\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ")?;
http_request.push_str(sec_websocket_key.as_str())?;
http_request.push_str("\r\nOrigin: ")?;
http_request.push_str(websocket_options.origin)?;
if let Some(sub_protocols) = websocket_options.sub_protocols {
http_request.push_str("\r\nSec-WebSocket-Protocol: ")?;
for (i, sub_protocol) in sub_protocols.iter().enumerate() {
http_request.push_str(sub_protocol)?;
if i < (sub_protocols.len() - 1) {
http_request.push_str(", ")?;
}
}
}
http_request.push_str("\r\n")?;
if let Some(additional_headers) = websocket_options.additional_headers {
for additional_header in additional_headers.iter() {
http_request.push_str(additional_header)?;
http_request.push_str("\r\n")?;
}
}
http_request.push_str("Sec-WebSocket-Version: 13\r\n\r\n")?;
to[..http_request.len()].copy_from_slice(http_request.as_bytes());
Ok((http_request.len(), sec_websocket_key))
}
pub fn build_connect_handshake_response(
sec_websocket_key: &WebSocketKey,
sec_websocket_protocol: Option<&WebSocketSubProtocol>,
to: &mut [u8],
) -> Result<usize> {
let mut http_response: String<1024> = String::new();
http_response.push_str(
"HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\nUpgrade: websocket\r\n",
)?;
if let Some(sec_websocket_protocol) = sec_websocket_protocol {
http_response.push_str("Sec-WebSocket-Protocol: ")?;
http_response.push_str(sec_websocket_protocol)?;
http_response.push_str("\r\n")?;
}
let mut output = [0; 28];
http::build_accept_string(sec_websocket_key, &mut output)?;
let accept_string = str::from_utf8(&output)?;
http_response.push_str("Sec-WebSocket-Accept: ")?;
http_response.push_str(accept_string)?;
http_response.push_str("\r\n\r\n")?;
to[..http_response.len()].copy_from_slice(http_response.as_bytes());
Ok(http_response.len())
}
pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> {
let mut accept_string: String<64> = String::new();
accept_string.push_str(sec_websocket_key)?;
accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?;
let mut sha1 = Sha1::new();
sha1.update(&accept_string);
let input = sha1.finalize();
base64::encode_config_slice(input, base64::STANDARD, output); Ok(())
}