use base64::Engine;
use core::fmt;
use futures_util::sink::Sink;
use pin_project::pin_project;
use redis::{
aio::ConnectionManager,
streams::{StreamTrimStrategy, StreamTrimmingMode},
AsyncCommands, RedisResult,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
use tracing::warn;
use crate::{redis::EncodedMessage, EncodableMessage};
use super::{
RedisError, FORMAT_VERSION_ATTR, ID_KEY, MESSAGE_TIMESTAMP_KEY, PAYLOAD_KEY, PUBLISHER_KEY,
SCHEMA_KEY,
};
use super::{StreamName, ENCODING_ATTR};
#[derive(Debug, Clone)]
pub struct PublisherClient {
client: redis::Client,
publisher_id: PublisherId,
}
impl PublisherClient {
pub fn from_client(client: redis::Client, publisher_id: impl Into<String>) -> Self {
let publisher_id = PublisherId::new(publisher_id);
PublisherClient {
client,
publisher_id,
}
}
}
#[derive(Debug)]
pub enum PublishError<M: EncodableMessage, E> {
Publish {
cause: RedisError,
messages: Vec<M>,
},
Response(E),
InvalidMessage {
cause: M::Error,
message: M,
},
}
impl<M: EncodableMessage, E> fmt::Display for PublishError<M, E>
where
M::Error: fmt::Display,
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PublishError::Publish { messages, .. } => f.write_fmt(format_args!(
"could not publish {} messages",
messages.len()
)),
PublishError::Response(..) => f.write_str(
"could not forward response for a successfully published message to the sink",
),
PublishError::InvalidMessage { .. } => f.write_str("could not validate message"),
}
}
}
impl<M: EncodableMessage, E> std::error::Error for PublishError<M, E>
where
M: fmt::Debug,
M::Error: std::error::Error + 'static,
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PublishError::Publish { cause, .. } => Some(cause),
PublishError::Response(cause) => Some(cause as &_),
PublishError::InvalidMessage { cause, .. } => Some(cause as &_),
}
}
}
pub struct TopicConfig {
pub name: StreamName,
}
impl PublisherClient {
pub async fn publisher(&self) -> Publisher {
let client = self.client.clone();
let publisher_id = self.publisher_id.clone();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
tokio::spawn(async move {
loop {
if rx.is_closed() {
break;
}
let con_res = ConnectionManager::new_with_config(
client.clone(),
super::connection_manager_config(),
)
.await;
if let Ok(mut con) = con_res {
if rx.is_closed() {
break;
}
while let Some(EncodedMessage {
id,
topic,
b64_data,
schema,
}) = rx.recv().await
{
let key = StreamName::from(topic);
let b64_data = b64_data.as_str();
let res =
push(&mut con, &key, b64_data, &schema, &id, &publisher_id.0).await;
if let Err(err) = res {
warn!("{:?}", err);
if err.is_io_error() {
break;
}
}
}
}
}
});
Publisher { sender: tx }
}
}
async fn push(
con: &mut ConnectionManager,
key: &StreamName,
payload: &str,
schema: &str,
hedwig_id: &str,
publisher_id: &str,
) -> RedisResult<()> {
let options = redis::streams::StreamAddOptions::default().trim(StreamTrimStrategy::maxlen(
StreamTrimmingMode::Approx,
1_000,
));
let message_timestamp: String = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.to_string();
con.xadd_options(
&key.0,
"*",
&[
(PAYLOAD_KEY, payload),
FORMAT_VERSION_ATTR,
(ID_KEY, hedwig_id),
(MESSAGE_TIMESTAMP_KEY, &message_timestamp),
(PUBLISHER_KEY, publisher_id),
(SCHEMA_KEY, schema),
ENCODING_ATTR,
],
&options,
)
.await
}
#[derive(Debug, Clone)]
struct PublisherId(String);
impl PublisherId {
fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
}
#[derive(Clone)]
pub struct Publisher {
sender: tokio::sync::mpsc::Sender<EncodedMessage>,
}
impl<M, S> crate::publisher::Publisher<M, S> for Publisher
where
M: EncodableMessage + Send + 'static,
S: Sink<M> + Send + 'static,
{
type PublishError = PublishError<M, S::Error>;
type PublishSink = PublishSink<M, S>;
fn publish_sink_with_responses(
self,
validator: M::Validator,
_response_sink: S,
) -> Self::PublishSink {
PublishSink {
validator,
sender: self.sender.clone(),
_m: std::marker::PhantomData,
buffer: None,
}
}
}
#[pin_project]
pub struct PublishSink<M: EncodableMessage, S: Sink<M>> {
validator: M::Validator,
sender: tokio::sync::mpsc::Sender<EncodedMessage>,
_m: std::marker::PhantomData<(M, S)>,
buffer: Option<M>,
}
impl<M, S> Sink<M> for PublishSink<M, S>
where
M: EncodableMessage + Send + 'static,
S: Sink<M> + Send + 'static,
{
type Error = PublishError<M, S::Error>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush_buffered_message(cx)
}
fn start_send(mut self: Pin<&mut Self>, message: M) -> Result<(), Self::Error> {
let this = self.as_mut().project();
if this.buffer.replace(message).is_some() {
panic!("each `start_send` must be preceded by a successful call to `poll_ready`");
}
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush_buffered_message(cx)
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
fn encode_message<M>(
validator: &M::Validator,
message: M,
) -> Result<EncodedMessage, PublishError<M, std::convert::Infallible>>
where
M: EncodableMessage + Send + 'static,
{
let validated = match message.encode(validator) {
Ok(validated_msg) => validated_msg,
Err(err) => {
return Err(PublishError::InvalidMessage {
cause: err,
message,
})
}
};
let bytes = validated.data();
let schema = validated.schema().to_string().into();
let b64_data = base64::engine::general_purpose::STANDARD.encode(bytes);
let id = validated.uuid().to_string();
Ok(EncodedMessage {
id,
schema,
topic: message.topic(),
b64_data,
})
}
impl<M, S> PublishSink<M, S>
where
M: EncodableMessage + Send + 'static,
S: Sink<M> + Send + 'static,
{
fn poll_flush_buffered_message(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), PublishError<M, S::Error>>> {
let this = self.project();
if this.sender.capacity() == 0 {
cx.waker().wake_by_ref();
return Poll::Pending;
}
let Some(message) = this.buffer.take() else {
return Poll::Ready(Ok(()));
};
let Ok(encoded_message) = encode_message(this.validator, message) else {
return Poll::Ready(Ok(()));
};
this.sender.try_send(encoded_message).unwrap();
Poll::Ready(Ok(()))
}
}