use core::cell::{Cell, RefCell};
use core::fmt;
use core::marker::PhantomData;
use core::mem;
use core::ops::Deref;
use core::ptr::NonNull;
use std::collections::VecDeque;
use alloc::boxed::Box;
use alloc::format;
use alloc::rc::Rc;
use alloc::rc::Weak;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use std::collections::hash_map::{Entry, HashMap};
use musli::Decode;
use musli::alloc::Global;
use musli::mode::Binary;
use musli::reader::SliceReader;
use musli::storage;
use slab::Slab;
use crate::api::{self, ChannelId, Event, MessageId};
const MAX_CAPACITY: usize = 1048576;
#[non_exhaustive]
pub struct EmptyBody;
#[non_exhaustive]
pub struct EmptyCallback;
trait RequestCallback {
fn error(&self, error: Error);
fn as_request(&self) -> Option<&(dyn Callback<Result<RawPacket, Error>> + 'static)> {
None
}
fn as_channel(&self) -> Option<&(dyn Fn(Result<ChannelId, Error>) + 'static)> {
None
}
}
type StateListeners = Slab<Rc<dyn Callback<State>>>;
type Broadcasts = HashMap<MessageId, Slab<Rc<dyn Callback<Result<RawPacket>>>>>;
type Buffers = VecDeque<Box<BufData>>;
type Requests = HashMap<u32, Box<Pending<dyn RequestCallback>>>;
#[doc(hidden)]
pub struct Location {
pub(crate) protocol: String,
pub(crate) host: String,
pub(crate) port: String,
}
pub(crate) mod sealed_socket {
pub trait Sealed {}
}
pub(crate) trait SocketImpl
where
Self: Sized + self::sealed_socket::Sealed,
{
#[doc(hidden)]
type Handles;
#[doc(hidden)]
fn new(url: &str, handles: &Self::Handles) -> Result<Self, Error>;
#[doc(hidden)]
fn send(&self, data: &[u8]) -> Result<(), Error>;
#[doc(hidden)]
fn close(self) -> Result<(), Error>;
}
pub(crate) mod sealed_performance {
pub trait Sealed {}
}
pub trait PerformanceImpl
where
Self: Sized + self::sealed_performance::Sealed,
{
#[doc(hidden)]
fn now(&self) -> f64;
}
pub(crate) mod sealed_window {
pub trait Sealed {}
}
pub(crate) trait WindowImpl
where
Self: Sized + self::sealed_window::Sealed,
{
#[doc(hidden)]
type Performance: PerformanceImpl;
#[doc(hidden)]
type Timeout;
#[doc(hidden)]
fn new() -> Result<Self, Error>;
#[doc(hidden)]
fn performance(&self) -> Result<Self::Performance, Error>;
#[doc(hidden)]
fn location(&self) -> Result<Location, Error>;
#[doc(hidden)]
fn set_timeout(
&self,
millis: u32,
callback: impl FnOnce() + 'static,
) -> Result<Self::Timeout, Error>;
}
pub(crate) mod sealed_web {
pub trait Sealed {}
}
pub trait WebImpl
where
Self: 'static + Copy + Sized + self::sealed_web::Sealed,
{
#[doc(hidden)]
#[allow(private_bounds)]
type Window: WindowImpl;
#[doc(hidden)]
type Handles;
#[doc(hidden)]
#[allow(private_bounds)]
type Socket: SocketImpl<Handles = Self::Handles>;
#[doc(hidden)]
#[allow(private_interfaces)]
fn handles(shared: &Weak<Shared<Self>>) -> Self::Handles;
#[doc(hidden)]
fn random(range: u32) -> u32;
}
pub fn connect<H>(connect: Connect) -> ServiceBuilder<H, EmptyCallback>
where
H: WebImpl,
{
ServiceBuilder {
connect,
on_error: EmptyCallback,
_marker: PhantomData,
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[non_exhaustive]
pub enum State {
Open,
Closed,
}
impl State {
pub fn is_open(&self) -> bool {
matches!(self, Self::Open)
}
}
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
}
impl Error {
pub fn is_empty_packet(&self) -> bool {
matches!(self.kind, ErrorKind::EmptyPacket)
}
pub fn message(message: impl fmt::Display) -> Self {
Self::new(ErrorKind::Message(message.to_string()))
}
}
#[derive(Debug)]
enum ErrorKind {
EmptyPacket,
Message(String),
DecodeResponseHeader(storage::Error),
DecodeErrorMessage(storage::Error),
DecodePacket(storage::Error),
EncodingHeader(storage::Error),
EncodingBody(storage::Error),
Overflow(usize, usize),
}
impl Error {
#[inline]
fn new(kind: ErrorKind) -> Self {
Self { kind }
}
#[inline]
pub(crate) fn decode_response_header(error: storage::Error) -> Self {
Self::new(ErrorKind::DecodeResponseHeader(error))
}
#[inline]
pub(crate) fn decode_error_message(error: storage::Error) -> Self {
Self::new(ErrorKind::DecodeErrorMessage(error))
}
#[inline]
pub(crate) fn decode_packet(error: storage::Error) -> Self {
Self::new(ErrorKind::DecodePacket(error))
}
#[inline]
pub(crate) fn encoding_header(error: storage::Error) -> Self {
Self::new(ErrorKind::EncodingHeader(error))
}
#[inline]
pub(crate) fn encoding_body(error: storage::Error) -> Self {
Self::new(ErrorKind::EncodingBody(error))
}
#[inline]
pub(crate) fn msg(message: impl fmt::Display) -> Self {
Self::new(ErrorKind::Message(message.to_string()))
}
}
impl fmt::Display for Error {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
ErrorKind::EmptyPacket => write!(f, "Packet is empty"),
ErrorKind::Message(message) => write!(f, "{message}"),
ErrorKind::DecodeResponseHeader(..) => {
write!(f, "Encoding error when decoding response header")
}
ErrorKind::DecodeErrorMessage(..) => {
write!(f, "Encoding error when decoding error response")
}
ErrorKind::DecodePacket(..) => write!(f, "Encoding error when decoding packet"),
ErrorKind::EncodingHeader(..) => write!(f, "Encoding error when encoding header"),
ErrorKind::EncodingBody(..) => write!(f, "Encoding error when encoding body"),
ErrorKind::Overflow(at, len) => {
write!(f, "Internal packet overflow, {at} not in range 0-{len}")
}
}
}
}
impl core::error::Error for Error {
#[inline]
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match &self.kind {
ErrorKind::DecodeResponseHeader(error) => Some(error),
ErrorKind::DecodeErrorMessage(error) => Some(error),
ErrorKind::DecodePacket(error) => Some(error),
ErrorKind::EncodingHeader(error) => Some(error),
ErrorKind::EncodingBody(error) => Some(error),
_ => None,
}
}
}
#[cfg(feature = "wasm_bindgen02")]
impl From<wasm_bindgen02::JsValue> for Error {
#[inline]
fn from(error: wasm_bindgen02::JsValue) -> Self {
Self::new(ErrorKind::Message(format!("{error:?}")))
}
}
type Result<T, E = Error> = core::result::Result<T, E>;
const INITIAL_TIMEOUT: u32 = 250;
const MAX_TIMEOUT: u32 = 4000;
#[derive(Debug)]
enum ConnectKind {
Location { path: String },
Url { url: String },
}
pub struct Connect {
kind: ConnectKind,
}
impl Connect {
#[inline]
pub fn location(path: impl AsRef<str>) -> Self {
Self {
kind: ConnectKind::Location {
path: String::from(path.as_ref()),
},
}
}
#[inline]
pub fn url(url: String) -> Self {
Self {
kind: ConnectKind::Url { url },
}
}
}
impl fmt::Debug for Connect {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
struct Generic {
state_listeners: RefCell<StateListeners>,
requests: RefCell<Requests>,
broadcasts: RefCell<Broadcasts>,
buffers: RefCell<Buffers>,
}
pub(crate) struct Shared<H>
where
H: WebImpl,
{
connect: Connect,
state: Cell<State>,
pub(crate) on_error: Box<dyn Callback<Error>>,
window: H::Window,
performance: <H::Window as WindowImpl>::Performance,
handles: H::Handles,
opened: Cell<Option<f64>>,
serial: Cell<u32>,
defer_broadcasts: RefCell<VecDeque<Weak<dyn Callback<Result<RawPacket>>>>>,
defer_state_listeners: RefCell<VecDeque<Weak<dyn Callback<State>>>>,
pub(crate) socket: RefCell<Option<H::Socket>>,
output: RefCell<Vec<u8>>,
current_timeout: Cell<u32>,
reconnect_timeout: RefCell<Option<<H::Window as WindowImpl>::Timeout>>,
g: Rc<Generic>,
}
impl<H> Drop for Shared<H>
where
H: WebImpl,
{
fn drop(&mut self) {
if let Some(s) = self.socket.take()
&& let Err(e) = s.close()
{
self.on_error.call(e);
}
let state_listeners = mem::take(&mut *self.g.state_listeners.borrow_mut());
let mut requests = self.g.requests.borrow_mut();
for (_, listener) in state_listeners {
listener.call(State::Closed);
}
for (_, p) in requests.drain() {
p.callback.error(Error::msg("Websocket service closed"));
}
}
}
pub struct ServiceBuilder<H, E>
where
H: WebImpl,
{
connect: Connect,
on_error: E,
_marker: PhantomData<H>,
}
impl<H, E> ServiceBuilder<H, E>
where
H: WebImpl,
E: Callback<Error>,
{
pub fn on_error<U>(self, on_error: U) -> ServiceBuilder<H, U>
where
U: Callback<Error>,
{
ServiceBuilder {
connect: self.connect,
on_error,
_marker: self._marker,
}
}
pub fn build(self) -> Service<H> {
let window = match H::Window::new() {
Ok(window) => window,
Err(error) => {
panic!("{error}")
}
};
let performance = match WindowImpl::performance(&window) {
Ok(performance) => performance,
Err(error) => {
panic!("{error}")
}
};
let shared = Rc::<Shared<H>>::new_cyclic(move |shared| Shared {
connect: self.connect,
state: Cell::new(State::Closed),
on_error: Box::new(self.on_error),
window,
performance,
handles: H::handles(shared),
opened: Cell::new(None),
serial: Cell::new(0),
defer_broadcasts: RefCell::new(VecDeque::new()),
defer_state_listeners: RefCell::new(VecDeque::new()),
socket: RefCell::new(None),
output: RefCell::new(Vec::new()),
current_timeout: Cell::new(INITIAL_TIMEOUT),
reconnect_timeout: RefCell::new(None),
g: Rc::new(Generic {
state_listeners: RefCell::new(Slab::new()),
broadcasts: RefCell::new(HashMap::new()),
requests: RefCell::new(Requests::new()),
buffers: RefCell::new(VecDeque::new()),
}),
});
let handle = Handle {
shared: Rc::downgrade(&shared),
};
Service { shared, handle }
}
}
pub struct Service<H>
where
H: WebImpl,
{
shared: Rc<Shared<H>>,
handle: Handle<H>,
}
impl<H> Service<H>
where
H: WebImpl,
{
pub fn connect(&self) {
self.shared.connect()
}
pub fn handle(&self) -> &Handle<H> {
&self.handle
}
}
impl<H> Shared<H>
where
H: WebImpl,
{
fn send_client_request<T>(&self, serial: u32, channel: ChannelId, body: &T) -> Result<()>
where
T: api::Request,
{
let Some(ref socket) = *self.socket.borrow() else {
return Err(Error::msg("Socket is not connected"));
};
let header = api::RequestHeader {
serial,
id: <T::Endpoint as api::Endpoint>::ID.get(),
channel,
};
let out = &mut *self.output.borrow_mut();
storage::to_writer(&mut *out, &header).map_err(Error::encoding_header)?;
storage::to_writer(&mut *out, &body).map_err(Error::encoding_body)?;
tracing::debug!(
header.serial,
?header.id,
len = out.len(),
"Sending request"
);
socket.send(out.as_slice())?;
out.clear();
out.shrink_to(MAX_CAPACITY);
Ok(())
}
fn send_connect(&self, serial: u32) -> Result<()> {
let Some(ref socket) = *self.socket.borrow() else {
return Err(Error::msg("Socket is not connected"));
};
let header = api::RequestHeader {
serial,
id: MessageId::CONNECT.get(),
channel: ChannelId::NONE,
};
let out = &mut *self.output.borrow_mut();
storage::to_writer(&mut *out, &header).map_err(Error::encoding_header)?;
tracing::debug!(
header.serial,
?header.id,
len = out.len(),
"Sending request"
);
socket.send(out.as_slice())?;
out.clear();
out.shrink_to(MAX_CAPACITY);
Ok(())
}
fn remove_channel(&self, channel: ChannelId) {
if let Err(error) = self._send_disconnect(channel) {
self.on_error.call(error);
}
}
fn _send_disconnect(&self, channel: ChannelId) -> Result<()> {
let Some(ref socket) = *self.socket.borrow() else {
return Err(Error::msg("Socket is not connected"));
};
let header = api::RequestHeader {
serial: 0,
id: MessageId::DISCONNECT.get(),
channel,
};
let out = &mut *self.output.borrow_mut();
storage::to_writer(&mut *out, &header).map_err(Error::encoding_header)?;
tracing::debug!(
header.serial,
?header.id,
len = out.len(),
"Sending request"
);
socket.send(out.as_slice())?;
out.clear();
out.shrink_to(MAX_CAPACITY);
Ok(())
}
pub(crate) fn next_buffer(self: &Rc<Self>, needed: usize) -> Box<BufData> {
match self.g.buffers.borrow_mut().pop_front() {
Some(mut buf) => {
if buf.data.capacity() < needed {
buf.data.reserve(needed - buf.data.len());
}
buf
}
None => Box::new(BufData::with_capacity(Rc::downgrade(&self.g), needed)),
}
}
pub(crate) fn message(self: &Rc<Self>, buf: Box<BufData>) -> Result<()> {
let buf = BufRc::new(buf);
let mut reader = SliceReader::new(&buf);
let header: api::ResponseHeader =
storage::decode(&mut reader).map_err(Error::decode_response_header)?;
if let Some(broadcast) = MessageId::new(header.broadcast) {
tracing::debug!(?header, "Got broadcast");
if broadcast == MessageId::SERVER_HELLO {
self.set_open();
return Ok(());
}
if !self.defer_broadcasts(broadcast) {
return Ok(());
};
if let Some(id) = MessageId::new(header.error) {
let error = match id {
MessageId::ERROR_MESSAGE => {
storage::decode(&mut reader).map_err(Error::decode_error_message)?
}
_ => api::ErrorMessage {
message: "Unsupported broadcast",
},
};
while let Some(callback) = self.defer_broadcasts.borrow_mut().pop_front() {
if let Some(callback) = callback.upgrade() {
callback.call(Err(Error::msg(format_args!(
"Server error: {}",
error.message
))));
}
}
return Ok(());
}
let at = buf.len().saturating_sub(reader.remaining());
let packet = RawPacket {
buf: Some(buf.clone()),
at: Cell::new(at),
id: broadcast,
channel: header.channel,
};
while let Some(callback) = self.defer_broadcasts.borrow_mut().pop_front() {
if let Some(callback) = callback.upgrade() {
callback.call(Ok(packet.clone()));
}
}
} else {
tracing::debug!(?header, "Got response");
let p = self.g.requests.borrow_mut().remove(&header.serial);
if let Some(p) = p {
if let Some(id) = MessageId::new(header.error) {
let error = match id {
MessageId::ERROR_MESSAGE => {
storage::decode(&mut reader).map_err(Error::decode_error_message)?
}
_ => api::ErrorMessage {
message: "Unsupported request",
},
};
p.callback
.error(Error::msg(format_args!("Server error: {}", error.message)));
return Ok(());
}
match p.id {
MessageId::CONNECT => {
let Some(callback) = &p.callback.as_channel() else {
p.callback.error(Error::msg("Unexpected channel response"));
return Ok(());
};
callback(Ok(header.channel));
return Ok(());
}
_ => {
let Some(callback) = p.callback.as_request() else {
p.callback.error(Error::msg("Unexpected channel response"));
return Ok(());
};
let at = buf.len().saturating_sub(reader.remaining());
let packet = RawPacket {
id: p.id,
buf: Some(buf),
at: Cell::new(at),
channel: header.channel,
};
callback.call(Ok(packet));
return Ok(());
}
}
}
tracing::warn!(?header.serial, "Got message with unknown serial");
}
Ok(())
}
fn defer_broadcasts(self: &Rc<Self>, kind: MessageId) -> bool {
let mut defer = self.defer_broadcasts.borrow_mut();
let broadcasts = self.g.broadcasts.borrow();
let Some(broadcasts) = broadcasts.get(&kind) else {
return false;
};
for (_, callback) in broadcasts.iter() {
defer.push_back(Rc::downgrade(callback));
}
!defer.is_empty()
}
pub(crate) fn set_open(&self) {
tracing::debug!("Set open");
self.opened
.set(Some(PerformanceImpl::now(&self.performance)));
self.emit_state_change(State::Open);
}
fn is_open_for_a_while(&self) -> bool {
let Some(at) = self.opened.get() else {
return false;
};
let now = PerformanceImpl::now(&self.performance);
(now - at) >= 250.0
}
pub(crate) fn close(self: &Rc<Self>) -> Result<(), Error> {
tracing::debug!("Close connection");
let shared = Rc::downgrade(self);
tracing::debug!(
"Set closed timeout={}, opened={:?}",
self.current_timeout.get(),
self.opened.get(),
);
if !self.is_open_for_a_while() {
let current_timeout = self.current_timeout.get();
if current_timeout < MAX_TIMEOUT {
let fuzz = H::random(50);
self.current_timeout.set(
current_timeout
.saturating_mul(2)
.saturating_add(fuzz)
.min(MAX_TIMEOUT),
);
}
} else {
self.current_timeout.set(INITIAL_TIMEOUT);
}
self.opened.set(None);
self.emit_state_change(State::Closed);
if let Some(s) = self.socket.take() {
s.close()?;
}
self.close_pending();
let timeout = self
.window
.set_timeout(self.current_timeout.get(), move || {
if let Some(shared) = shared.upgrade() {
Self::connect(&shared);
}
})?;
drop(self.reconnect_timeout.borrow_mut().replace(timeout));
Ok(())
}
fn close_pending(self: &Rc<Self>) {
loop {
let Some(serial) = self.g.requests.borrow().keys().next().copied() else {
break;
};
let p = {
let mut requests = self.g.requests.borrow_mut();
let Some(p) = requests.remove(&serial) else {
break;
};
p
};
p.callback.error(Error::msg("Connection closed"));
}
}
fn emit_state_change(&self, state: State) {
if self.state.get() == state {
return;
}
self.state.set(state);
{
let mut defer = self.defer_state_listeners.borrow_mut();
for (_, callback) in self.g.state_listeners.borrow().iter() {
defer.push_back(Rc::downgrade(callback));
}
if defer.is_empty() {
return;
}
}
while let Some(callback) = self.defer_state_listeners.borrow_mut().pop_front() {
if let Some(callback) = callback.upgrade() {
callback.call(state);
}
}
}
fn is_closed(&self) -> bool {
self.opened.get().is_none()
}
fn connect(self: &Rc<Self>) {
tracing::debug!("Connect");
if let Err(e) = self.build() {
self.on_error.call(e);
} else {
return;
}
if let Err(e) = self.close() {
self.on_error.call(e);
}
}
fn build(self: &Rc<Self>) -> Result<()> {
let url = match &self.connect.kind {
ConnectKind::Location { path } => {
let Location {
protocol,
host,
port,
} = WindowImpl::location(&self.window)?;
let protocol = match protocol.as_str() {
"https:" => "wss:",
"http:" => "ws:",
other => {
return Err(Error::msg(format_args!(
"Same host connection is not supported for protocol `{other}`"
)));
}
};
let path = ForcePrefix(path, '/');
format!("{protocol}//{host}:{port}{path}")
}
ConnectKind::Url { url } => url.clone(),
};
let ws = SocketImpl::new(&url, &self.handles)?;
let old = self.socket.borrow_mut().replace(ws);
if let Some(old) = old {
old.close()?;
}
Ok(())
}
}
pub trait Callback<I>
where
Self: 'static,
{
fn call(&self, input: I);
}
impl<I> Callback<I> for EmptyCallback {
#[inline]
fn call(&self, _: I) {}
}
impl<F, I> Callback<I> for F
where
F: 'static + Fn(I),
{
#[inline]
fn call(&self, input: I) {
self(input)
}
}
pub struct ChannelBuilder<'a, H, C>
where
H: WebImpl,
{
shared: &'a Weak<Shared<H>>,
callback: C,
}
impl<'a, H, C> ChannelBuilder<'a, H, C>
where
H: WebImpl,
{
pub fn on_open<U>(self, callback: U) -> ChannelBuilder<'a, H, U>
where
U: Callback<Result<Channel<H>, Error>>,
{
ChannelBuilder {
shared: self.shared,
callback,
}
}
}
impl<'a, H, C> ChannelBuilder<'a, H, C>
where
C: Callback<Result<Channel<H>, Error>>,
H: WebImpl,
{
pub fn send(self) -> Request {
struct RequestCallbackImpl<C>(C);
impl<C> RequestCallback for RequestCallbackImpl<C>
where
C: Fn(Result<ChannelId, Error>) + 'static,
{
#[inline]
fn as_channel(&self) -> Option<&(dyn Fn(Result<ChannelId, Error>) + 'static)> {
Some(&self.0)
}
#[inline]
fn error(&self, error: Error) {
(self.0)(Err(error));
}
}
let Some(shared) = self.shared.upgrade() else {
self.callback
.call(Err(Error::msg("WebSocket service is down")));
return Request::new();
};
if shared.is_closed() {
self.callback
.call(Err(Error::msg("WebSocket is not connected")));
return Request::new();
}
let serial = shared.serial.get();
if let Err(error) = shared.send_connect(serial) {
shared.on_error.call(error);
return Request::new();
}
shared.serial.set(serial.wrapping_add(1));
let callback = {
let shared = Rc::downgrade(&shared);
move |result| {
let result = match result {
Err(error) => Err(error),
Ok(id) => Ok(Channel {
shared: shared.clone(),
id,
}),
};
self.callback.call(result)
}
};
let pending = Pending {
id: MessageId::CONNECT,
serial,
callback: RequestCallbackImpl(callback),
};
let existing = shared
.g
.requests
.borrow_mut()
.insert(serial, Box::new(pending));
if let Some(p) = existing {
p.callback.error(Error::msg("Request cancelled"));
}
Request {
serial,
g: Rc::downgrade(&shared.g),
}
}
}
pub struct RequestBuilder<'a, H, B, C>
where
H: WebImpl,
{
shared: &'a Weak<Shared<H>>,
channel: Option<ChannelId>,
body: B,
callback: C,
}
impl<'a, H, B, C> RequestBuilder<'a, H, B, C>
where
H: WebImpl,
{
#[inline]
pub fn body<U>(self, body: U) -> RequestBuilder<'a, H, U, C>
where
U: api::Request,
{
RequestBuilder {
shared: self.shared,
channel: self.channel,
body,
callback: self.callback,
}
}
pub fn on_packet<E>(
self,
callback: impl Callback<Result<Packet<E>>>,
) -> RequestBuilder<'a, H, B, impl Callback<Result<RawPacket>>>
where
E: api::Endpoint,
{
self.on_raw_packet(move |result: Result<RawPacket>| match result {
Ok(ok) => callback.call(Ok(Packet::new(ok))),
Err(err) => callback.call(Err(err)),
})
}
pub fn on_raw_packet<U>(self, callback: U) -> RequestBuilder<'a, H, B, U>
where
U: Callback<Result<RawPacket, Error>>,
{
RequestBuilder {
shared: self.shared,
channel: self.channel,
body: self.body,
callback,
}
}
}
impl<'a, H, B, C> RequestBuilder<'a, H, B, C>
where
B: api::Request,
C: Callback<Result<RawPacket>>,
H: WebImpl,
{
pub fn send(self) -> Request {
struct RequestCallbackImpl<C>(C);
impl<C> RequestCallback for RequestCallbackImpl<C>
where
C: Callback<Result<RawPacket>>,
{
#[inline]
fn as_request(&self) -> Option<&(dyn Callback<Result<RawPacket>> + 'static)> {
Some(&self.0)
}
#[inline]
fn error(&self, error: Error) {
self.0.call(Err(error));
}
}
let Some(shared) = self.shared.upgrade() else {
self.callback
.call(Err(Error::msg("WebSocket service is down")));
return Request::new();
};
if shared.is_closed() {
self.callback
.call(Err(Error::msg("WebSocket is not connected")));
return Request::new();
}
let Some(channel) = self.channel else {
self.callback
.call(Err(Error::msg("WebSocket request over closed channel")));
return Request::new();
};
let serial = shared.serial.get();
if let Err(error) = shared.send_client_request(serial, channel, &self.body) {
shared.on_error.call(error);
return Request::new();
}
shared.serial.set(serial.wrapping_add(1));
let pending = Pending {
id: <B::Endpoint as api::Endpoint>::ID,
serial,
callback: RequestCallbackImpl(self.callback),
};
let existing = shared
.g
.requests
.borrow_mut()
.insert(serial, Box::new(pending));
if let Some(p) = existing {
p.callback.error(Error::msg("Request cancelled"));
}
Request {
serial,
g: Rc::downgrade(&shared.g),
}
}
}
pub struct Request {
serial: u32,
g: Weak<Generic>,
}
impl Request {
#[inline]
pub const fn new() -> Self {
Self {
serial: 0,
g: Weak::new(),
}
}
pub fn clear(&mut self) {
let removed = {
let serial = mem::take(&mut self.serial);
let Some(g) = self.g.upgrade() else {
return;
};
self.g = Weak::new();
let Some(p) = g.requests.borrow_mut().remove(&serial) else {
return;
};
p
};
drop(removed);
}
#[inline]
pub fn is_pending(&self) -> bool {
let Some(g) = self.g.upgrade() else {
return false;
};
g.requests.borrow().contains_key(&self.serial)
}
}
impl Default for Request {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl Drop for Request {
#[inline]
fn drop(&mut self) {
self.clear();
}
}
pub struct Listener {
kind: Option<MessageId>,
index: usize,
g: Weak<Generic>,
}
impl Listener {
#[inline]
pub const fn new() -> Self {
Self {
kind: None,
index: 0,
g: Weak::new(),
}
}
#[inline]
pub(crate) const fn empty_with_kind(kind: MessageId) -> Self {
Self {
kind: Some(kind),
index: 0,
g: Weak::new(),
}
}
pub fn clear(&mut self) {
let removed;
let removed_value;
{
let Some(g) = self.g.upgrade() else {
return;
};
self.g = Weak::new();
let index = mem::take(&mut self.index);
let Some(kind) = self.kind.take() else {
return;
};
let mut broadcasts = g.broadcasts.borrow_mut();
let Entry::Occupied(mut e) = broadcasts.entry(kind) else {
return;
};
removed = e.get_mut().try_remove(index);
if e.get().is_empty() {
removed_value = Some(e.remove());
} else {
removed_value = None;
}
}
drop(removed);
drop(removed_value);
}
}
impl Default for Listener {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl Drop for Listener {
#[inline]
fn drop(&mut self) {
self.clear();
}
}
pub struct StateListener {
index: usize,
g: Weak<Generic>,
}
impl StateListener {
#[inline]
pub const fn new() -> Self {
Self {
index: 0,
g: Weak::new(),
}
}
pub fn clear(&mut self) {
let removed = {
let Some(g) = self.g.upgrade() else {
return;
};
self.g = Weak::new();
g.state_listeners.borrow_mut().try_remove(self.index)
};
drop(removed);
}
}
impl Default for StateListener {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl Drop for StateListener {
#[inline]
fn drop(&mut self) {
self.clear();
}
}
pub(crate) struct BufData {
pub(crate) data: Vec<u8>,
strong: Cell<usize>,
g: Weak<Generic>,
}
impl BufData {
fn with_capacity(g: Weak<Generic>, capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
strong: Cell::new(0),
g,
}
}
unsafe fn dec(ptr: NonNull<BufData>) {
unsafe {
let count = ptr.as_ref().strong.get().wrapping_sub(1);
ptr.as_ref().strong.set(count);
if count > 0 {
return;
}
let mut buf = Box::from_raw(ptr.as_ptr());
let Some(g) = buf.as_ref().g.upgrade() else {
return;
};
let mut buffers = g.buffers.borrow_mut();
buf.data.set_len(buf.data.len().min(MAX_CAPACITY));
buf.data.shrink_to(MAX_CAPACITY);
buffers.push_back(buf);
}
}
unsafe fn inc(ptr: NonNull<BufData>) {
unsafe {
let count = ptr.as_ref().strong.get().wrapping_add(1);
if count == 0 {
std::process::abort();
}
ptr.as_ref().strong.set(count);
}
}
}
struct BufRc {
data: NonNull<BufData>,
}
impl BufRc {
fn new(data: Box<BufData>) -> Self {
let data = NonNull::from(Box::leak(data));
unsafe {
BufData::inc(data);
}
Self { data }
}
}
impl Deref for BufRc {
type Target = [u8];
fn deref(&self) -> &Self::Target {
unsafe { &(*self.data.as_ptr()).data }
}
}
impl Clone for BufRc {
fn clone(&self) -> Self {
unsafe {
BufData::inc(self.data);
}
Self { data: self.data }
}
}
impl Drop for BufRc {
fn drop(&mut self) {
unsafe {
BufData::dec(self.data);
}
}
}
#[derive(Clone)]
pub struct RawPacket {
id: MessageId,
buf: Option<BufRc>,
at: Cell<usize>,
channel: ChannelId,
}
impl RawPacket {
pub const fn empty() -> Self {
Self {
id: MessageId::EMPTY,
buf: None,
at: Cell::new(0),
channel: ChannelId::NONE,
}
}
#[inline]
pub fn channel(&self) -> ChannelId {
self.channel
}
pub fn decode<'this, T>(&'this self) -> Result<T>
where
T: Decode<'this, Binary, Global>,
{
let at = self.at.get();
if self.id == MessageId::EMPTY {
return Err(Error::new(ErrorKind::EmptyPacket));
}
let Some(bytes) = self.as_slice().get(at..) else {
return Err(Error::new(ErrorKind::Overflow(at, self.as_slice().len())));
};
let mut reader = SliceReader::new(bytes);
match storage::decode(&mut reader) {
Ok(value) => {
self.at.set(at + bytes.len() - reader.remaining());
Ok(value)
}
Err(error) => {
self.at.set(self.len());
Err(Error::decode_packet(error))
}
}
}
pub fn as_slice(&self) -> &[u8] {
match &self.buf {
Some(buf) => buf.as_ref(),
None => &[],
}
}
pub fn len(&self) -> usize {
match &self.buf {
Some(buf) => buf.len(),
None => 0,
}
}
pub fn is_empty(&self) -> bool {
self.at.get() >= self.len()
}
pub fn id(&self) -> MessageId {
self.id
}
}
#[derive(Clone)]
pub struct Packet<T> {
raw: RawPacket,
_marker: PhantomData<T>,
}
impl<T> Packet<T> {
pub const fn empty() -> Self {
Self {
raw: RawPacket::empty(),
_marker: PhantomData,
}
}
#[inline]
pub fn channel(&self) -> ChannelId {
self.raw.channel()
}
#[inline]
pub fn new(raw: RawPacket) -> Self {
Self {
raw,
_marker: PhantomData,
}
}
pub fn into_raw(self) -> RawPacket {
self.raw
}
pub fn is_empty(&self) -> bool {
self.raw.is_empty()
}
pub fn id(&self) -> MessageId {
self.raw.id()
}
}
impl<T> Packet<T>
where
T: api::Decodable,
{
pub fn decode(&self) -> Result<T::Type<'_>> {
self.decode_any()
}
pub fn decode_any<'de, R>(&'de self) -> Result<R>
where
R: Decode<'de, Binary, Global>,
{
self.raw.decode()
}
}
impl<T> Packet<T>
where
T: api::Endpoint,
{
pub fn decode_response(&self) -> Result<T::Response<'_>> {
self.decode_any_response()
}
pub fn decode_any_response<'de, R>(&'de self) -> Result<R>
where
R: Decode<'de, Binary, Global>,
{
self.raw.decode()
}
}
impl<T> Packet<T>
where
T: api::Broadcast,
{
pub fn decode_event<'de>(&'de self) -> Result<T::Event<'de>>
where
T: api::BroadcastWithEvent,
{
self.decode_event_any()
}
pub fn decode_event_any<'de, E>(&'de self) -> Result<E>
where
E: Event<Broadcast = T> + Decode<'de, Binary, Global>,
{
self.raw.decode()
}
}
pub struct Handle<H>
where
H: WebImpl,
{
shared: Weak<Shared<H>>,
}
impl<H> Handle<H>
where
H: WebImpl,
{
pub fn channel(&self) -> ChannelBuilder<'_, H, EmptyCallback> {
ChannelBuilder {
shared: &self.shared,
callback: EmptyCallback,
}
}
pub fn request(&self) -> RequestBuilder<'_, H, EmptyBody, EmptyCallback> {
RequestBuilder {
shared: &self.shared,
channel: Some(ChannelId::NONE),
body: EmptyBody,
callback: EmptyCallback,
}
}
pub fn on_broadcast<T>(&self, callback: impl Callback<Result<Packet<T>>>) -> Listener
where
T: api::Broadcast,
{
self.on_raw_broadcast::<T>(move |result| match result {
Ok(packet) => callback.call(Ok(Packet::new(packet))),
Err(error) => callback.call(Err(error)),
})
}
pub fn on_raw_broadcast<T>(&self, callback: impl Callback<Result<RawPacket>>) -> Listener
where
T: api::Broadcast,
{
let Some(shared) = self.shared.upgrade() else {
return Listener::empty_with_kind(T::ID);
};
let index = {
let mut broadcasts = shared.g.broadcasts.borrow_mut();
let slots = broadcasts.entry(T::ID).or_default();
slots.insert(Rc::new(callback))
};
Listener {
kind: Some(T::ID),
index,
g: Rc::downgrade(&shared.g),
}
}
pub fn on_state_change(&self, callback: impl Callback<State>) -> (State, StateListener) {
let Some(shared) = self.shared.upgrade() else {
return (
State::Closed,
StateListener {
index: 0,
g: Weak::new(),
},
);
};
let (state, index) = {
let index = shared
.g
.state_listeners
.borrow_mut()
.insert(Rc::new(callback));
(shared.state.get(), index)
};
let listener = StateListener {
index,
g: Rc::downgrade(&shared.g),
};
(state, listener)
}
}
impl<H> Clone for Handle<H>
where
H: WebImpl,
{
#[inline]
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
impl<H> Default for Handle<H>
where
H: WebImpl,
{
#[inline]
fn default() -> Self {
Self {
shared: Weak::new(),
}
}
}
impl<H> PartialEq for Handle<H>
where
H: WebImpl,
{
#[inline]
fn eq(&self, other: &Self) -> bool {
Weak::ptr_eq(&self.shared, &other.shared)
}
}
impl<H> fmt::Debug for Handle<H>
where
H: WebImpl,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("Handle");
if let Some(shared) = self.shared.upgrade() {
f.field("connect", &shared.connect);
f.field("state", &shared.state.get());
}
f.finish()
}
}
pub struct Channel<H>
where
H: WebImpl,
{
shared: Weak<Shared<H>>,
id: ChannelId,
}
impl<H> Channel<H>
where
H: WebImpl,
{
#[inline]
pub fn id(&self) -> ChannelId {
self.id
}
pub fn request(&self) -> RequestBuilder<'_, H, EmptyBody, EmptyCallback> {
RequestBuilder {
shared: &self.shared,
channel: (self.id != ChannelId::NONE).then_some(self.id),
body: EmptyBody,
callback: EmptyCallback,
}
}
}
impl<H> Default for Channel<H>
where
H: WebImpl,
{
#[inline]
fn default() -> Self {
Self {
shared: Weak::new(),
id: ChannelId::NONE,
}
}
}
impl<H> PartialEq for Channel<H>
where
H: WebImpl,
{
#[inline]
fn eq(&self, other: &Self) -> bool {
Weak::ptr_eq(&self.shared, &other.shared) && self.id == other.id
}
}
impl<H> Drop for Channel<H>
where
H: WebImpl,
{
#[inline]
fn drop(&mut self) {
if let Some(shared) = self.shared.upgrade() {
shared.remove_channel(self.id);
}
}
}
impl<H> fmt::Debug for Channel<H>
where
H: WebImpl,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("Channel");
if let Some(shared) = self.shared.upgrade() {
f.field("connect", &shared.connect);
f.field("state", &shared.state.get());
}
f.field("id", &self.id);
f.finish()
}
}
struct Pending<C>
where
C: ?Sized,
{
id: MessageId,
serial: u32,
callback: C,
}
impl<C> fmt::Debug for Pending<C> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Pending")
.field("serial", &self.serial)
.field("id", &self.id)
.finish_non_exhaustive()
}
}
struct ForcePrefix<'a>(&'a str, char);
impl fmt::Display for ForcePrefix<'_> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self(string, prefix) = *self;
prefix.fmt(f)?;
string.trim_start_matches(prefix).fmt(f)?;
Ok(())
}
}