use crate::{
buffer::Buffer,
codec::Codec,
engine::EngineId,
event::{Receiver, Sender},
protocol::{PacketRecepients, ProtocolPacketData},
replication::Replicable,
};
use serde::{Deserialize, Serialize};
use std::{
any::{Any, TypeId},
error::Error,
hash::Hash,
io::Cursor,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ChannelId(u64);
impl ChannelId {
pub const fn new(id: u64) -> Self {
Self(id)
}
pub fn hashed<T: Hash>(item: &T) -> Self {
Self(crate::hash(item))
}
pub fn typed<T: 'static>() -> Self {
Self::hashed(&std::any::type_name::<T>())
}
pub const fn id(&self) -> u64 {
self.0
}
}
impl Replicable for ChannelId {
fn collect_changes(&self, buffer: &mut Buffer) -> Result<(), Box<dyn Error>> {
self.0.collect_changes(buffer)?;
Ok(())
}
fn apply_changes(&mut self, buffer: &mut Buffer) -> Result<(), Box<dyn Error>> {
self.0.apply_changes(buffer)?;
Ok(())
}
}
impl std::fmt::Display for ChannelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "#channel:{}", self.0)
}
}
#[derive(Debug)]
pub struct Dispatch<Message: Send + 'static> {
pub message: Message,
pub recepients: PacketRecepients,
pub sender: Option<EngineId>,
}
impl<Message: Send + 'static> Dispatch<Message> {
pub fn new(message: Message) -> Self {
Self {
message,
recepients: Default::default(),
sender: None,
}
}
pub fn recepient(mut self, engine_id: EngineId) -> Self {
self.recepients.push(engine_id);
self
}
pub fn maybe_recepient(mut self, engine_id: Option<EngineId>) -> Self {
if let Some(engine_id) = engine_id {
self.recepients.push(engine_id);
}
self
}
pub fn recepients(mut self, engine_ids: impl IntoIterator<Item = EngineId>) -> Self {
self.recepients.extend(engine_ids);
self
}
pub fn sender(mut self, engine_id: EngineId) -> Self {
self.sender = Some(engine_id);
self
}
pub fn maybe_sender(mut self, engine_id: Option<EngineId>) -> Self {
self.sender = engine_id;
self
}
}
impl<Message: Send + 'static> From<Message> for Dispatch<Message> {
fn from(message: Message) -> Self {
Self::new(message)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum ChannelMode {
ReliableOrdered,
ReliableUnordered,
Unreliable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum ChannelKind {
Read,
Write,
}
pub struct Channel {
kind: ChannelKind,
message_type_name: &'static str,
message_type_id: TypeId,
pump: Box<dyn FnMut() -> Result<bool, Box<dyn Error>> + Send>,
pump_all: Box<dyn FnMut() -> Result<usize, Box<dyn Error>> + Send>,
}
impl Channel {
pub fn read<C: Codec<Value = Message>, Message: Send + 'static>(
packet_receiver: Receiver<ProtocolPacketData>,
message_sender: Sender<Dispatch<Message>>,
) -> Self {
struct State<Message: Send + 'static> {
packet_receiver: Receiver<ProtocolPacketData>,
message_sender: Sender<Dispatch<Message>>,
}
let state_pump = Arc::new(Mutex::new(State {
packet_receiver,
message_sender,
})) as Arc<Mutex<dyn Any + Send>>;
let state_pump_all = state_pump.clone();
let pump = Box::new(move || -> Result<bool, Box<dyn Error>> {
let mut state = state_pump
.lock()
.map_err(|_| "Failed to lock read channel state")?;
let state = state
.downcast_mut::<State<Message>>()
.ok_or("Failed to downcast read channel state")?;
if let Some(ProtocolPacketData {
data,
recepients,
sender,
}) = state.packet_receiver.try_recv()
{
let mut buffer = Cursor::new(data);
let message = C::decode(&mut buffer)?;
state
.message_sender
.send(Dispatch {
message,
recepients,
sender,
})
.map_err(|err| format!("Pump message sender error: {err}"))?;
Ok(true)
} else {
Ok(false)
}
});
let pump_all = Box::new(move || -> Result<usize, Box<dyn Error>> {
let mut count = 0;
let mut state = state_pump_all
.lock()
.map_err(|_| "Failed to lock read channel state")?;
let state = state
.downcast_mut::<State<Message>>()
.ok_or("Failed to downcast read channel state")?;
for ProtocolPacketData {
data,
recepients,
sender,
} in state.packet_receiver.iter()
{
let mut buffer = Cursor::new(data);
let message = C::decode(&mut buffer)?;
state
.message_sender
.send(Dispatch {
message,
recepients,
sender,
})
.map_err(|err| format!("Pump-all message sender error: {err}"))?;
count += 1;
}
Ok(count)
});
Self {
kind: ChannelKind::Read,
message_type_name: std::any::type_name::<Message>(),
message_type_id: TypeId::of::<Message>(),
pump,
pump_all,
}
}
pub fn write<C: Codec<Value = Message>, Message: Send + 'static>(
packet_sender: Sender<ProtocolPacketData>,
message_receiver: Receiver<Dispatch<Message>>,
) -> Self {
struct State<Message: Send + 'static> {
packet_sender: Sender<ProtocolPacketData>,
message_receiver: Receiver<Dispatch<Message>>,
}
let state_pump = Arc::new(Mutex::new(State {
packet_sender,
message_receiver,
})) as Arc<Mutex<dyn Any + Send>>;
let state_pump_all = state_pump.clone();
let pump = Box::new(move || -> Result<bool, Box<dyn Error>> {
let mut state = state_pump
.lock()
.map_err(|_| "Failed to lock write channel state")?;
let state = state
.downcast_mut::<State<Message>>()
.ok_or("Failed to downcast write channel state")?;
if let Some(Dispatch {
message,
recepients,
sender,
}) = state.message_receiver.try_recv()
{
let mut buffer = Cursor::new(Vec::new());
C::encode(&message, &mut buffer)?;
state
.packet_sender
.send(ProtocolPacketData {
data: buffer.into_inner(),
recepients,
sender,
})
.map_err(|err| format!("Pump packet sender error: {err}"))?;
Ok(true)
} else {
Ok(false)
}
});
let pump_all = Box::new(move || -> Result<usize, Box<dyn Error>> {
let mut count = 0;
let mut state = state_pump_all
.lock()
.map_err(|_| "Failed to lock write channel state")?;
let state = state
.downcast_mut::<State<Message>>()
.ok_or("Failed to downcast write channel state")?;
for Dispatch {
message,
recepients,
sender,
} in state.message_receiver.iter()
{
let mut buffer = Cursor::new(Vec::new());
C::encode(&message, &mut buffer)?;
state
.packet_sender
.send(ProtocolPacketData {
data: buffer.into_inner(),
recepients,
sender,
})
.map_err(|err| format!("Pump-all packet sender error: {err}"))?;
count += 1;
}
Ok(count)
});
Self {
kind: ChannelKind::Write,
message_type_name: std::any::type_name::<Message>(),
message_type_id: TypeId::of::<Message>(),
pump,
pump_all,
}
}
pub fn kind(&self) -> ChannelKind {
self.kind
}
pub fn message_type_name(&self) -> &'static str {
self.message_type_name
}
pub fn message_type_id(&self) -> TypeId {
self.message_type_id
}
pub fn pump(&mut self) -> Result<bool, Box<dyn Error>> {
(self.pump)()
}
pub fn pump_all(&mut self) -> Result<usize, Box<dyn Error>> {
(self.pump_all)()
}
}
impl std::fmt::Debug for Channel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Channel")
.field("kind", &self.kind)
.field("message_type_name", &self.message_type_name)
.field("message_type_id", &self.message_type_id)
.finish()
}
}
impl Future for Channel {
type Output = Result<(), Box<dyn Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.pump() {
Ok(true) => Poll::Ready(Ok(())),
Ok(false) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{codec::postcard::PostcardCodec, event::unbounded};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct TestMessage {
pub id: u32,
}
#[test]
fn test_async() {
fn is_send<T: Send>() {}
is_send::<Channel>();
}
#[test]
fn test_channel_write() {
let (msg_tx, msg_rx) = unbounded();
let (pkt_tx, pkt_rx) = unbounded();
let mut channel = Channel::write::<PostcardCodec<TestMessage>, _>(pkt_tx, msg_rx);
let message = TestMessage { id: 42 };
msg_tx.send(message.into()).unwrap();
channel.pump().unwrap();
let packet = pkt_rx.recv_blocking().unwrap();
let mut buffer = Cursor::new(packet.data);
let message = PostcardCodec::<TestMessage>::decode(&mut buffer).unwrap();
assert_eq!(message.id, 42);
}
#[test]
fn test_channel_read() {
let (msg_tx, msg_rx) = unbounded();
let (pkt_tx, pkt_rx) = unbounded();
let mut channel = Channel::read::<PostcardCodec<TestMessage>, _>(pkt_rx, msg_tx);
let message = TestMessage { id: 42 };
let mut buffer = Cursor::new(Vec::new());
PostcardCodec::<TestMessage>::encode(&message, &mut buffer).unwrap();
pkt_tx
.send(ProtocolPacketData {
data: buffer.into_inner(),
recepients: Default::default(),
sender: None,
})
.unwrap();
channel.pump().unwrap();
let message = msg_rx.recv_blocking().unwrap();
assert_eq!(message.message.id, 42);
}
}