use std::{
collections::VecDeque,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::Bytes;
use futures::{Future, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, error, warn};
use msg_common::{Channel, unix_micros};
use msg_transport::Address;
use msg_wire::pubsub;
use super::{
stats::SessionStats,
stream::{PublisherStream, TopicMessage},
};
pub(super) enum SessionCommand {
Subscribe(String),
Unsubscribe(String),
}
#[must_use = "This future must be spawned"]
pub(super) struct PublisherSession<Io, A: Address> {
addr: A,
egress: VecDeque<pubsub::Message>,
stream: PublisherStream<Io>,
stats: Arc<SessionStats>,
driver_channel: Channel<TopicMessage, SessionCommand>,
}
impl<Io: AsyncRead + AsyncWrite + Unpin, A: Address> PublisherSession<Io, A> {
pub(super) fn new(
addr: A,
stream: PublisherStream<Io>,
channel: Channel<TopicMessage, SessionCommand>,
) -> Self {
Self {
addr,
stream,
egress: VecDeque::with_capacity(4),
stats: Arc::new(SessionStats::default()),
driver_channel: channel,
}
}
pub(super) fn stats(&self) -> Arc<SessionStats> {
Arc::clone(&self.stats)
}
fn subscribe(&mut self, topic: String) {
self.egress.push_back(pubsub::Message::new_sub(Bytes::from(topic)));
}
fn unsubscribe(&mut self, topic: String) {
self.egress.push_back(pubsub::Message::new_unsub(Bytes::from(topic)));
}
fn on_incoming(&mut self, incoming: Result<TopicMessage, pubsub::Error>) {
match incoming {
Ok(msg) => {
let now = unix_micros();
self.stats.increment_rx(msg.payload.len());
self.stats.update_latency(now.saturating_sub(msg.timestamp));
if let Err(e) = self.driver_channel.try_send(msg) {
warn!(err = ?e, addr = ?self.addr, "Failed to send message to driver");
}
}
Err(e) => {
error!(err = ?e, addr = ?self.addr, "Error receiving message");
}
}
}
fn on_command(&mut self, cmd: SessionCommand) {
match cmd {
SessionCommand::Subscribe(topic) => self.subscribe(topic),
SessionCommand::Unsubscribe(topic) => self.unsubscribe(topic),
}
}
}
impl<Io: AsyncRead + AsyncWrite + Unpin, A: Address + Unpin> Future for PublisherSession<Io, A> {
type Output = ();
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.stream.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => {
this.on_incoming(result);
continue;
}
Poll::Ready(None) => {
error!(addr = ?this.addr, "Publisher stream closed");
return Poll::Ready(());
}
Poll::Pending => {}
}
let mut progress = false;
while let Some(msg) = this.egress.pop_front() {
if this.stream.poll_send(cx, msg.clone()).is_ready() {
progress = true;
debug!("Queued message for sending: {:?}", msg);
} else {
this.egress.push_back(msg);
break;
}
}
if progress {
continue;
}
if let Poll::Ready(item) = this.driver_channel.poll_recv(cx) {
match item {
Some(cmd) => {
this.on_command(cmd);
continue;
}
None => {
warn!(addr = ?this.addr, "Driver channel closed, shutting down session");
return Poll::Ready(());
}
}
}
return Poll::Pending;
}
}
}