use core::marker::PhantomData;
use picoserve_derive::ErrorWithStatusCode;
use crate::{
self as picoserve,
extract::FromRequestParts,
futures::Either,
io::{Read, Write, WriteExt},
};
use super::StatusCode;
#[derive(Debug, thiserror::Error, ErrorWithStatusCode)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[status_code(BAD_REQUEST)]
pub enum WebSocketUpgradeRejection {
#[error("Websocket upgrades must use the `GET` method")]
#[status_code(BAD_REQUEST)]
MethodNotGet,
#[error("Websocket upgrades must have a Connection header of `Upgrade`")]
InvalidConnectionHeader,
#[error("Websocket upgrades must have an Upgrade of `websocket`")]
InvalidUpgradeHeader,
#[error("Websocket version must be 13")]
InvalidWebSocketVersionHeader,
#[error("Websocket upgrades must have a `Sec-WebSocket-Key` header")]
WebSocketKeyHeaderMissing,
}
pub trait WebSocketProtocol {
fn name(&self) -> Option<&str>;
}
pub struct UnspecifiedProtocol;
impl WebSocketProtocol for UnspecifiedProtocol {
fn name(&self) -> Option<&str> {
None
}
}
pub struct SpecifiedProtocol<P: AsRef<str>>(P);
impl<P: AsRef<str>> WebSocketProtocol for SpecifiedProtocol<P> {
fn name(&self) -> Option<&str> {
Some(self.0.as_ref())
}
}
type WebSocketKey = [u8; 28];
pub struct WebSocketUpgrade {
key: WebSocketKey,
protocols: Option<heapless::String<32>>,
upgrade_token: crate::extract::UpgradeToken,
}
impl WebSocketUpgrade {
pub fn protocols(&self) -> Option<impl Iterator<Item = &str>> {
self.protocols
.as_ref()
.map(|protocols| protocols.split(',').map(str::trim))
}
}
impl<'r, State> crate::extract::FromRequest<'r, State> for WebSocketUpgrade {
type Rejection = WebSocketUpgradeRejection;
async fn from_request<R: Read>(
state: &'r State,
request_parts: crate::request::RequestParts<'r>,
_request_body: crate::request::RequestBody<'r, R>,
) -> Result<Self, Self::Rejection> {
if !request_parts.method().eq_ignore_ascii_case("get") {
return Err(WebSocketUpgradeRejection::MethodNotGet);
}
let upgrade_token = crate::extract::UpgradeToken::from_request_parts(state, &request_parts)
.await
.map_err(|crate::extract::NoUpgradeHeaderError| {
WebSocketUpgradeRejection::InvalidUpgradeHeader
})?;
if request_parts
.headers()
.get("upgrade")
.is_none_or(|upgrade| upgrade != "websocket")
{
return Err(WebSocketUpgradeRejection::InvalidUpgradeHeader);
}
if !request_parts
.headers()
.get("sec-websocket-version")
.is_some_and(|version| version == "13")
{
return Err(WebSocketUpgradeRejection::InvalidWebSocketVersionHeader);
}
let key = request_parts
.headers()
.get("sec-websocket-key")
.map(|key| {
let hash = lhash::Sha1::new()
.const_update(key.value)
.const_update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
.const_result();
let mut buffer = [0; 28];
data_encoding::BASE64.encode_mut(&hash, &mut buffer);
buffer
})
.ok_or(WebSocketUpgradeRejection::WebSocketKeyHeaderMissing)?;
let protocols = request_parts
.headers()
.get("sec-websocket-protocol")
.and_then(|protocol| {
let mut buffer = heapless::String::new();
buffer.push_str(protocol.as_str().ok()?).ok()?;
Some(buffer)
});
Ok(Self {
key,
protocols,
upgrade_token,
})
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Opcode {
Data(Data),
Control(Control),
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Data {
Continue,
Text,
Binary,
Reserved(u8),
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Control {
Close,
Ping,
Pong,
Reserved(u8),
}
impl From<u8> for Opcode {
fn from(value: u8) -> Self {
match value {
0 => Opcode::Data(Data::Continue),
1 => Opcode::Data(Data::Text),
2 => Opcode::Data(Data::Binary),
3..=7 => Opcode::Data(Data::Reserved(value)),
8 => Opcode::Control(Control::Close),
9 => Opcode::Control(Control::Ping),
10 => Opcode::Control(Control::Pong),
11..=255 => Opcode::Control(Control::Reserved(value)),
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Frame {
pub is_final: bool,
pub opcode: Opcode,
pub length: usize,
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ReadFrameError {
UnexpectedEof,
MessageIsTooLong(u64),
OutOfSpace,
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ReadMessageError {
ReadFrameError(ReadFrameError),
ReservedOpcode(u8),
MessageStartsWithContinuation,
UnexpectedMessageStart,
TextIsNotUtf8,
}
impl ReadMessageError {
pub fn code(&self) -> u16 {
match self {
Self::ReadFrameError(_)
| Self::MessageStartsWithContinuation
| Self::UnexpectedMessageStart => 1002,
Self::ReservedOpcode(_) => 1003,
Self::TextIsNotUtf8 => 1007,
}
}
}
enum InternalError<IoError, Error> {
Io(IoError),
Other(Error),
}
impl<IoError, Error> From<Error> for InternalError<IoError, Error> {
fn from(error: Error) -> Self {
Self::Other(error)
}
}
impl<IoError> From<crate::io::ReadExactError<IoError>> for InternalError<IoError, ReadFrameError> {
fn from(value: crate::io::ReadExactError<IoError>) -> Self {
match value {
crate::io::ReadExactError::UnexpectedEof => Self::Other(ReadFrameError::UnexpectedEof),
crate::io::ReadExactError::Other(error) => Self::Io(error),
}
}
}
impl<IOError> From<InternalError<IOError, ReadFrameError>>
for InternalError<IOError, ReadMessageError>
{
fn from(error: InternalError<IOError, ReadFrameError>) -> Self {
match error {
InternalError::Io(error) => InternalError::Io(error),
InternalError::Other(error) => {
InternalError::Other(ReadMessageError::ReadFrameError(error))
}
}
}
}
impl<IoError> From<core::str::Utf8Error> for InternalError<IoError, ReadMessageError> {
fn from(_: core::str::Utf8Error) -> Self {
ReadMessageError::TextIsNotUtf8.into()
}
}
trait InternalResultExt<T, S, IoError, Error> {
fn into_nested_result(self) -> Result<Either<Result<T, Error>, S>, IoError>;
}
impl<T, S, IoError, Error> InternalResultExt<T, S, IoError, Error>
for Result<Either<T, S>, InternalError<IoError, Error>>
{
fn into_nested_result(self) -> Result<Either<Result<T, Error>, S>, IoError> {
match self {
Ok(Either::First(value)) => Ok(Either::First(Ok(value))),
Ok(Either::Second(signal)) => Ok(Either::Second(signal)),
Err(InternalError::Io(error)) => Err(error),
Err(InternalError::Other(error)) => Ok(Either::First(Err(error))),
}
}
}
enum MessageOpcode {
Text,
Binary,
Close,
Ping,
Pong,
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Message<'a> {
Text(&'a str),
Binary(&'a [u8]),
Close(Option<(u16, &'a str)>),
Ping(&'a [u8]),
Pong(&'a [u8]),
}
async fn next_byte<R: Read>(reader: &mut R) -> Result<u8, crate::io::ReadExactError<R::Error>> {
let mut buffer = 0;
reader
.read_exact(core::slice::from_mut(&mut buffer))
.await
.map(|()| buffer)
}
pub struct SocketRx<R: Read> {
reader: R,
}
impl<R: Read> SocketRx<R> {
async fn next_frame_internal<Signal: core::future::Future>(
&mut self,
buffer: &mut [u8],
other_signal: Signal,
) -> Result<Either<Frame, Signal::Output>, InternalError<R::Error, ReadFrameError>> {
let first = match crate::futures::select_either(
core::pin::pin!(other_signal),
next_byte(&mut self.reader),
)
.await
{
Either::First(signal) => return Ok(Either::Second(signal)),
Either::Second(b) => b?,
};
let second = next_byte(&mut self.reader).await?;
let is_final = first & 0x80 != 0;
let opcode = Opcode::from(first & 0x0F);
let is_masked = second & 0x80 != 0;
let length_byte = second & 0x7F;
let length = match length_byte {
126 => {
let mut length_bytes = [0; 2];
self.reader.read_exact(&mut length_bytes).await?;
u16::from_be_bytes(length_bytes).into()
}
127 => {
let mut length_bytes = [0; 8];
self.reader.read_exact(&mut length_bytes).await?;
let length = u64::from_be_bytes(length_bytes);
length
.try_into()
.map_err(|_| ReadFrameError::MessageIsTooLong(length))?
}
length => length.into(),
};
let mut mask = [0; 4];
if is_masked {
self.reader.read_exact(&mut mask).await?;
}
let data = buffer.get_mut(..length).ok_or(ReadFrameError::OutOfSpace)?;
self.reader.read_exact(data).await?;
if is_masked {
for (data, mask) in data.iter_mut().zip(mask.iter().cycle()) {
*data ^= mask;
}
}
Ok(Either::First(Frame {
is_final,
opcode,
length,
}))
}
pub async fn next_frame<Signal: core::future::Future>(
&mut self,
buffer: &mut [u8],
signal: Signal,
) -> Result<Either<Result<Frame, ReadFrameError>, Signal::Output>, R::Error> {
self.next_frame_internal(buffer, signal)
.await
.into_nested_result()
}
async fn next_message_internal<'a, Signal: core::future::Future>(
&mut self,
buffer: &'a mut [u8],
signal: Signal,
) -> Result<Either<Message<'a>, Signal::Output>, InternalError<R::Error, ReadMessageError>>
{
let Frame {
is_final: is_single_frame,
opcode,
length: mut message_length,
} = match self.next_frame_internal(buffer, signal).await? {
Either::First(frame) => frame,
Either::Second(signal) => return Ok(Either::Second(signal)),
};
let opcode = match opcode {
Opcode::Data(Data::Continue) => {
return Err(ReadMessageError::MessageStartsWithContinuation.into())
}
Opcode::Data(Data::Text) => MessageOpcode::Text,
Opcode::Data(Data::Binary) => MessageOpcode::Binary,
Opcode::Control(Control::Close) => MessageOpcode::Close,
Opcode::Control(Control::Ping) => MessageOpcode::Ping,
Opcode::Control(Control::Pong) => MessageOpcode::Pong,
Opcode::Data(Data::Reserved(opcode)) | Opcode::Control(Control::Reserved(opcode)) => {
return Err(ReadMessageError::ReservedOpcode(opcode).into())
}
};
if !is_single_frame {
loop {
let Frame {
is_final,
opcode,
length,
} = self
.next_frame_internal(&mut buffer[message_length..], core::future::pending())
.await?
.ignore_never_b();
match opcode {
Opcode::Data(Data::Continue) => (),
Opcode::Data(Data::Text | Data::Binary)
| Opcode::Control(Control::Close | Control::Ping | Control::Pong) => {
return Err(ReadMessageError::UnexpectedMessageStart.into())
}
Opcode::Data(Data::Reserved(opcode))
| Opcode::Control(Control::Reserved(opcode)) => {
return Err(ReadMessageError::ReservedOpcode(opcode).into())
}
}
message_length += length;
if is_final {
break;
}
}
}
let data = &buffer[..message_length];
Ok(Either::First(match opcode {
MessageOpcode::Text => Message::Text(core::str::from_utf8(data)?),
MessageOpcode::Binary => Message::Binary(data),
MessageOpcode::Close => Message::Close(match data {
[] => None,
&[code] => Some((code.into(), "")),
[c1, c0, text @ ..] => {
Some((u16::from_be_bytes([*c1, *c0]), core::str::from_utf8(text)?))
}
}),
MessageOpcode::Ping => Message::Ping(data),
MessageOpcode::Pong => Message::Pong(data),
}))
}
pub async fn next_message<'a, Signal: core::future::Future>(
&mut self,
buffer: &'a mut [u8],
signal: Signal,
) -> Result<Either<Result<Message<'a>, ReadMessageError>, Signal::Output>, R::Error> {
self.next_message_internal(buffer, signal)
.await
.into_nested_result()
}
}
pub struct SocketTx<W: Write> {
writer: W,
}
impl<W: Write> SocketTx<W> {
async fn flush(&mut self) -> Result<(), W::Error> {
self.writer.flush().await
}
async fn write_length(&mut self, length: usize) -> Result<(), W::Error> {
if let Some(length_byte) = u8::try_from(length).ok().filter(|length| *length <= 125) {
self.writer.write_all(&[length_byte]).await
} else if let Ok(length) = u16::try_from(length) {
self.writer.write_all(&[126]).await?;
self.writer.write_all(&length.to_be_bytes()).await
} else {
self.writer.write_all(&[127]).await?;
self.writer.write_all(&(length as u64).to_be_bytes()).await
}
}
async fn write_frame(
&mut self,
is_final: bool,
opcode: u8,
data: &[u8],
) -> Result<(), W::Error> {
self.writer
.write_all(&[if is_final { 0b10000000 } else { 0 } | opcode])
.await?;
self.write_length(data.len()).await?;
self.writer.write_all(data).await
}
pub async fn send_text(&mut self, data: &str) -> Result<(), W::Error> {
self.write_frame(true, 1, data.as_bytes()).await?;
self.flush().await
}
pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), W::Error> {
self.write_frame(true, 2, data).await?;
self.flush().await
}
pub async fn send_display(&mut self, data: impl core::fmt::Display) -> Result<(), W::Error> {
let opcode = &mut 1;
write!(FrameWriter { opcode, tx: self }, "{data}").await?;
self.write_frame(true, *opcode, &[]).await?;
self.flush().await
}
#[cfg(feature = "json")]
pub async fn send_json(&mut self, value: impl serde::Serialize) -> Result<(), W::Error> {
let opcode = &mut 1;
super::json::Json(value)
.do_write_to(&mut FrameWriter { opcode, tx: self })
.await?;
self.write_frame(true, *opcode, &[]).await?;
self.flush().await
}
pub async fn close(mut self, reason: impl Into<Option<(u16, &str)>>) -> Result<(), W::Error> {
self.writer.write_all(&[0b10000000 | 8]).await?;
match reason.into() {
Some((code, message)) => {
let code_bytes = code.to_be_bytes();
self.write_length(code_bytes.len() + message.len()).await?;
self.writer.write_all(&code_bytes).await?;
self.writer.write_all(message.as_bytes()).await
}
None => self.write_length(0).await,
}?;
self.flush().await
}
pub async fn send_ping(&mut self, data: &[u8]) -> Result<(), W::Error> {
self.write_frame(true, 9, data).await?;
self.flush().await
}
pub async fn send_pong(&mut self, data: &[u8]) -> Result<(), W::Error> {
self.write_frame(true, 10, data).await?;
self.flush().await
}
}
struct FrameWriter<'w, W: Write> {
opcode: &'w mut u8,
tx: &'w mut SocketTx<W>,
}
impl<W: Write> crate::io::ErrorType for FrameWriter<'_, W> {
type Error = W::Error;
}
impl<W: Write> Write for FrameWriter<'_, W> {
async fn write(&mut self, data: &[u8]) -> Result<usize, W::Error> {
self.tx
.write_frame(false, core::mem::replace(self.opcode, 0), data)
.await
.map(|()| data.len())
}
async fn flush(&mut self) -> Result<(), Self::Error> {
self.tx.flush().await
}
}
pub trait WebSocketCallback {
async fn run<R: Read, W: Write<Error = R::Error>>(
self,
rx: SocketRx<R>,
tx: SocketTx<W>,
) -> Result<(), W::Error>;
}
impl<C: WebSocketCallback> WebSocketCallbackWithShutdownSignal for C {
async fn run_with_shutdown_signal<
R: Read,
W: Write<Error = R::Error>,
S: core::future::Future<Output = ()> + Clone + Unpin,
>(
self,
rx: SocketRx<R>,
tx: SocketTx<W>,
_shutdown_signal: S,
) -> Result<(), W::Error> {
self.run(rx, tx).await
}
}
pub trait WebSocketCallbackWithShutdownSignal {
async fn run_with_shutdown_signal<
R: Read,
W: Write<Error = R::Error>,
S: core::future::Future<Output = ()> + Clone + Unpin,
>(
self,
rx: SocketRx<R>,
tx: SocketTx<W>,
shutdown_signal: S,
) -> Result<(), W::Error>;
}
pub trait WebSocketCallbackWithState<State> {
async fn run_with_state<R: Read, W: Write<Error = R::Error>>(
self,
state: &State,
rx: SocketRx<R>,
tx: SocketTx<W>,
) -> Result<(), W::Error>;
}
impl<State, C: WebSocketCallback> WebSocketCallbackWithState<State> for C {
async fn run_with_state<R: Read, W: Write<Error = R::Error>>(
self,
_state: &State,
rx: SocketRx<R>,
tx: SocketTx<W>,
) -> Result<(), W::Error> {
self.run(rx, tx).await
}
}
pub trait WebSocketCallbackWithStateAndShutdownSignal<State> {
async fn run_with_state_and_shutdown_signal<
R: Read,
W: Write<Error = R::Error>,
S: core::future::Future<Output = ()> + Clone + Unpin,
>(
self,
state: &State,
rx: SocketRx<R>,
tx: SocketTx<W>,
shutdown_signal: S,
) -> Result<(), W::Error>;
}
impl<State, C: WebSocketCallbackWithState<State>> WebSocketCallbackWithStateAndShutdownSignal<State>
for C
{
async fn run_with_state_and_shutdown_signal<
R: Read,
W: Write<Error = R::Error>,
S: core::future::Future<Output = ()> + Clone + Unpin,
>(
self,
state: &State,
rx: SocketRx<R>,
tx: SocketTx<W>,
_shutdown_signal: S,
) -> Result<(), W::Error> {
self.run_with_state(state, rx, tx).await
}
}
pub struct UpgradedWebSocket<P: WebSocketProtocol, C> {
sec_websocket_accept: WebSocketKey,
sec_websocket_protocol: P,
upgrade_token: crate::extract::UpgradeToken,
callback: C,
}
impl<C> UpgradedWebSocket<UnspecifiedProtocol, C> {
pub fn with_protocol<P: AsRef<str>>(
self,
protocol: P,
) -> UpgradedWebSocket<SpecifiedProtocol<P>, C> {
let UpgradedWebSocket {
sec_websocket_accept,
sec_websocket_protocol: UnspecifiedProtocol,
upgrade_token,
callback,
} = self;
UpgradedWebSocket {
sec_websocket_accept,
sec_websocket_protocol: SpecifiedProtocol(protocol),
upgrade_token,
callback,
}
}
}
pub struct CallbackNotUsingState<C: WebSocketCallbackWithShutdownSignal> {
callback: C,
}
pub struct CallbackUsingState<State, C: WebSocketCallbackWithStateAndShutdownSignal<State>> {
callback: C,
state: PhantomData<fn(&State)>,
}
impl WebSocketUpgrade {
pub fn on_upgrade<C: WebSocketCallbackWithShutdownSignal>(
self,
callback: C,
) -> UpgradedWebSocket<UnspecifiedProtocol, CallbackNotUsingState<C>> {
super::assert_implements_into_response(UpgradedWebSocket {
sec_websocket_accept: self.key,
sec_websocket_protocol: UnspecifiedProtocol,
upgrade_token: self.upgrade_token,
callback: CallbackNotUsingState { callback },
})
}
pub fn on_upgrade_using_state<State, C: WebSocketCallbackWithStateAndShutdownSignal<State>>(
self,
callback: C,
) -> UpgradedWebSocket<UnspecifiedProtocol, CallbackUsingState<State, C>> {
super::assert_implements_into_response_with_state::<State, _>(UpgradedWebSocket {
sec_websocket_accept: self.key,
sec_websocket_protocol: UnspecifiedProtocol,
upgrade_token: self.upgrade_token,
callback: CallbackUsingState {
callback,
state: PhantomData,
},
})
}
}
fn websocket_response<'a, B: super::Body + 'a>(
sec_websocket_accept: &'a WebSocketKey,
sec_websocket_protocol: Option<&'a str>,
body: B,
) -> super::Response<impl super::HeadersIter + 'a, B> {
super::Response {
status_code: StatusCode::SWITCHING_PROTOCOLS,
headers: [
("Upgrade", "websocket"),
("Connection", "upgrade"),
(
"Sec-WebSocket-Accept",
#[allow(unsafe_code)]
unsafe {
core::str::from_utf8_unchecked(sec_websocket_accept)
},
),
],
body,
}
.with_headers(
sec_websocket_protocol
.map(|sec_websocket_protocol| ("Sec-WebSocket-Protocol", sec_websocket_protocol)),
)
}
impl<P: WebSocketProtocol, C: WebSocketCallbackWithShutdownSignal> super::IntoResponse
for UpgradedWebSocket<P, CallbackNotUsingState<C>>
{
async fn write_to<R: Read, W: super::ResponseWriter<Error = R::Error>>(
self,
connection: super::Connection<'_, R>,
response_writer: W,
) -> Result<crate::ResponseSent, W::Error> {
struct Body<C: WebSocketCallbackWithShutdownSignal> {
upgrade_token: crate::extract::UpgradeToken,
callback: CallbackNotUsingState<C>,
}
impl<C: WebSocketCallbackWithShutdownSignal> super::Body for Body<C> {
async fn write_response_body<
R: crate::io::Read,
W: crate::io::Write<Error = R::Error>,
>(
self,
connection: super::Connection<'_, R>,
writer: W,
) -> Result<(), W::Error> {
let shutdown_signal = connection.shutdown_signal.clone();
self.callback
.callback
.run_with_shutdown_signal(
SocketRx {
reader: connection.upgrade(self.upgrade_token),
},
SocketTx { writer },
shutdown_signal,
)
.await
}
}
let UpgradedWebSocket {
sec_websocket_accept,
sec_websocket_protocol,
upgrade_token,
callback,
} = self;
response_writer
.write_response(
connection,
websocket_response(
&sec_websocket_accept,
sec_websocket_protocol.name(),
Body {
upgrade_token,
callback,
},
),
)
.await
}
}
impl<State, P: WebSocketProtocol, C: WebSocketCallbackWithStateAndShutdownSignal<State>>
super::IntoResponseWithState<State> for UpgradedWebSocket<P, CallbackUsingState<State, C>>
{
async fn write_to_with_state<R: Read, W: super::ResponseWriter<Error = R::Error>>(
self,
state: &State,
connection: super::Connection<'_, R>,
response_writer: W,
) -> Result<crate::ResponseSent, W::Error> {
struct Body<'s, State, C: WebSocketCallbackWithStateAndShutdownSignal<State>> {
state: &'s State,
upgrade_token: crate::extract::UpgradeToken,
callback: CallbackUsingState<State, C>,
}
impl<State, C: WebSocketCallbackWithStateAndShutdownSignal<State>> super::Body
for Body<'_, State, C>
{
async fn write_response_body<
R: crate::io::Read,
W: crate::io::Write<Error = R::Error>,
>(
self,
connection: super::Connection<'_, R>,
writer: W,
) -> Result<(), W::Error> {
let shutdown_signal = connection.shutdown_signal.clone();
self.callback
.callback
.run_with_state_and_shutdown_signal(
self.state,
SocketRx {
reader: connection.upgrade(self.upgrade_token),
},
SocketTx { writer },
shutdown_signal,
)
.await
}
}
let UpgradedWebSocket {
sec_websocket_accept,
sec_websocket_protocol,
upgrade_token,
callback,
} = self;
response_writer
.write_response(
connection,
websocket_response(
&sec_websocket_accept,
sec_websocket_protocol.name(),
Body {
state,
upgrade_token,
callback,
},
),
)
.await
}
}