pub(crate) mod backend;
use {
crate::WebSocketRuntime,
aeronet_io::{
AeronetIoPlugin, IoSystems, Session,
connection::{DROP_DISCONNECT_REASON, Disconnect},
packet::{IP_MTU, RecvPacket},
},
bevy_app::prelude::*,
bevy_ecs::prelude::*,
bevy_platform::time::Instant,
bytes::Bytes,
core::num::Saturating,
derive_more::{Display, Error},
futures::channel::{mpsc, oneshot},
std::io,
tracing::{trace, trace_span},
};
cfg_if::cfg_if! {
if #[cfg(target_family = "wasm")] {
type ConnectionError = crate::JsError;
type SendError = crate::JsError;
} else {
use futures::never::Never;
type ConnectionError = crate::tungstenite::Error;
type SendError = Never;
}
}
pub(crate) struct WebSocketSessionPlugin;
impl Plugin for WebSocketSessionPlugin {
fn build(&self, app: &mut App) {
if !app.is_plugin_added::<AeronetIoPlugin>() {
app.add_plugins(AeronetIoPlugin);
}
#[cfg(not(target_family = "wasm"))]
{
use tracing::debug;
if rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.is_ok()
{
debug!("Installed default `aws-lc-rs` CryptoProvider");
} else {
debug!("CryptoProvider is already installed");
}
}
app.init_resource::<WebSocketRuntime>()
.add_systems(PreUpdate, poll.in_set(IoSystems::Poll))
.add_systems(PostUpdate, flush.in_set(IoSystems::Flush))
.add_observer(on_disconnect);
}
}
#[derive(Debug, Component)]
#[require(Session::new(Instant::now(), MTU))]
pub struct WebSocketIo {
pub(crate) rx_packet_b2f: mpsc::UnboundedReceiver<RecvPacket>,
pub(crate) tx_packet_f2b: mpsc::UnboundedSender<Bytes>,
pub(crate) tx_user_dc: Option<oneshot::Sender<String>>,
}
pub const MTU: usize = IP_MTU - 60 - 40 - 14;
#[derive(Debug, Display, Error)]
#[non_exhaustive]
pub enum SessionError {
#[display("frontend closed")]
FrontendClosed,
#[display("backend closed")]
BackendClosed,
#[display("failed to get local socket address")]
GetLocalAddr(io::Error),
#[display("failed to get peer socket address")]
GetPeerAddr(io::Error),
#[display("receiver stream closed")]
RecvStreamClosed,
#[display("connection lost")]
Connection(ConnectionError),
#[display("connection closed with code {_0}")]
Closed(#[error(not(source))] u16),
#[display("peer disconnected without reason")]
DisconnectedWithoutReason,
#[display("failed to send data")]
Send(SendError),
}
impl Drop for WebSocketIo {
fn drop(&mut self) {
if let Some(tx_dc) = self.tx_user_dc.take() {
_ = tx_dc.send(DROP_DISCONNECT_REASON.to_owned());
}
}
}
#[derive(Debug)]
pub(crate) struct SessionFrontend {
pub rx_packet_b2f: mpsc::UnboundedReceiver<RecvPacket>,
pub tx_packet_f2b: mpsc::UnboundedSender<Bytes>,
pub tx_user_dc: oneshot::Sender<String>,
}
fn on_disconnect(trigger: On<Disconnect>, mut sessions: Query<&mut WebSocketIo>) {
let entity = trigger.event_target();
let Ok(mut io) = sessions.get_mut(entity) else {
return;
};
if let Some(tx_dc) = io.tx_user_dc.take() {
_ = tx_dc.send(trigger.reason.clone());
}
}
pub(crate) fn poll(mut sessions: Query<(Entity, &mut Session, &mut WebSocketIo)>) {
for (entity, mut session, mut io) in &mut sessions {
let span = trace_span!("poll", %entity);
let _span = span.enter();
let mut num_packets = Saturating(0);
let mut num_bytes = Saturating(0);
while let Ok(Some(packet)) = io.rx_packet_b2f.try_next() {
num_packets += 1;
session.stats.packets_recv += 1;
num_bytes += packet.payload.len();
session.stats.bytes_recv += packet.payload.len();
session.recv.push(packet);
}
if num_packets.0 > 0 {
trace!(%num_packets, %num_bytes, "Received packets");
}
}
}
fn flush(mut sessions: Query<(Entity, &mut Session, &WebSocketIo)>) {
for (entity, mut session, io) in &mut sessions {
let span = trace_span!("flush", %entity);
let _span = span.enter();
let session = &mut *session;
let mut num_packets = Saturating(0);
let mut num_bytes = Saturating(0);
for packet in session.send.drain(..) {
num_packets += 1;
session.stats.packets_sent += 1;
num_bytes += packet.len();
session.stats.bytes_sent += packet.len();
_ = io.tx_packet_f2b.unbounded_send(packet);
}
if num_packets.0 > 0 {
trace!(%num_packets, %num_bytes, "Flushed packets");
}
}
}