use async_channel::{Receiver, Sender};
pub use async_channel::{RecvError, SendError, TryRecvError, TrySendError};
use futures::Stream;
use super::*;
#[viewit::viewit(
vis_all = "",
getters(vis_all = "pub", style = "ref"),
setters(vis_all = "pub", prefix = "with")
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Packet<A, T> {
#[viewit(
getter(
const,
style = "ref",
attrs(doc = "Returns the messages of the packet")
),
setter(attrs(doc = "Sets the messages of the packet (Builder pattern)"))
)]
payload: Bytes,
#[viewit(
getter(
const,
style = "ref",
attrs(doc = "Returns the address sent the packet")
),
setter(attrs(doc = "Sets the address who sent the packet (Builder pattern)"))
)]
from: A,
#[viewit(
getter(
const,
style = "ref",
attrs(doc = "Returns the instant when the packet was received")
),
setter(attrs(doc = "Sets the instant when the packet was received (Builder pattern)"))
)]
timestamp: T,
}
impl<A, T> Packet<A, T> {
#[inline]
pub const fn new(from: A, timestamp: T, payload: Bytes) -> Self {
Self {
payload,
from,
timestamp,
}
}
#[inline]
pub fn into_components(self) -> (A, T, Bytes) {
(self.from, self.timestamp, self.payload)
}
#[inline]
pub fn set_from(&mut self, from: A) -> &mut Self {
self.from = from;
self
}
#[inline]
pub fn set_timestamp(&mut self, timestamp: T) -> &mut Self {
self.timestamp = timestamp;
self
}
#[inline]
pub fn set_payload(&mut self, payload: Bytes) -> &mut Self {
self.payload = payload;
self
}
}
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PacketProducer<A, T> {
sender: Sender<Packet<A, T>>,
}
#[pin_project::pin_project]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PacketSubscriber<A, T> {
#[pin]
receiver: Receiver<Packet<A, T>>,
}
pub fn packet_stream<T: Transport>() -> (
PacketProducer<T::ResolvedAddress, <T::Runtime as RuntimeLite>::Instant>,
PacketSubscriber<T::ResolvedAddress, <T::Runtime as RuntimeLite>::Instant>,
) {
let (sender, receiver) = async_channel::unbounded();
(PacketProducer { sender }, PacketSubscriber { receiver })
}
impl<A, T> PacketProducer<A, T> {
pub async fn send(&self, packet: Packet<A, T>) -> Result<(), SendError<Packet<A, T>>> {
self.sender.send(packet).await
}
pub fn try_send(&self, packet: Packet<A, T>) -> Result<(), TrySendError<Packet<A, T>>> {
self.sender.try_send(packet)
}
pub fn is_empty(&self) -> bool {
self.sender.is_empty()
}
pub fn len(&self) -> usize {
self.sender.len()
}
pub fn is_full(&self) -> bool {
self.sender.is_full()
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub fn close(&self) -> bool {
self.sender.close()
}
}
impl<A, T> PacketSubscriber<A, T> {
pub async fn recv(&self) -> Result<Packet<A, T>, RecvError> {
self.receiver.recv().await
}
pub fn try_recv(&self) -> Result<Packet<A, T>, TryRecvError> {
self.receiver.try_recv()
}
pub fn is_empty(&self) -> bool {
self.receiver.is_empty()
}
pub fn len(&self) -> usize {
self.receiver.len()
}
pub fn is_full(&self) -> bool {
self.receiver.is_full()
}
pub fn is_closed(&self) -> bool {
self.receiver.is_closed()
}
pub fn close(&self) -> bool {
self.receiver.close()
}
}
impl<A, T> Stream for PacketSubscriber<A, T> {
type Item = <Receiver<Packet<A, T>> as Stream>::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
<Receiver<_> as Stream>::poll_next(self.project().receiver, cx)
}
}
pub fn promised_stream<T: Transport>() -> (
StreamProducer<T::ResolvedAddress, T::Connection>,
StreamSubscriber<T::ResolvedAddress, T::Connection>,
) {
let (sender, receiver) = async_channel::bounded(1);
(StreamProducer { sender }, StreamSubscriber { receiver })
}
#[derive(Debug)]
#[repr(transparent)]
pub struct StreamProducer<A, S> {
sender: Sender<(A, S)>,
}
impl<A, S> Clone for StreamProducer<A, S> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
#[pin_project::pin_project]
#[derive(Debug)]
#[repr(transparent)]
pub struct StreamSubscriber<A, S> {
#[pin]
receiver: Receiver<(A, S)>,
}
impl<A, S> Clone for StreamSubscriber<A, S> {
fn clone(&self) -> Self {
Self {
receiver: self.receiver.clone(),
}
}
}
impl<A, S> StreamProducer<A, S> {
pub async fn send(&self, addr: A, conn: S) -> Result<(), SendError<(A, S)>> {
self.sender.send((addr, conn)).await
}
pub fn try_send(&self, addr: A, conn: S) -> Result<(), TrySendError<(A, S)>> {
self.sender.try_send((addr, conn))
}
pub fn is_empty(&self) -> bool {
self.sender.is_empty()
}
pub fn len(&self) -> usize {
self.sender.len()
}
pub fn is_full(&self) -> bool {
self.sender.is_full()
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub fn close(&self) -> bool {
self.sender.close()
}
}
impl<A, S> StreamSubscriber<A, S> {
pub async fn recv(&self) -> Result<(A, S), RecvError> {
self.receiver.recv().await
}
pub fn try_recv(&self) -> Result<(A, S), TryRecvError> {
self.receiver.try_recv()
}
pub fn is_empty(&self) -> bool {
self.receiver.is_empty()
}
pub fn len(&self) -> usize {
self.receiver.len()
}
pub fn is_full(&self) -> bool {
self.receiver.is_full()
}
pub fn is_closed(&self) -> bool {
self.receiver.is_closed()
}
pub fn close(&self) -> bool {
self.receiver.close()
}
}
impl<A, S> Stream for StreamSubscriber<A, S> {
type Item = <Receiver<(A, S)> as Stream>::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
<Receiver<_> as Stream>::poll_next(self.project().receiver, cx)
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use bytes::Bytes;
use smol_str::SmolStr;
use super::*;
async fn access<R: RuntimeLite>() {
let messages = Message::<SmolStr, SocketAddr>::user_data(Bytes::new());
let timestamp = R::now();
let mut packet = Packet::<SocketAddr, _>::new(
"127.0.0.1:8080".parse().unwrap(),
timestamp,
messages.encode_to_bytes().unwrap(),
);
packet.set_from("127.0.0.1:8081".parse().unwrap());
let start = R::now();
packet.set_timestamp(start);
let messages = Message::<SmolStr, SocketAddr>::user_data(Bytes::from_static(b"a"))
.encode_to_bytes()
.unwrap();
packet.set_payload(messages);
assert_eq!(
packet.payload(),
&Message::<SmolStr, SocketAddr>::user_data(Bytes::from_static(b"a"))
.encode_to_bytes()
.unwrap(),
);
assert_eq!(
*packet.from(),
"127.0.0.1:8081".parse::<SocketAddr>().unwrap()
);
assert_eq!(*packet.timestamp(), start);
}
#[test]
fn tokio_access() {
tokio::runtime::Builder::new_current_thread()
.worker_threads(1)
.enable_all()
.build()
.unwrap()
.block_on(access::<agnostic_lite::tokio::TokioRuntime>());
}
}