use super::generator::StreamGenerator;
use crate::{
error::Result,
packet::{Packet, PacketId},
Error, Event, Payload,
};
use async_stream::try_stream;
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use std::{
fmt::Debug,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tf_rust_engineio::{
asynchronous::Client as EngineClient, Packet as EnginePacket, PacketId as EnginePacketId,
};
#[derive(Clone)]
pub(crate) struct Socket {
engine_client: Arc<EngineClient>,
connected: Arc<AtomicBool>,
generator: StreamGenerator<Packet>,
}
impl Socket {
pub(super) fn new(engine_client: EngineClient) -> Result<Self> {
let connected = Arc::new(AtomicBool::default());
Ok(Socket {
engine_client: Arc::new(engine_client.clone()),
connected: connected.clone(),
generator: StreamGenerator::new(Self::stream(engine_client, connected)),
})
}
pub async fn connect(&self) -> Result<()> {
self.engine_client.connect().await?;
self.connected.store(true, Ordering::Release);
Ok(())
}
pub async fn disconnect(&self) -> Result<()> {
if self.is_engineio_connected() {
self.engine_client.disconnect().await?;
}
if self.connected.load(Ordering::Acquire) {
self.connected.store(false, Ordering::Release);
}
Ok(())
}
pub async fn send(&self, packet: Packet) -> Result<()> {
if !self.is_engineio_connected() || !self.connected.load(Ordering::Acquire) {
return Err(Error::IllegalActionBeforeOpen());
}
let engine_packet = EnginePacket::new(EnginePacketId::Message, Bytes::from(&packet));
self.engine_client.emit(engine_packet).await?;
if let Some(attachments) = packet.attachments {
for attachment in attachments {
let engine_packet = EnginePacket::new(EnginePacketId::MessageBinary, attachment);
self.engine_client.emit(engine_packet).await?;
}
}
Ok(())
}
pub async fn emit(&self, nsp: &str, event: Event, data: Payload) -> Result<()> {
let socket_packet = Packet::new_from_payload(data, event, nsp, None)?;
self.send(socket_packet).await
}
pub async fn ack(&self, nsp: &str, data: Payload, ack_id: Option<i32>) -> Result<()> {
let socket_packet = Packet::ack_from_payload(data, nsp, ack_id)?;
self.send(socket_packet).await
}
fn stream(
client: EngineClient,
is_connected: Arc<AtomicBool>,
) -> Pin<Box<impl Stream<Item = Result<Packet>> + Send>> {
Box::pin(try_stream! {
for await received_data in client.clone() {
let packet = received_data?;
if packet.packet_id == EnginePacketId::Message
|| packet.packet_id == EnginePacketId::MessageBinary
{
let packet = Self::handle_engineio_packet(packet, client.clone()).await?;
Self::handle_socketio_packet(&packet, is_connected.clone());
yield packet;
}
}
})
}
#[inline]
fn handle_socketio_packet(socket_packet: &Packet, is_connected: Arc<AtomicBool>) {
match socket_packet.packet_type {
PacketId::Connect => {
is_connected.store(true, Ordering::Release);
}
PacketId::ConnectError => {
is_connected.store(false, Ordering::Release);
}
PacketId::Disconnect => {
is_connected.store(false, Ordering::Release);
}
_ => (),
}
}
async fn handle_engineio_packet(
packet: EnginePacket,
mut client: EngineClient,
) -> Result<Packet> {
let mut socket_packet = Packet::try_from(&packet.data)?;
if socket_packet.attachment_count > 0 {
let mut attachments_left = socket_packet.attachment_count;
let mut attachments = Vec::new();
while attachments_left > 0 {
let next = client.next().await.unwrap();
match next {
Err(err) => return Err(err.into()),
Ok(packet) => match packet.packet_id {
EnginePacketId::MessageBinary | EnginePacketId::Message => {
attachments.push(packet.data);
attachments_left -= 1;
}
_ => {
return Err(Error::InvalidAttachmentPacketType(
packet.packet_id.into(),
));
}
},
}
}
socket_packet.attachments = Some(attachments);
}
Ok(socket_packet)
}
fn is_engineio_connected(&self) -> bool {
self.engine_client.is_connected()
}
}
impl Stream for Socket {
type Item = Result<Packet>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.generator.poll_next_unpin(cx)
}
}
impl Debug for Socket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Socket")
.field("engine_client", &self.engine_client)
.field("connected", &self.connected)
.finish()
}
}