use std::{
pin::Pin,
task::{Context, Poll, ready},
};
use bytes::Bytes;
use futures::{SinkExt, Stream, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tracing::{debug, trace};
use super::SubError;
use msg_wire::pubsub;
pub(super) struct PublisherStream<Io> {
conn: Framed<Io, pubsub::Codec>,
flush: bool,
}
impl<Io: AsyncRead + AsyncWrite + Unpin> PublisherStream<Io> {
pub fn poll_send(
&mut self,
cx: &mut Context<'_>,
msg: pubsub::Message,
) -> Poll<Result<(), SubError>> {
ready!(self.conn.poll_ready_unpin(cx))?;
debug!("Sending message to topic: {:?}", msg.topic());
self.conn.start_send_unpin(msg)?;
self.flush = true;
cx.waker().wake_by_ref();
Poll::Ready(Ok(()))
}
}
impl<Io: AsyncRead + AsyncWrite + Unpin> From<Framed<Io, pubsub::Codec>> for PublisherStream<Io> {
fn from(conn: Framed<Io, pubsub::Codec>) -> Self {
Self { conn, flush: false }
}
}
pub(super) struct TopicMessage {
pub timestamp: u64,
pub compression_type: u8,
pub topic: String,
pub payload: Bytes,
}
impl<Io: AsyncRead + AsyncWrite + Unpin> Stream for PublisherStream<Io> {
type Item = Result<TopicMessage, pubsub::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.flush && this.conn.poll_flush_unpin(cx).is_ready() {
trace!("Flushed connection");
this.flush = false
}
if let Some(result) = ready!(this.conn.poll_next_unpin(cx)) {
return Poll::Ready(Some(result.map(|msg| {
let timestamp = msg.timestamp();
let compression_type = msg.compression_type();
let (topic, payload) = msg.into_parts();
let topic = String::from_utf8_lossy(&topic).to_string();
TopicMessage { compression_type, timestamp, topic, payload }
})));
}
Poll::Pending
}
}