#![cfg_attr(not(feature = "std"), no_std)]
#![allow(async_fn_in_trait)]
#![warn(clippy::large_futures)]
#![allow(clippy::uninlined_format_args)]
#![allow(unknown_lints)]
use core::fmt::Display;
use core::str;
use httparse::{Header, EMPTY_HEADER};
use ws::{is_upgrade_accepted, is_upgrade_request, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN};
pub const DEFAULT_MAX_HEADERS_COUNT: usize = 64;
pub(crate) mod fmt;
#[cfg(feature = "io")]
pub mod io;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum HeadersMismatchError {
ResponseConnectionTypeMismatchError,
BodyTypeError(&'static str),
}
impl Display for HeadersMismatchError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ResponseConnectionTypeMismatchError => write!(
f,
"Response connection type is different from the request connection type"
),
Self::BodyTypeError(s) => write!(f, "Body type mismatch: {s}"),
}
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for HeadersMismatchError {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::ResponseConnectionTypeMismatchError => defmt::write!(
f,
"Response connection type is different from the request connection type"
),
Self::BodyTypeError(s) => defmt::write!(f, "Body type mismatch: {}", s),
}
}
}
impl core::error::Error for HeadersMismatchError {}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "std", derive(Hash))]
pub enum Method {
Delete,
Get,
Head,
Post,
Put,
Connect,
Options,
Trace,
Copy,
Lock,
MkCol,
Move,
Propfind,
Proppatch,
Search,
Unlock,
Bind,
Rebind,
Unbind,
Acl,
Report,
MkActivity,
Checkout,
Merge,
MSearch,
Notify,
Subscribe,
Unsubscribe,
Patch,
Purge,
MkCalendar,
Link,
Unlink,
}
impl Method {
pub fn new(method: &str) -> Option<Self> {
if method.eq_ignore_ascii_case("Delete") {
Some(Self::Delete)
} else if method.eq_ignore_ascii_case("Get") {
Some(Self::Get)
} else if method.eq_ignore_ascii_case("Head") {
Some(Self::Head)
} else if method.eq_ignore_ascii_case("Post") {
Some(Self::Post)
} else if method.eq_ignore_ascii_case("Put") {
Some(Self::Put)
} else if method.eq_ignore_ascii_case("Connect") {
Some(Self::Connect)
} else if method.eq_ignore_ascii_case("Options") {
Some(Self::Options)
} else if method.eq_ignore_ascii_case("Trace") {
Some(Self::Trace)
} else if method.eq_ignore_ascii_case("Copy") {
Some(Self::Copy)
} else if method.eq_ignore_ascii_case("Lock") {
Some(Self::Lock)
} else if method.eq_ignore_ascii_case("MkCol") {
Some(Self::MkCol)
} else if method.eq_ignore_ascii_case("Move") {
Some(Self::Move)
} else if method.eq_ignore_ascii_case("Propfind") {
Some(Self::Propfind)
} else if method.eq_ignore_ascii_case("Proppatch") {
Some(Self::Proppatch)
} else if method.eq_ignore_ascii_case("Search") {
Some(Self::Search)
} else if method.eq_ignore_ascii_case("Unlock") {
Some(Self::Unlock)
} else if method.eq_ignore_ascii_case("Bind") {
Some(Self::Bind)
} else if method.eq_ignore_ascii_case("Rebind") {
Some(Self::Rebind)
} else if method.eq_ignore_ascii_case("Unbind") {
Some(Self::Unbind)
} else if method.eq_ignore_ascii_case("Acl") {
Some(Self::Acl)
} else if method.eq_ignore_ascii_case("Report") {
Some(Self::Report)
} else if method.eq_ignore_ascii_case("MkActivity") {
Some(Self::MkActivity)
} else if method.eq_ignore_ascii_case("Checkout") {
Some(Self::Checkout)
} else if method.eq_ignore_ascii_case("Merge") {
Some(Self::Merge)
} else if method.eq_ignore_ascii_case("MSearch") {
Some(Self::MSearch)
} else if method.eq_ignore_ascii_case("Notify") {
Some(Self::Notify)
} else if method.eq_ignore_ascii_case("Subscribe") {
Some(Self::Subscribe)
} else if method.eq_ignore_ascii_case("Unsubscribe") {
Some(Self::Unsubscribe)
} else if method.eq_ignore_ascii_case("Patch") {
Some(Self::Patch)
} else if method.eq_ignore_ascii_case("Purge") {
Some(Self::Purge)
} else if method.eq_ignore_ascii_case("MkCalendar") {
Some(Self::MkCalendar)
} else if method.eq_ignore_ascii_case("Link") {
Some(Self::Link)
} else if method.eq_ignore_ascii_case("Unlink") {
Some(Self::Unlink)
} else {
None
}
}
fn as_str(&self) -> &'static str {
match self {
Self::Delete => "DELETE",
Self::Get => "GET",
Self::Head => "HEAD",
Self::Post => "POST",
Self::Put => "PUT",
Self::Connect => "CONNECT",
Self::Options => "OPTIONS",
Self::Trace => "TRACE",
Self::Copy => "COPY",
Self::Lock => "LOCK",
Self::MkCol => "MKCOL",
Self::Move => "MOVE",
Self::Propfind => "PROPFIND",
Self::Proppatch => "PROPPATCH",
Self::Search => "SEARCH",
Self::Unlock => "UNLOCK",
Self::Bind => "BIND",
Self::Rebind => "REBIND",
Self::Unbind => "UNBIND",
Self::Acl => "ACL",
Self::Report => "REPORT",
Self::MkActivity => "MKACTIVITY",
Self::Checkout => "CHECKOUT",
Self::Merge => "MERGE",
Self::MSearch => "MSEARCH",
Self::Notify => "NOTIFY",
Self::Subscribe => "SUBSCRIBE",
Self::Unsubscribe => "UNSUBSCRIBE",
Self::Patch => "PATCH",
Self::Purge => "PURGE",
Self::MkCalendar => "MKCALENDAR",
Self::Link => "LINK",
Self::Unlink => "UNLINK",
}
}
}
impl Display for Method {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for Method {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "{}", self.as_str())
}
}
#[derive(Debug)]
pub struct Headers<'b, const N: usize = 64>([httparse::Header<'b>; N]);
impl<'b, const N: usize> Headers<'b, N> {
#[inline(always)]
pub const fn new() -> Self {
Self([httparse::EMPTY_HEADER; N])
}
pub fn content_len(&self) -> Option<u64> {
self.get("Content-Length").map(|content_len_str| {
unwrap!(
content_len_str.parse::<u64>(),
"Invalid Content-Length header"
)
})
}
pub fn content_type(&self) -> Option<&str> {
self.get("Content-Type")
}
pub fn content_encoding(&self) -> Option<&str> {
self.get("Content-Encoding")
}
pub fn transfer_encoding(&self) -> Option<&str> {
self.get("Transfer-Encoding")
}
pub fn host(&self) -> Option<&str> {
self.get("Host")
}
pub fn connection(&self) -> Option<&str> {
self.get("Connection")
}
pub fn cache_control(&self) -> Option<&str> {
self.get("Cache-Control")
}
pub fn upgrade(&self) -> Option<&str> {
self.get("Upgrade")
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.iter_raw()
.filter_map(|(name, value)| str::from_utf8(value).ok().map(|value| (name, value)))
}
pub fn iter_raw(&self) -> impl Iterator<Item = (&str, &[u8])> {
self.0
.iter()
.filter(|header| !header.name.is_empty())
.map(|header| (header.name, header.value))
}
pub fn get(&self, name: &str) -> Option<&str> {
self.iter()
.find(|(hname, _)| name.eq_ignore_ascii_case(hname))
.map(|(_, value)| value)
}
pub fn get_raw(&self, name: &str) -> Option<&[u8]> {
self.iter_raw()
.find(|(hname, _)| name.eq_ignore_ascii_case(hname))
.map(|(_, value)| value)
}
pub fn set(&mut self, name: &'b str, value: &'b str) -> &mut Self {
self.set_raw(name, value.as_bytes())
}
pub fn set_raw(&mut self, name: &'b str, value: &'b [u8]) -> &mut Self {
if !name.is_empty() {
for header in &mut self.0 {
if header.name.is_empty() || header.name.eq_ignore_ascii_case(name) {
*header = Header { name, value };
return self;
}
}
panic!("No space left");
} else {
self.remove(name)
}
}
pub fn remove(&mut self, name: &str) -> &mut Self {
let index = self
.0
.iter()
.enumerate()
.find(|(_, header)| header.name.eq_ignore_ascii_case(name));
if let Some((mut index, _)) = index {
while index < self.0.len() - 1 {
self.0[index] = self.0[index + 1];
index += 1;
}
self.0[index] = EMPTY_HEADER;
}
self
}
pub fn set_content_len(
&mut self,
content_len: u64,
buf: &'b mut heapless::String<20>,
) -> &mut Self {
*buf = unwrap!(content_len.try_into());
self.set("Content-Length", buf.as_str())
}
pub fn set_content_type(&mut self, content_type: &'b str) -> &mut Self {
self.set("Content-Type", content_type)
}
pub fn set_content_encoding(&mut self, content_encoding: &'b str) -> &mut Self {
self.set("Content-Encoding", content_encoding)
}
pub fn set_transfer_encoding(&mut self, transfer_encoding: &'b str) -> &mut Self {
self.set("Transfer-Encoding", transfer_encoding)
}
pub fn set_transfer_encoding_chunked(&mut self) -> &mut Self {
self.set_transfer_encoding("Chunked")
}
pub fn set_host(&mut self, host: &'b str) -> &mut Self {
self.set("Host", host)
}
pub fn set_connection(&mut self, connection: &'b str) -> &mut Self {
self.set("Connection", connection)
}
pub fn set_connection_close(&mut self) -> &mut Self {
self.set_connection("Close")
}
pub fn set_connection_keep_alive(&mut self) -> &mut Self {
self.set_connection("Keep-Alive")
}
pub fn set_connection_upgrade(&mut self) -> &mut Self {
self.set_connection("Upgrade")
}
pub fn set_cache_control(&mut self, cache: &'b str) -> &mut Self {
self.set("Cache-Control", cache)
}
pub fn set_cache_control_no_cache(&mut self) -> &mut Self {
self.set_cache_control("No-Cache")
}
pub fn set_upgrade(&mut self, upgrade: &'b str) -> &mut Self {
self.set("Upgrade", upgrade)
}
pub fn set_upgrade_websocket(&mut self) -> &mut Self {
self.set_upgrade("websocket")
}
pub fn set_ws_upgrade_request_headers(
&mut self,
host: Option<&'b str>,
origin: Option<&'b str>,
version: Option<&'b str>,
nonce: &[u8; ws::NONCE_LEN],
buf: &'b mut [u8; ws::MAX_BASE64_KEY_LEN],
) -> &mut Self {
for (name, value) in ws::upgrade_request_headers(host, origin, version, nonce, buf) {
self.set(name, value);
}
self
}
pub fn set_ws_upgrade_response_headers<'a, H>(
&mut self,
request_headers: H,
version: Option<&'a str>,
buf: &'b mut [u8; ws::MAX_BASE64_KEY_RESPONSE_LEN],
) -> Result<&mut Self, ws::UpgradeError>
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
for (name, value) in ws::upgrade_response_headers(request_headers, version, buf)? {
self.set(name, value);
}
Ok(self)
}
}
impl<const N: usize> Default for Headers<'_, N> {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum ConnectionType {
KeepAlive,
Close,
Upgrade,
}
impl ConnectionType {
pub fn resolve(
headers_connection_type: Option<ConnectionType>,
carry_over_connection_type: Option<ConnectionType>,
http11: bool,
) -> Result<Self, HeadersMismatchError> {
match headers_connection_type {
Some(connection_type) => {
if let Some(carry_over_connection_type) = carry_over_connection_type {
if matches!(connection_type, ConnectionType::KeepAlive)
&& matches!(carry_over_connection_type, ConnectionType::Close)
{
warn!("Cannot set a Keep-Alive connection when the peer requested Close");
Err(HeadersMismatchError::ResponseConnectionTypeMismatchError)?;
}
}
Ok(connection_type)
}
None => {
if let Some(carry_over_connection_type) = carry_over_connection_type {
Ok(carry_over_connection_type)
} else if http11 {
Ok(Self::KeepAlive)
} else {
Ok(Self::Close)
}
}
}
}
pub fn from_header(name: &str, value: &str) -> Option<Self> {
if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Close") {
Some(Self::Close)
} else if "Connection".eq_ignore_ascii_case(name)
&& value.eq_ignore_ascii_case("Keep-Alive")
{
Some(Self::KeepAlive)
} else if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Upgrade") {
Some(Self::Upgrade)
} else {
None
}
}
pub fn from_headers<'a, H>(headers: H) -> Option<Self>
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
let mut connection = None;
for (name, value) in headers {
let header_connection = Self::from_header(name, value);
if let Some(header_connection) = header_connection {
if let Some(connection) = connection {
warn!(
"Multiple Connection headers found. Current {} and new {}",
connection, header_connection
);
}
connection = Some(header_connection);
}
}
connection
}
pub fn raw_header(&self) -> (&str, &[u8]) {
let connection = match self {
Self::KeepAlive => "Keep-Alive",
Self::Close => "Close",
Self::Upgrade => "Upgrade",
};
("Connection", connection.as_bytes())
}
}
impl Display for ConnectionType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::KeepAlive => write!(f, "Keep-Alive"),
Self::Close => write!(f, "Close"),
Self::Upgrade => write!(f, "Upgrade"),
}
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for ConnectionType {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::KeepAlive => defmt::write!(f, "Keep-Alive"),
Self::Close => defmt::write!(f, "Close"),
Self::Upgrade => defmt::write!(f, "Upgrade"),
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum BodyType {
Chunked,
ContentLen(u64),
Raw,
}
impl BodyType {
pub fn resolve(
headers_body_type: Option<BodyType>,
connection_type: ConnectionType,
request: bool,
http11: bool,
chunked_if_unspecified: bool,
) -> Result<Self, HeadersMismatchError> {
match headers_body_type {
Some(headers_body_type) => {
match headers_body_type {
BodyType::Raw => {
if request {
warn!("Raw body in a request. This is not allowed.");
Err(HeadersMismatchError::BodyTypeError(
"Raw body in a request. This is not allowed.",
))?;
} else if !matches!(connection_type, ConnectionType::Close) {
warn!("Raw body response with a Keep-Alive connection. This is not allowed.");
Err(HeadersMismatchError::BodyTypeError("Raw body response with a Keep-Alive connection. This is not allowed."))?;
}
}
BodyType::Chunked => {
if !http11 {
warn!("Chunked body with an HTTP/1.0 connection. This is not allowed.");
Err(HeadersMismatchError::BodyTypeError(
"Chunked body with an HTTP/1.0 connection. This is not allowed.",
))?;
}
}
_ => {}
}
Ok(headers_body_type)
}
None => {
if request {
if chunked_if_unspecified && http11 {
Ok(BodyType::Chunked)
} else {
debug!("Unknown body type in a request. Assuming Content-Length=0.");
Ok(BodyType::ContentLen(0))
}
} else if matches!(connection_type, ConnectionType::Close) {
Ok(BodyType::Raw)
} else if matches!(connection_type, ConnectionType::Upgrade) {
if http11 {
debug!("Unknown body type in response but the Connection is Upgrade. Assuming Content-Length=0.");
Ok(BodyType::ContentLen(0))
} else {
warn!("Connection is set to Upgrade but the HTTP protocol version is not 1.1. This is not allowed.");
Err(HeadersMismatchError::BodyTypeError(
"Connection is set to Upgrade but the HTTP protocol version is not 1.1. This is not allowed.",
))
}
} else if chunked_if_unspecified && http11 {
Ok(BodyType::Chunked)
} else {
warn!("Unknown body type in a response with a Keep-Alive connection. This is not allowed.");
Err(HeadersMismatchError::BodyTypeError("Unknown body type in a response with a Keep-Alive connection. This is not allowed."))
}
}
}
}
pub fn from_header(name: &str, value: &str) -> Option<Self> {
if "Transfer-Encoding".eq_ignore_ascii_case(name) {
if value.eq_ignore_ascii_case("Chunked") {
return Some(Self::Chunked);
}
} else if "Content-Length".eq_ignore_ascii_case(name) {
return Some(Self::ContentLen(unwrap!(
value.parse::<u64>(),
"Invalid Content-Length header"
))); }
None
}
pub fn from_headers<'a, H>(headers: H) -> Option<Self>
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
let mut body = None;
for (name, value) in headers {
let header_body = Self::from_header(name, value);
if let Some(header_body) = header_body {
if let Some(body) = body {
warn!(
"Multiple body type headers found. Current {} and new {}",
body, header_body
);
}
body = Some(header_body);
}
}
body
}
pub fn raw_header<'a>(&self, buf: &'a mut heapless::String<20>) -> Option<(&str, &'a [u8])> {
match self {
Self::Chunked => Some(("Transfer-Encoding", "Chunked".as_bytes())),
Self::ContentLen(len) => {
use core::fmt::Write;
buf.clear();
write_unwrap!(buf, "{}", len);
Some(("Content-Length", buf.as_bytes()))
}
Self::Raw => None,
}
}
}
impl Display for BodyType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Chunked => write!(f, "Chunked"),
Self::ContentLen(len) => write!(f, "Content-Length: {len}"),
Self::Raw => write!(f, "Raw"),
}
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for BodyType {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::Chunked => defmt::write!(f, "Chunked"),
Self::ContentLen(len) => defmt::write!(f, "Content-Length: {}", len),
Self::Raw => defmt::write!(f, "Raw"),
}
}
}
#[derive(Debug)]
pub struct RequestHeaders<'b, const N: usize> {
pub http11: bool,
pub method: Method,
pub path: &'b str,
pub headers: Headers<'b, N>,
}
impl<const N: usize> RequestHeaders<'_, N> {
#[inline(always)]
pub const fn new() -> Self {
Self {
http11: true,
method: Method::Get,
path: "/",
headers: Headers::new(),
}
}
pub fn is_ws_upgrade_request(&self) -> bool {
is_upgrade_request(self.method, self.headers.iter())
}
}
impl<const N: usize> Default for RequestHeaders<'_, N> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Display for RequestHeaders<'_, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{} ", if self.http11 { "HTTP/1.1" } else { "HTTP/1.0" })?;
writeln!(f, "{} {}", self.method, self.path)?;
for (name, value) in self.headers.iter() {
if name.is_empty() {
break;
}
writeln!(f, "{name}: {value}")?;
}
Ok(())
}
}
#[cfg(feature = "defmt")]
impl<const N: usize> defmt::Format for RequestHeaders<'_, N> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "{} ", if self.http11 { "HTTP/1.1" } else { "HTTP/1.0" });
defmt::write!(f, "{} {}\n", self.method, self.path);
for (name, value) in self.headers.iter() {
if name.is_empty() {
break;
}
defmt::write!(f, "{}: {}\n", name, value);
}
}
}
#[derive(Debug)]
pub struct ResponseHeaders<'b, const N: usize> {
pub http11: bool,
pub code: u16,
pub reason: Option<&'b str>,
pub headers: Headers<'b, N>,
}
impl<const N: usize> ResponseHeaders<'_, N> {
#[inline(always)]
pub const fn new() -> Self {
Self {
http11: true,
code: 200,
reason: None,
headers: Headers::new(),
}
}
pub fn is_ws_upgrade_accepted(
&self,
nonce: &[u8; NONCE_LEN],
buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
) -> bool {
is_upgrade_accepted(self.code, self.headers.iter(), nonce, buf)
}
}
impl<const N: usize> Default for ResponseHeaders<'_, N> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Display for ResponseHeaders<'_, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{} ", if self.http11 { "HTTP/1.1 " } else { "HTTP/1.0" })?;
writeln!(f, "{} {}", self.code, self.reason.unwrap_or(""))?;
for (name, value) in self.headers.iter() {
if name.is_empty() {
break;
}
writeln!(f, "{name}: {value}")?;
}
Ok(())
}
}
#[cfg(feature = "defmt")]
impl<const N: usize> defmt::Format for ResponseHeaders<'_, N> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "{} ", if self.http11 { "HTTP/1.1 " } else { "HTTP/1.0" });
defmt::write!(f, "{} {}\n", self.code, self.reason.unwrap_or(""));
for (name, value) in self.headers.iter() {
if name.is_empty() {
break;
}
defmt::write!(f, "{}: {}\n", name, value);
}
}
}
pub mod ws {
use base64::Engine;
use crate::Method;
pub const NONCE_LEN: usize = 16;
pub const MAX_BASE64_KEY_LEN: usize = 28;
pub const MAX_BASE64_KEY_RESPONSE_LEN: usize = 33;
pub const UPGRADE_REQUEST_HEADERS_LEN: usize = 7;
pub const UPGRADE_RESPONSE_HEADERS_LEN: usize = 4;
pub fn upgrade_request_headers<'a>(
host: Option<&'a str>,
origin: Option<&'a str>,
version: Option<&'a str>,
nonce: &[u8; NONCE_LEN],
buf: &'a mut [u8; MAX_BASE64_KEY_LEN],
) -> [(&'a str, &'a str); UPGRADE_REQUEST_HEADERS_LEN] {
let host = host.map(|host| ("Host", host)).unwrap_or(("", ""));
let origin = origin.map(|origin| ("Origin", origin)).unwrap_or(("", ""));
[
host,
origin,
("Content-Length", "0"),
("Connection", "Upgrade"),
("Upgrade", "websocket"),
("Sec-WebSocket-Version", version.unwrap_or("13")),
("Sec-WebSocket-Key", sec_key_encode(nonce, buf)),
]
}
pub fn is_upgrade_request<'a, H>(method: Method, request_headers: H) -> bool
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
if method != Method::Get {
return false;
}
let mut connection = false;
let mut upgrade = false;
for (name, value) in request_headers {
if name.eq_ignore_ascii_case("Connection") {
connection = value.eq_ignore_ascii_case("Upgrade");
} else if name.eq_ignore_ascii_case("Upgrade") {
upgrade = value.eq_ignore_ascii_case("websocket");
}
}
connection && upgrade
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum UpgradeError {
NoVersion,
NoSecKey,
UnsupportedVersion,
}
impl core::fmt::Display for UpgradeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::NoVersion => write!(f, "No Sec-WebSocket-Version header"),
Self::NoSecKey => write!(f, "No Sec-WebSocket-Key header"),
Self::UnsupportedVersion => write!(f, "Unsupported Sec-WebSocket-Version"),
}
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for UpgradeError {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::NoVersion => defmt::write!(f, "No Sec-WebSocket-Version header"),
Self::NoSecKey => defmt::write!(f, "No Sec-WebSocket-Key header"),
Self::UnsupportedVersion => defmt::write!(f, "Unsupported Sec-WebSocket-Version"),
}
}
}
impl core::error::Error for UpgradeError {}
pub fn upgrade_response_headers<'a, 'b, H>(
request_headers: H,
version: Option<&'a str>,
buf: &'b mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
) -> Result<[(&'b str, &'b str); UPGRADE_RESPONSE_HEADERS_LEN], UpgradeError>
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
let mut version_ok = false;
let mut sec_key_resp_len = None;
for (name, value) in request_headers {
if name.eq_ignore_ascii_case("Sec-WebSocket-Version") {
if !value.eq_ignore_ascii_case(version.unwrap_or("13")) {
return Err(UpgradeError::NoVersion);
}
version_ok = true;
} else if name.eq_ignore_ascii_case("Sec-WebSocket-Key") {
sec_key_resp_len = Some(sec_key_response(value, buf).len());
}
}
if version_ok {
if let Some(sec_key_resp_len) = sec_key_resp_len {
Ok([
("Content-Length", "0"),
("Connection", "Upgrade"),
("Upgrade", "websocket"),
(
"Sec-WebSocket-Accept",
unwrap!(core::str::from_utf8(&buf[..sec_key_resp_len]).map_err(|_| ())),
),
])
} else {
Err(UpgradeError::NoSecKey)
}
} else {
Err(UpgradeError::NoVersion)
}
}
pub fn is_upgrade_accepted<'a, H>(
code: u16,
response_headers: H,
nonce: &[u8; NONCE_LEN],
buf: &'a mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
) -> bool
where
H: IntoIterator<Item = (&'a str, &'a str)>,
{
if code != 101 {
return false;
}
let mut connection = false;
let mut upgrade = false;
let mut sec_key_response = false;
for (name, value) in response_headers {
if name.eq_ignore_ascii_case("Connection") {
connection = value.eq_ignore_ascii_case("Upgrade");
} else if name.eq_ignore_ascii_case("Upgrade") {
upgrade = value.eq_ignore_ascii_case("websocket");
} else if name.eq_ignore_ascii_case("Sec-WebSocket-Accept") {
let sec_key = sec_key_encode(nonce, buf);
let mut sha1 = sha1_smol::Sha1::new();
sha1.update(sec_key.as_bytes());
let sec_key_resp = sec_key_response_finalize(&mut sha1, buf);
sec_key_response = value.eq(sec_key_resp);
}
}
connection && upgrade && sec_key_response
}
fn sec_key_encode<'a>(nonce: &[u8], buf: &'a mut [u8]) -> &'a str {
let nonce_base64_len = unwrap!(base64::engine::general_purpose::STANDARD
.encode_slice(nonce, buf)
.map_err(|_| ()));
unwrap!(core::str::from_utf8(&buf[..nonce_base64_len]).map_err(|_| ()))
}
pub fn sec_key_response<'a>(
sec_key: &str,
buf: &'a mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
) -> &'a str {
let mut sha1 = sha1_smol::Sha1::new();
sec_key_response_start(sec_key, &mut sha1);
sec_key_response_finalize(&mut sha1, buf)
}
fn sec_key_response_start(sec_key: &str, sha1: &mut sha1_smol::Sha1) {
debug!("Computing response for key: {}", sec_key);
sha1.update(sec_key.as_bytes());
}
fn sec_key_response_finalize<'a>(sha1: &mut sha1_smol::Sha1, buf: &'a mut [u8]) -> &'a str {
const WS_MAGIC_GUUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1.update(WS_MAGIC_GUUID.as_bytes());
let len = unwrap!(base64::engine::general_purpose::STANDARD
.encode_slice(sha1.digest().bytes(), buf)
.map_err(|_| ()));
let sec_key_response = unwrap!(core::str::from_utf8(&buf[..len]).map_err(|_| ()));
debug!("Computed response: {}", sec_key_response);
sec_key_response
}
}
#[cfg(test)]
mod test {
use crate::{
ws::{sec_key_response, MAX_BASE64_KEY_RESPONSE_LEN},
BodyType, ConnectionType,
};
#[test]
fn test_resp() {
let mut buf = [0_u8; MAX_BASE64_KEY_RESPONSE_LEN];
let resp = sec_key_response("dGhlIHNhbXBsZSBub25jZQ==", &mut buf);
assert_eq!(resp, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn test_resolve_conn() {
assert_eq!(
unwrap!(ConnectionType::resolve(None, None, true)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(None, None, false)),
ConnectionType::Close
);
assert_eq!(
unwrap!(ConnectionType::resolve(
None,
Some(ConnectionType::KeepAlive),
false
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
None,
Some(ConnectionType::KeepAlive),
true
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
None,
false
)),
ConnectionType::Close
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
None,
false
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
None,
true
)),
ConnectionType::Close
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
None,
true
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
Some(ConnectionType::Close),
false
)),
ConnectionType::Close
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
Some(ConnectionType::KeepAlive),
false
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
Some(ConnectionType::Close),
true
)),
ConnectionType::Close
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
Some(ConnectionType::KeepAlive),
true
)),
ConnectionType::KeepAlive
);
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
Some(ConnectionType::KeepAlive),
false
)),
ConnectionType::Close
);
assert!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
Some(ConnectionType::Close),
false
)
.is_err());
assert_eq!(
unwrap!(ConnectionType::resolve(
Some(ConnectionType::Close),
Some(ConnectionType::KeepAlive),
true
)),
ConnectionType::Close
);
assert!(ConnectionType::resolve(
Some(ConnectionType::KeepAlive),
Some(ConnectionType::Close),
true
)
.is_err());
}
#[test]
fn test_resolve_body() {
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::KeepAlive,
true,
true,
false
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Close,
true,
true,
false
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::KeepAlive,
true,
false,
false
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Close,
true,
false,
false
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Upgrade,
false,
true,
false
)),
BodyType::ContentLen(0)
);
assert!(BodyType::resolve(None, ConnectionType::Upgrade, false, false, false).is_err());
assert!(BodyType::resolve(
Some(BodyType::Chunked),
ConnectionType::Close,
true,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Chunked),
ConnectionType::KeepAlive,
true,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Chunked),
ConnectionType::Close,
false,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Chunked),
ConnectionType::KeepAlive,
false,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::Close,
true,
true,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::KeepAlive,
true,
true,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::Close,
true,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::KeepAlive,
true,
false,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::KeepAlive,
false,
true,
false
)
.is_err());
assert!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::KeepAlive,
false,
false,
false
)
.is_err());
assert_eq!(
unwrap!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::Close,
false,
true,
false
)),
BodyType::Raw
);
assert_eq!(
unwrap!(BodyType::resolve(
Some(BodyType::Raw),
ConnectionType::Close,
false,
false,
false
)),
BodyType::Raw
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Close,
true,
true,
true
)),
BodyType::Chunked
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::KeepAlive,
true,
true,
true
)),
BodyType::Chunked
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Close,
true,
false,
true
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::KeepAlive,
true,
false,
true
)),
BodyType::ContentLen(0)
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::KeepAlive,
false,
true,
true
)),
BodyType::Chunked
);
assert_eq!(
unwrap!(BodyType::resolve(
None,
ConnectionType::Close,
false,
true,
true
)),
BodyType::Raw
);
}
}