use crate::error::PacketError;
use bytes::Bytes;
use bytestring::ByteString;
use miette::Diagnostic;
use serde::Deserialize;
use serde_json::{Map, Value};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug)]
pub struct Ns<T>(pub ByteString, pub T);
#[derive(Debug, Deserialize)]
pub struct Connect {
pub sid: ByteString,
#[serde(flatten)]
pub extra: Map<String, Value>,
}
#[derive(Debug, Error, Diagnostic, Deserialize)]
#[error("{message}")]
#[diagnostic(
code(sioc::connect_error),
help(
"Server rejected the namespace connection. Verify your auth payload and server middleware."
),
url("https://socket.io/docs/v4/socket-io-protocol/#connection-to-a-namespace")
)]
pub struct ConnectError {
pub message: ByteString,
#[serde(flatten)]
pub extra: Map<String, Value>,
}
#[derive(Debug, Clone)]
pub struct DynEvent {
pub payload: ByteString,
pub id: Option<u64>,
pub attachments: Option<Vec<Bytes>>,
}
impl std::fmt::Display for DynEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut map = f.debug_map();
map.entry(&"payload", &format_args!("{}", self.payload));
if let Some(id) = self.id {
map.entry(&"id", &id);
}
if let Some(attachments) = &self.attachments {
map.entry(&"count", &attachments.len());
}
map.finish()
}
}
impl DynEvent {
pub fn new<T>(payload: T, id: Option<u64>) -> Self
where
T: Into<ByteString>,
{
Self {
payload: payload.into(),
id,
attachments: None,
}
}
pub fn with_attachments(mut self, attachments: Vec<Bytes>) -> Self {
self.attachments = Some(attachments);
self
}
}
#[derive(Debug, Clone)]
pub struct DynAck {
pub payload: ByteString,
pub attachments: Option<Vec<Bytes>>,
}
impl std::fmt::Display for DynAck {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut map = f.debug_map();
map.entry(&"payload", &format_args!("{}", self.payload));
if let Some(attachments) = &self.attachments {
map.entry(&"count", &attachments.len());
}
map.finish()
}
}
impl DynAck {
pub fn new<T>(payload: T) -> Self
where
T: Into<ByteString>,
{
Self {
payload: payload.into(),
attachments: None,
}
}
pub fn with_attachments(mut self, attachments: Vec<Bytes>) -> Self {
self.attachments = Some(attachments);
self
}
}
#[derive(Debug)]
pub enum Signal<E = DynEvent> {
Connect(Connect),
Disconnect,
ConnectError(ConnectError),
Event(E),
}
impl<E> std::fmt::Display for Signal<E>
where
E: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connect(c) => f
.debug_tuple("Connect")
.field(&format_args!("{}", c.sid))
.finish(),
Self::Disconnect => f.write_str("Disconnect"),
Self::ConnectError(e) => f
.debug_tuple("ConnectError")
.field(&format_args!("{}", e))
.finish(),
Self::Event(e) => f
.debug_tuple("Event")
.field(&format_args!("{}", e))
.finish(),
}
}
}
impl<E> Signal<E> {
pub fn take_event(self) -> Option<E> {
match self {
Self::Event(e) => Some(e),
_ => None,
}
}
pub fn and_then<F, T>(self, f: F) -> Option<T>
where
F: FnOnce(E) -> Option<T>,
{
match self {
Self::Event(e) => f(e),
_ => None,
}
}
pub fn map<F, U>(self, f: F) -> Signal<U>
where
F: FnOnce(E) -> U,
{
match self {
Self::Connect(c) => Signal::Connect(c),
Self::Disconnect => Signal::Disconnect,
Self::ConnectError(e) => Signal::ConnectError(e),
Self::Event(e) => Signal::Event(f(e)),
}
}
}
#[derive(Debug)]
pub enum Directive {
Connect {
tx: mpsc::Sender<Signal>,
payload: ByteString,
},
Disconnect,
Event {
payload: ByteString,
tx: Option<oneshot::Sender<DynAck>>,
attachments: Option<Vec<Bytes>>,
},
Ack {
payload: ByteString,
id: u64,
attachments: Option<Vec<Bytes>>,
},
Dropped,
}
#[derive(Debug)]
pub enum Packet {
Connect(ByteString),
Disconnect,
Event {
payload: ByteString,
id: Option<u64>,
},
Ack { payload: ByteString, id: u64 },
ConnectError(ByteString),
BinaryEvent {
payload: ByteString,
id: Option<u64>,
count: usize,
},
BinaryAck {
payload: ByteString,
id: u64,
count: usize,
},
}
impl Packet {
pub fn size_hint(&self, ns: &str) -> usize {
match self {
Self::Connect(payload) => hint_packet_size(ns, false, false, Some(payload)),
Self::Disconnect => hint_packet_size(ns, false, false, None),
Self::Event { payload, id } => hint_packet_size(ns, false, id.is_some(), Some(payload)),
Self::Ack { payload, .. } => hint_packet_size(ns, false, true, Some(payload)),
Self::ConnectError(payload) => hint_packet_size(ns, false, false, Some(payload)),
Self::BinaryEvent { payload, id, .. } => {
hint_packet_size(ns, true, id.is_some(), Some(payload))
}
Self::BinaryAck { payload, .. } => hint_packet_size(ns, true, true, Some(payload)),
}
}
pub fn encode(&self, ns: &str) -> String {
let mut buffer = String::with_capacity(self.size_hint(ns));
match self {
Self::Connect(bytes) => write_packet(&mut buffer, b'0', None, ns, None, Some(bytes)),
Self::Disconnect => write_packet(&mut buffer, b'1', None, ns, None, None),
Self::Event { payload, id } => {
write_packet(&mut buffer, b'2', None, ns, *id, Some(payload))
}
Self::Ack { payload, id } => {
write_packet(&mut buffer, b'3', None, ns, Some(*id), Some(payload))
}
Self::ConnectError(payload) => {
write_packet(&mut buffer, b'4', None, ns, None, Some(payload))
}
Self::BinaryEvent { payload, id, count } => {
write_packet(&mut buffer, b'5', Some(*count), ns, *id, Some(payload))
}
Self::BinaryAck { payload, id, count } => write_packet(
&mut buffer,
b'6',
Some(*count),
ns,
Some(*id),
Some(payload),
),
}
buffer
}
}
impl std::fmt::Display for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connect(payload) => f
.debug_tuple("Connect")
.field(&format_args!("{}", payload))
.finish(),
Self::Disconnect => f.write_str("Disconnect"),
Self::Event { payload, id } => {
let mut s = f.debug_struct("Event");
s.field("payload", &format_args!("{}", payload));
if let Some(id) = id {
s.field("id", id);
}
s.finish()
}
Self::Ack { payload, id } => f
.debug_struct("Ack")
.field("payload", &format_args!("{}", payload))
.field("id", id)
.finish(),
Self::ConnectError(payload) => f
.debug_tuple("ConnectError")
.field(&format_args!("{}", payload))
.finish(),
Self::BinaryEvent { payload, id, count } => {
let mut s = f.debug_struct("BinaryEvent");
s.field("payload", &format_args!("{}", payload));
if let Some(id) = id {
s.field("id", id);
}
s.field("count", count);
s.finish()
}
Self::BinaryAck { payload, id, count } => f
.debug_struct("BinaryAck")
.field("payload", &format_args!("{}", payload))
.field("id", id)
.field("count", count)
.finish(),
}
}
}
impl Ns<Packet> {
pub fn size_hint(&self) -> usize {
self.1.size_hint(&self.0)
}
pub fn encode(&self) -> String {
self.1.encode(&self.0)
}
}
impl TryFrom<ByteString> for Ns<Packet> {
type Error = PacketError;
fn try_from(bytes: ByteString) -> Result<Self, PacketError> {
let mut chars = bytes.chars();
let id = chars.next().ok_or(PacketError::Empty)?;
let bytes = bytes.slice_ref(chars.as_str());
let packet = match id {
'0' => {
let (ns, payload) = split_namespace(bytes)?;
Ns(ns, Packet::Connect(payload))
}
'1' => {
let (ns, _) = split_namespace(bytes)?;
Ns(ns, Packet::Disconnect)
}
'2' => {
let (count, bytes) = split_attachments(bytes)?;
if let Some(count) = count {
return Err(PacketError::UnexpectedAttachments { count });
}
let (ns, bytes) = split_namespace(bytes)?;
let (id, payload) = split_id(bytes)?;
Ns(ns, Packet::Event { payload, id })
}
'3' => {
let (ns, bytes) = split_namespace(bytes)?;
let (id, payload) = split_id(bytes)?;
let id = id.ok_or(PacketError::MissingAckId)?;
Ns(ns, Packet::Ack { payload, id })
}
'4' => {
let (ns, payload) = split_namespace(bytes)?;
Ns(ns, Packet::ConnectError(payload))
}
'5' => {
let (count, bytes) = split_attachments(bytes)?;
let count = count.ok_or(PacketError::MissingAttachmentCount)?;
let (ns, bytes) = split_namespace(bytes)?;
let (id, payload) = split_id(bytes)?;
Ns(ns, Packet::BinaryEvent { payload, id, count })
}
'6' => {
let (count, bytes) = split_attachments(bytes)?;
let count = count.ok_or(PacketError::MissingAttachmentCount)?;
let (ns, bytes) = split_namespace(bytes)?;
let (id, payload) = split_id(bytes)?;
let id = id.ok_or(PacketError::MissingAckId)?;
Ns(ns, Packet::BinaryAck { payload, id, count })
}
id => return Err(PacketError::InvalidId { id }),
};
Ok(packet)
}
}
const U64_MAX_LEN: usize = 20;
const fn ack_size_hint() -> usize {
U64_MAX_LEN
}
const fn binary_size_hint() -> usize {
U64_MAX_LEN + 1
}
fn namespace_size(ns: &str) -> usize {
if ns == "/" { 0 } else { ns.len() + 1 }
}
pub fn hint_packet_size(ns: &str, binary: bool, ack: bool, payload: Option<&str>) -> usize {
let mut n = 1 + namespace_size(ns);
if ack {
n += ack_size_hint();
}
if binary {
n += binary_size_hint();
}
if let Some(payload) = payload {
n += payload.len();
}
n
}
fn write_attachments(buffer: &mut String, count: usize) {
buffer.push_str(&count.to_string());
buffer.push('-');
}
fn write_namespace(buffer: &mut String, ns: &str) {
if ns != "/" {
buffer.push_str(ns);
buffer.push(',');
}
}
fn write_id(buffer: &mut String, id: u64) {
buffer.push_str(&id.to_string());
}
fn write_payload(buffer: &mut String, payload: &str) {
buffer.push_str(payload);
}
pub fn write_packet(
buffer: &mut String,
type_id: u8,
count: Option<usize>,
ns: &str,
id: Option<u64>,
payload: Option<&str>,
) {
buffer.push(type_id as char);
if let Some(count) = count {
write_attachments(buffer, count);
}
write_namespace(buffer, ns);
if let Some(id) = id {
write_id(buffer, id);
}
if let Some(payload) = payload {
write_payload(buffer, payload);
}
}
pub fn split_attachments(bytes: ByteString) -> Result<(Option<usize>, ByteString), PacketError> {
let pair = match bytes.char_indices().find(|(_, c)| !c.is_ascii_digit()) {
Some((i, '-')) => {
let count = bytes[..i]
.parse()
.map_err(PacketError::InvalidAttachmentCount)?;
let rest = bytes.slice_ref(&bytes[i + 1..]);
(Some(count), rest)
}
_ => (None, bytes),
};
Ok(pair)
}
pub fn split_namespace(bytes: ByteString) -> Result<(ByteString, ByteString), PacketError> {
match bytes.chars().next() {
Some('/') => match bytes.split_once(',') {
Some((ns, payload)) => Ok((bytes.slice_ref(ns), bytes.slice_ref(payload))),
None => Err(PacketError::MissingNamespaceDelimiter),
},
_ => Ok((ByteString::from_static("/"), bytes)),
}
}
pub fn split_id(bytes: ByteString) -> Result<(Option<u64>, ByteString), PacketError> {
let pair = match bytes.char_indices().find(|(_, c)| !c.is_ascii_digit()) {
Some((i, _)) if i > 0 => {
let id = bytes[..i].parse().map_err(PacketError::InvalidAckId)?;
let rest = bytes.slice_ref(&bytes[i..]);
(Some(id), rest)
}
_ => (None, bytes),
};
Ok(pair)
}