use std::{fmt, mem::replace, num::NonZeroU16, ops::Deref};
use bytes::{BufMut, Bytes, BytesMut};
use super::error::ProtocolError;
use crate::utf8;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(super) enum OpCode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl OpCode {
pub(super) fn is_control(self) -> bool {
matches!(self, Self::Close | Self::Ping | Self::Pong)
}
}
impl TryFrom<u8> for OpCode {
type Error = ProtocolError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Continuation),
1 => Ok(Self::Text),
2 => Ok(Self::Binary),
8 => Ok(Self::Close),
9 => Ok(Self::Ping),
10 => Ok(Self::Pong),
_ => Err(ProtocolError::InvalidOpcode),
}
}
}
impl From<OpCode> for u8 {
fn from(value: OpCode) -> Self {
match value {
OpCode::Continuation => 0,
OpCode::Text => 1,
OpCode::Binary => 2,
OpCode::Close => 8,
OpCode::Ping => 9,
OpCode::Pong => 10,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct CloseCode(NonZeroU16);
#[rustfmt::skip]
impl CloseCode {
pub const NORMAL_CLOSURE: Self = Self::constant(1000);
pub const GOING_AWAY: Self = Self::constant(1001);
pub const PROTOCOL_ERROR: Self = Self::constant(1002);
pub const UNSUPPORTED_DATA: Self = Self::constant(1003);
pub const NO_STATUS_RECEIVED: Self = Self::constant(1005);
pub const INVALID_FRAME_PAYLOAD_DATA: Self = Self::constant(1007);
pub const POLICY_VIOLATION: Self = Self::constant(1008);
pub const MESSAGE_TOO_BIG: Self = Self::constant(1009);
pub const MANDATORY_EXTENSION: Self = Self::constant(1010);
pub const INTERNAL_SERVER_ERROR: Self = Self::constant(1011);
pub const SERVICE_RESTART: Self = Self::constant(1012);
pub const SERVICE_OVERLOAD: Self = Self::constant(1013);
pub const BAD_GATEWAY: Self = Self::constant(1014);
}
impl CloseCode {
const fn try_from_u16(code: u16) -> Option<Self> {
match code {
1000..=1015 | 3000..=4999 => {
match NonZeroU16::new(code) {
Some(code) => Some(Self(code)),
None => unreachable!(),
}
}
0..=999 | 1016..=2999 | 5000..=u16::MAX => None,
}
}
const fn constant(code: u16) -> Self {
match Self::try_from_u16(code) {
Some(code) => code,
None => unreachable!(),
}
}
#[must_use]
pub fn is_reserved(self) -> bool {
match self.0.get() {
1004 | 1005 | 1006 | 1015 => true,
1000..=4999 => false,
0..=999 | 5000..=u16::MAX => {
debug_assert!(false, "unexpected CloseCode");
false
}
}
}
}
impl From<CloseCode> for u16 {
fn from(value: CloseCode) -> Self {
value.0.get()
}
}
impl TryFrom<u16> for CloseCode {
type Error = ProtocolError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
Self::try_from_u16(value).ok_or(ProtocolError::InvalidCloseCode)
}
}
#[derive(Clone)]
pub struct Payload {
data: Bytes,
utf8_validated: bool,
}
impl Payload {
const fn from_static(bytes: &'static [u8]) -> Self {
Self {
data: Bytes::from_static(bytes),
utf8_validated: false,
}
}
pub(super) fn set_utf8_validated(&mut self, value: bool) {
self.utf8_validated = value;
}
pub(super) fn truncate(&mut self, len: usize) {
self.data.truncate(len);
}
fn split_to(&mut self, at: usize) -> Self {
self.utf8_validated = false;
Self {
data: self.data.split_to(at),
utf8_validated: false,
}
}
}
impl Deref for Payload {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl fmt::Debug for Payload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Payload").field(&self.data).finish()
}
}
impl From<Bytes> for Payload {
fn from(value: Bytes) -> Self {
Self {
data: value,
utf8_validated: false,
}
}
}
impl From<BytesMut> for Payload {
fn from(value: BytesMut) -> Self {
Self {
data: value.freeze(),
utf8_validated: false,
}
}
}
impl From<Payload> for Bytes {
fn from(value: Payload) -> Self {
value.data
}
}
impl From<Payload> for BytesMut {
fn from(value: Payload) -> Self {
value.data.into()
}
}
impl From<Vec<u8>> for Payload {
fn from(value: Vec<u8>) -> Self {
Self {
data: BytesMut::from_iter(value).freeze(),
utf8_validated: false,
}
}
}
impl From<String> for Payload {
fn from(value: String) -> Self {
Self {
data: BytesMut::from_iter(value.into_bytes()).freeze(),
utf8_validated: true,
}
}
}
impl From<&'static [u8]> for Payload {
fn from(value: &'static [u8]) -> Self {
Self {
data: Bytes::from_static(value),
utf8_validated: false,
}
}
}
impl From<&'static str> for Payload {
fn from(value: &'static str) -> Self {
Self {
data: Bytes::from_static(value.as_bytes()),
utf8_validated: true,
}
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub(super) opcode: OpCode,
pub(super) payload: Payload,
}
impl Message {
#[must_use]
pub fn text<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Text,
payload: payload.into(),
}
}
#[must_use]
pub fn binary<P: Into<Payload>>(payload: P) -> Self {
Self {
opcode: OpCode::Binary,
payload: payload.into(),
}
}
#[must_use]
#[track_caller]
pub fn close(code: Option<CloseCode>, reason: &str) -> Self {
let mut payload = BytesMut::with_capacity((2 + reason.len()) * usize::from(code.is_some()));
if let Some(code) = code {
assert!(!code.is_reserved());
payload.put_u16(code.into());
assert!(reason.len() <= 123);
payload.extend_from_slice(reason.as_bytes());
}
Self {
opcode: OpCode::Close,
payload: payload.into(),
}
}
#[must_use]
#[track_caller]
pub fn ping<P: Into<Payload>>(payload: P) -> Self {
let payload = payload.into();
assert!(payload.len() <= 125);
Self {
opcode: OpCode::Ping,
payload,
}
}
#[must_use]
#[track_caller]
pub fn pong<P: Into<Payload>>(payload: P) -> Self {
let payload = payload.into();
assert!(payload.len() <= 125);
Self {
opcode: OpCode::Pong,
payload,
}
}
#[must_use]
pub fn is_text(&self) -> bool {
self.opcode == OpCode::Text
}
#[must_use]
pub fn is_binary(&self) -> bool {
self.opcode == OpCode::Binary
}
#[must_use]
pub fn is_close(&self) -> bool {
self.opcode == OpCode::Close
}
#[must_use]
pub fn is_ping(&self) -> bool {
self.opcode == OpCode::Ping
}
#[must_use]
pub fn is_pong(&self) -> bool {
self.opcode == OpCode::Pong
}
#[must_use]
pub fn into_payload(self) -> Payload {
self.payload
}
pub fn as_payload(&self) -> &Payload {
&self.payload
}
pub fn as_text(&self) -> Option<&str> {
(self.opcode == OpCode::Text).then(|| {
assert!(
self.payload.utf8_validated || utf8::parse_str(&self.payload).is_ok(),
"called as_text on message created from payload with invalid utf-8"
);
unsafe { std::str::from_utf8_unchecked(&self.payload) }
})
}
pub fn as_close(&self) -> Option<(CloseCode, &str)> {
(self.opcode == OpCode::Close).then(|| {
let code = if self.payload.is_empty() {
CloseCode::NO_STATUS_RECEIVED
} else {
unsafe {
CloseCode::try_from(u16::from_be_bytes(
self.payload
.get_unchecked(0..2)
.try_into()
.unwrap_unchecked(),
))
.unwrap_unchecked()
}
};
let reason =
unsafe { std::str::from_utf8_unchecked(self.payload.get(2..).unwrap_or_default()) };
(code, reason)
})
}
pub(super) fn into_frames(self, frame_size: usize) -> MessageFrames {
MessageFrames {
frame_size,
payload: self.payload,
opcode: self.opcode,
}
}
}
pub(super) struct MessageFrames {
frame_size: usize,
payload: Payload,
opcode: OpCode,
}
impl Iterator for MessageFrames {
type Item = Frame;
fn next(&mut self) -> Option<Self::Item> {
let is_empty = self.payload.is_empty() && self.opcode == OpCode::Continuation;
(!is_empty).then(|| {
let payload = self
.payload
.split_to(self.frame_size.min(self.payload.len()));
Frame {
opcode: replace(&mut self.opcode, OpCode::Continuation),
is_final: self.payload.is_empty(),
payload,
}
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct Limits {
pub(super) max_payload_len: usize,
}
impl Limits {
#[must_use]
pub fn unlimited() -> Self {
Self {
max_payload_len: usize::MAX,
}
}
#[must_use]
pub fn max_payload_len(mut self, size: Option<usize>) -> Self {
self.set_max_payload_len(size);
self
}
pub fn set_max_payload_len(&mut self, size: Option<usize>) {
self.max_payload_len = size.unwrap_or(usize::MAX);
}
}
impl Default for Limits {
fn default() -> Self {
Self {
max_payload_len: 64 * 1024 * 1024,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Config {
pub(super) frame_size: usize,
pub(super) flush_threshold: usize,
}
impl Config {
#[must_use]
pub fn frame_size(mut self, frame_size: usize) -> Self {
assert_ne!(frame_size, 0, "frame_size must be non-zero");
self.frame_size = frame_size;
self
}
#[must_use]
pub fn flush_threshold(mut self, threshold: usize) -> Self {
self.flush_threshold = threshold;
self
}
}
impl Default for Config {
fn default() -> Self {
Self {
frame_size: 4 * 1024 * 1024,
flush_threshold: 8 * 1024,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum Role {
Client,
Server,
}
#[derive(Debug, PartialEq)]
pub(super) enum StreamState {
Active,
ClosedByPeer,
ClosedByUs,
CloseAcknowledged,
}
#[derive(Clone, Debug)]
pub(super) struct Frame {
pub opcode: OpCode,
pub is_final: bool,
pub payload: Payload,
}
impl Frame {
#[allow(clippy::declare_interior_mutable_const)]
pub const DEFAULT_CLOSE: Self = Self {
opcode: OpCode::Close,
is_final: true,
payload: Payload::from_static(&CloseCode::NORMAL_CLOSURE.0.get().to_be_bytes()),
};
pub fn encode<'a>(&self, out: &'a mut [u8; 14]) -> &'a mut [u8; 4] {
out[0] = (u8::from(self.is_final) << 7) | u8::from(self.opcode);
let mask_slice = if u16::try_from(self.payload.len()).is_err() {
out[1] = 127;
let len = u64::try_from(self.payload.len()).unwrap();
out[2..10].copy_from_slice(&len.to_be_bytes());
&mut out[10..14]
} else if self.payload.len() > 125 {
out[1] = 126;
let len = u16::try_from(self.payload.len()).expect("checked by previous branch");
out[2..4].copy_from_slice(&len.to_be_bytes());
&mut out[4..8]
} else {
out[1] = u8::try_from(self.payload.len()).expect("checked by previous branch");
&mut out[2..6]
};
mask_slice.try_into().unwrap()
}
}
impl From<Message> for Frame {
fn from(value: Message) -> Self {
Self {
opcode: value.opcode,
is_final: true,
payload: value.payload,
}
}
}
impl From<&ProtocolError> for Frame {
fn from(val: &ProtocolError) -> Self {
match val {
ProtocolError::InvalidUtf8 => {
Message::close(Some(CloseCode::INVALID_FRAME_PAYLOAD_DATA), "invalid utf8")
}
_ => Message::close(Some(CloseCode::PROTOCOL_ERROR), val.as_str()),
}
.into()
}
}