use std::{
cell::UnsafeCell, fmt, hint::unreachable_unchecked, mem::replace, num::NonZeroU16, ops::Deref,
};
use bytes::{BufMut, Bytes, BytesMut};
use super::error::ProtocolError;
use crate::utf8;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(super) enum OpCode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl OpCode {
pub(super) fn is_control(self) -> bool {
matches!(self, Self::Close | Self::Ping | Self::Pong)
}
}
impl TryFrom<u8> for OpCode {
type Error = ProtocolError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Continuation),
1 => Ok(Self::Text),
2 => Ok(Self::Binary),
8 => Ok(Self::Close),
9 => Ok(Self::Ping),
10 => Ok(Self::Pong),
_ => Err(ProtocolError::InvalidOpcode),
}
}
}
impl From<OpCode> for u8 {
fn from(value: OpCode) -> Self {
match value {
OpCode::Continuation => 0,
OpCode::Text => 1,
OpCode::Binary => 2,
OpCode::Close => 8,
OpCode::Ping => 9,
OpCode::Pong => 10,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct CloseCode(NonZeroU16);
#[rustfmt::skip]
impl CloseCode {
pub const NORMAL_CLOSURE: Self = Self(unsafe { NonZeroU16::new_unchecked(1000) });
pub const GOING_AWAY: Self = Self(unsafe { NonZeroU16::new_unchecked(1001) });
pub const PROTOCOL_ERROR: Self = Self(unsafe { NonZeroU16::new_unchecked(1002) });
pub const UNSUPPORTED_DATA: Self = Self(unsafe { NonZeroU16::new_unchecked(1003) });
pub const NO_STATUS_RECEIVED: Self = Self(unsafe { NonZeroU16::new_unchecked(1005) });
pub const INVALID_FRAME_PAYLOAD_DATA: Self = Self(unsafe { NonZeroU16::new_unchecked(1007) });
pub const POLICY_VIOLATION: Self = Self(unsafe { NonZeroU16::new_unchecked(1008) });
pub const MESSAGE_TOO_BIG: Self = Self(unsafe { NonZeroU16::new_unchecked(1009) });
pub const MANDATORY_EXTENSION: Self = Self(unsafe { NonZeroU16::new_unchecked(1010) });
pub const INTERNAL_SERVER_ERROR: Self = Self(unsafe { NonZeroU16::new_unchecked(1011) });
pub const SERVICE_RESTART: Self = Self(unsafe { NonZeroU16::new_unchecked(1012) });
pub const SERVICE_OVERLOAD: Self = Self(unsafe { NonZeroU16::new_unchecked(1013) });
pub const BAD_GATEWAY: Self = Self(unsafe { NonZeroU16::new_unchecked(1014) });
}
impl CloseCode {
pub(super) fn is_sendable(self) -> bool {
match self.0.get() {
1004 | 1005 | 1006 | 1015 => false,
1000..=4999 => true,
0..=999 | 5000..=u16::MAX => unsafe { unreachable_unchecked() },
}
}
}
impl From<CloseCode> for u16 {
fn from(value: CloseCode) -> Self {
value.0.get()
}
}
impl TryFrom<u16> for CloseCode {
type Error = ProtocolError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
1000..=1015 | 3000..=4999 => Ok(Self(unsafe { NonZeroU16::new_unchecked(value) })),
0..=999 | 1016..=2999 | 5000..=u16::MAX => Err(ProtocolError::InvalidCloseCode),
}
}
}
pub struct Payload {
data: UnsafeCell<PayloadStorage>,
utf8_validated: bool,
}
impl Payload {
const fn from_static(bytes: &'static [u8]) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Shared(Bytes::from_static(bytes))),
utf8_validated: false,
}
}
pub(super) fn set_utf8_validated(&mut self, value: bool) {
self.utf8_validated = value;
}
pub(super) fn truncate(&mut self, len: usize) {
match self.data.get_mut() {
PayloadStorage::Unique(b) => b.truncate(len),
PayloadStorage::Shared(b) => b.truncate(len),
}
}
fn split_to(&mut self, at: usize) -> Self {
self.utf8_validated = false;
Self {
data: UnsafeCell::new(match self.data.get_mut() {
PayloadStorage::Unique(b) => PayloadStorage::Unique(b.split_to(at)),
PayloadStorage::Shared(b) => PayloadStorage::Shared(b.split_to(at)),
}),
utf8_validated: false,
}
}
fn as_bytes(&self) -> &Bytes {
if let PayloadStorage::Shared(bytes) = self.as_ref() {
bytes
} else {
unsafe {
let payload = self.data.get().read();
let bytes = match payload {
PayloadStorage::Unique(p) => p.freeze(),
PayloadStorage::Shared(_) => unreachable_unchecked(),
};
self.data.get().write(PayloadStorage::Shared(bytes));
}
match self.as_ref() {
PayloadStorage::Unique(_) => unsafe { unreachable_unchecked() },
PayloadStorage::Shared(p) => p,
}
}
}
pub(super) fn try_into_bytesmut(self) -> Result<BytesMut, Bytes> {
match self.data.into_inner() {
PayloadStorage::Unique(s) => Ok(s),
PayloadStorage::Shared(e) => Err(e),
}
}
}
impl AsRef<PayloadStorage> for Payload {
fn as_ref(&self) -> &PayloadStorage {
unsafe { &*self.data.get() }
}
}
impl Clone for Payload {
fn clone(&self) -> Self {
let bytes = self.as_bytes();
Self {
data: UnsafeCell::new(PayloadStorage::Shared(bytes.clone())),
utf8_validated: self.utf8_validated,
}
}
}
impl Deref for Payload {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self.as_ref() {
PayloadStorage::Unique(b) => b,
PayloadStorage::Shared(b) => b,
}
}
}
impl fmt::Debug for Payload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Payload").field(self.as_ref()).finish()
}
}
impl From<Bytes> for Payload {
fn from(value: Bytes) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Shared(value)),
utf8_validated: false,
}
}
}
impl From<BytesMut> for Payload {
fn from(value: BytesMut) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Unique(value)),
utf8_validated: false,
}
}
}
impl From<Payload> for Bytes {
fn from(value: Payload) -> Self {
match value.data.into_inner() {
PayloadStorage::Unique(p) => p.freeze(),
PayloadStorage::Shared(p) => p,
}
}
}
impl From<String> for Payload {
fn from(value: String) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Shared(Bytes::from(value.into_bytes()))),
utf8_validated: true,
}
}
}
impl From<&'static [u8]> for Payload {
fn from(value: &'static [u8]) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Shared(Bytes::from_static(value))),
utf8_validated: false,
}
}
}
impl From<&'static str> for Payload {
fn from(value: &'static str) -> Self {
Self {
data: UnsafeCell::new(PayloadStorage::Shared(Bytes::from_static(value.as_bytes()))),
utf8_validated: true,
}
}
}
#[derive(Debug)]
enum PayloadStorage {
Unique(BytesMut),
Shared(Bytes),
}
#[derive(Debug, Clone)]
pub struct Message {
pub(super) opcode: OpCode,
pub(super) payload: Payload,
}
impl Message {
#[must_use]
pub fn text<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Text,
payload: payload.into(),
}
}
#[must_use]
pub fn binary<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Binary,
payload: payload.into(),
}
}
#[must_use]
pub fn close(code: Option<CloseCode>, reason: &str) -> Self {
let mut payload = BytesMut::with_capacity((2 + reason.len()) * usize::from(code.is_some()));
if let Some(code) = code {
payload.put_u16(code.into());
payload.extend_from_slice(reason.as_bytes());
}
Self {
opcode: OpCode::Close,
payload: payload.into(),
}
}
#[must_use]
pub fn ping<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Ping,
payload: payload.into(),
}
}
#[must_use]
pub fn pong<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Pong,
payload: payload.into(),
}
}
#[must_use]
pub fn is_text(&self) -> bool {
self.opcode == OpCode::Text
}
#[must_use]
pub fn is_binary(&self) -> bool {
self.opcode == OpCode::Binary
}
#[must_use]
pub fn is_close(&self) -> bool {
self.opcode == OpCode::Close
}
#[must_use]
pub fn is_ping(&self) -> bool {
self.opcode == OpCode::Ping
}
#[must_use]
pub fn is_pong(&self) -> bool {
self.opcode == OpCode::Pong
}
#[must_use]
pub fn into_payload(self) -> Payload {
self.payload
}
pub fn as_payload(&self) -> &Payload {
&self.payload
}
pub fn as_text(&self) -> Option<&str> {
(self.opcode == OpCode::Text).then(|| {
assert!(
self.payload.utf8_validated || utf8::parse_str(&self.payload).is_ok(),
"called as_text on message created from payload with invalid utf-8"
);
unsafe { std::str::from_utf8_unchecked(&self.payload) }
})
}
pub fn as_close(&self) -> Option<(CloseCode, &str)> {
(self.opcode == OpCode::Close).then(|| {
let code = if self.payload.is_empty() {
CloseCode::NO_STATUS_RECEIVED
} else {
unsafe {
CloseCode::try_from(u16::from_be_bytes(
self.payload
.get_unchecked(0..2)
.try_into()
.unwrap_unchecked(),
))
.unwrap_unchecked()
}
};
let reason =
unsafe { std::str::from_utf8_unchecked(self.payload.get(2..).unwrap_or_default()) };
(code, reason)
})
}
pub(super) fn into_frames(self, frame_size: usize) -> MessageFrames {
MessageFrames {
frame_size,
payload: self.payload,
opcode: self.opcode,
}
}
}
pub(super) struct MessageFrames {
frame_size: usize,
payload: Payload,
opcode: OpCode,
}
impl Iterator for MessageFrames {
type Item = Frame;
fn next(&mut self) -> Option<Self::Item> {
let is_empty = self.payload.is_empty() && self.opcode == OpCode::Continuation;
(!is_empty).then(|| {
let payload = self
.payload
.split_to(self.frame_size.min(self.payload.len()));
Frame {
opcode: replace(&mut self.opcode, OpCode::Continuation),
is_final: self.payload.is_empty(),
payload,
}
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct Limits {
pub(super) max_payload_len: usize,
}
impl Limits {
#[must_use]
pub fn unlimited() -> Self {
Self {
max_payload_len: usize::MAX,
}
}
#[must_use]
pub fn max_payload_len(mut self, size: Option<usize>) -> Self {
self.max_payload_len = size.unwrap_or(usize::MAX);
self
}
}
impl Default for Limits {
fn default() -> Self {
Self {
max_payload_len: 64 * 1024 * 1024,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Config {
pub(super) frame_size: usize,
}
impl Config {
#[must_use]
pub fn frame_size(mut self, frame_size: usize) -> Self {
self.frame_size = frame_size;
self
}
}
impl Default for Config {
fn default() -> Self {
Self {
frame_size: 4 * 1024 * 1024,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum Role {
Client,
Server,
}
#[derive(Debug, PartialEq)]
pub(super) enum StreamState {
Active,
ClosedByPeer,
ClosedByUs,
CloseAcknowledged,
}
#[derive(Clone, Debug)]
pub(super) struct Frame {
pub opcode: OpCode,
pub is_final: bool,
pub payload: Payload,
}
impl Frame {
#[allow(clippy::declare_interior_mutable_const)]
pub const DEFAULT_CLOSE: Self = Self {
opcode: OpCode::Close,
is_final: true,
payload: Payload::from_static(&CloseCode::NORMAL_CLOSURE.0.get().to_be_bytes()),
};
pub fn encode(&self, out: &mut [u8; 10]) -> u8 {
out[0] = u8::from(self.is_final) << 7 | u8::from(self.opcode);
if u16::try_from(self.payload.len()).is_err() {
out[1] = 127;
let len = u64::try_from(self.payload.len()).unwrap();
out[2..10].copy_from_slice(&len.to_be_bytes());
10
} else if self.payload.len() > 125 {
out[1] = 126;
let len = u16::try_from(self.payload.len()).expect("checked by previous branch");
out[2..4].copy_from_slice(&len.to_be_bytes());
4
} else {
out[1] = u8::try_from(self.payload.len()).expect("checked by previous branch");
2
}
}
}
impl From<Message> for Frame {
fn from(value: Message) -> Self {
Self {
opcode: value.opcode,
is_final: true,
payload: value.payload,
}
}
}
impl From<&ProtocolError> for Frame {
fn from(val: &ProtocolError) -> Self {
match val {
ProtocolError::InvalidUtf8 => {
Message::close(Some(CloseCode::INVALID_FRAME_PAYLOAD_DATA), "invalid utf8")
}
_ => Message::close(Some(CloseCode::PROTOCOL_ERROR), val.as_str()),
}
.into()
}
}