use std::{borrow::Cow, future::Future, pin::Pin, sync::Arc};
use bytes::BytesMut;
use serde::Serialize;
use tracing::warn;
use super::lifecycle::BoxError;
use super::publisher_registry::ErasedPublisher;
use crate::codec::Codec;
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
use crate::codec::DefaultCodec;
use crate::runtime::publish::sealed::Sealed;
use crate::{Extensions, Headers, Publisher, TransactionalPublisher};
type PublishFut<'a> = Pin<Box<dyn Future<Output = Result<(), BoxError>> + Send + 'a>>;
#[derive(Debug, Clone)]
pub struct Outgoing<'a> {
name: Cow<'a, str>,
payload: BytesMut,
headers: Headers,
}
impl<'a> Outgoing<'a> {
#[must_use]
pub fn new(name: impl Into<Cow<'a, str>>, payload: impl Into<BytesMut>) -> Self {
Self {
name: name.into(),
payload: payload.into(),
headers: Headers::new(),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub fn set_name(&mut self, name: impl Into<Cow<'a, str>>) {
self.name = name.into();
}
#[must_use]
pub fn payload(&self) -> &[u8] {
&self.payload
}
pub fn payload_mut(&mut self) -> &mut BytesMut {
&mut self.payload
}
pub fn set_payload(&mut self, payload: impl Into<BytesMut>) {
self.payload = payload.into();
}
#[must_use]
pub fn headers(&self) -> &Headers {
&self.headers
}
pub fn headers_mut(&mut self) -> &mut Headers {
&mut self.headers
}
}
pub trait PublishMiddleware: Send + Sync {
fn on_publish<'a>(&'a self, out: &'a mut Outgoing<'a>, next: PublishNext<'a>)
-> PublishFut<'a>;
}
pub struct PublishNext<'a> {
rest: &'a [Arc<dyn PublishMiddleware>],
publisher: &'a dyn ErasedPublisher,
extensions: &'a Extensions,
}
impl<'a> PublishNext<'a> {
#[must_use]
pub fn run(self, out: &'a mut Outgoing<'a>) -> PublishFut<'a> {
match self.rest.split_first() {
Some((middleware, rest)) => middleware.on_publish(
out,
PublishNext {
rest,
publisher: self.publisher,
extensions: self.extensions,
},
),
None => self
.publisher
.publish_message(out.name(), out.payload(), out.headers()),
}
}
#[must_use]
pub fn extensions(&self) -> &Extensions {
self.extensions
}
}
impl std::fmt::Debug for PublishNext<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PublishNext")
.field("remaining", &self.rest.len())
.finish_non_exhaustive()
}
}
pub(crate) fn run_publish<'a>(
pipeline: &'a [Arc<dyn PublishMiddleware>],
publisher: &'a dyn ErasedPublisher,
out: &'a mut Outgoing<'a>,
extensions: &'a Extensions,
) -> PublishFut<'a> {
PublishNext {
rest: pipeline,
publisher,
extensions,
}
.run(out)
}
pub struct ScopedPublisher<'a> {
publisher: &'a dyn ErasedPublisher,
pipeline: &'a [Arc<dyn PublishMiddleware>],
extensions: &'a Extensions,
}
impl<'a> ScopedPublisher<'a> {
pub(crate) fn new(
publisher: &'a dyn ErasedPublisher,
pipeline: &'a [Arc<dyn PublishMiddleware>],
extensions: &'a Extensions,
) -> Self {
Self {
publisher,
pipeline,
extensions,
}
}
pub async fn publish(&self, mut out: Outgoing<'_>) -> Result<(), BoxError> {
run_publish(self.pipeline, self.publisher, &mut out, self.extensions).await
}
}
impl std::fmt::Debug for ScopedPublisher<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopedPublisher")
.field("layers", &self.pipeline.len())
.finish_non_exhaustive()
}
}
pub trait PublishLayer: Send + Sync {
fn apply(&self, out: &mut Outgoing<'_>);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PublishIdentity;
impl PublishLayer for PublishIdentity {
fn apply(&self, _out: &mut Outgoing<'_>) {}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PublishStack<Inner, Outer> {
inner: Inner,
outer: Outer,
}
impl<Inner: PublishLayer, Outer: PublishLayer> PublishLayer for PublishStack<Inner, Outer> {
fn apply(&self, out: &mut Outgoing<'_>) {
self.inner.apply(out);
self.outer.apply(out);
}
}
pub struct TypedPublisher<P, C, PL = PublishIdentity> {
publisher: P,
codec: C,
layers: PL,
}
impl<P, C> TypedPublisher<P, C, PublishIdentity> {
#[must_use]
pub fn with_codec(publisher: P, codec: C) -> Self {
Self {
publisher,
codec,
layers: PublishIdentity,
}
}
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
impl<P> TypedPublisher<P, DefaultCodec, PublishIdentity> {
#[must_use]
pub fn new(publisher: P) -> Self {
Self::with_codec(publisher, DefaultCodec::default())
}
}
impl<P, C, PL> TypedPublisher<P, C, PL> {
pub(crate) const fn codec(&self) -> &C {
&self.codec
}
#[must_use]
pub fn layer<N>(self, layer: N) -> TypedPublisher<P, C, PublishStack<PL, N>> {
TypedPublisher {
publisher: self.publisher,
codec: self.codec,
layers: PublishStack {
inner: self.layers,
outer: layer,
},
}
}
#[must_use]
pub fn transactional(self) -> Transactional<P, C, PL>
where
P: TransactionalPublisher,
{
Transactional { inner: self }
}
}
impl<P: Publisher, C: Codec, PL: PublishLayer> TypedPublisher<P, C, PL> {
pub(crate) async fn publish<T: Serialize + Sync>(
&self,
name: &str,
value: &T,
pipeline: &[Arc<dyn PublishMiddleware>],
extensions: &Extensions,
) -> Result<(), BoxError> {
let payload = self
.codec
.encode(value)
.map_err(|e| Box::new(e) as BoxError)?;
let mut out = Outgoing::new(name, payload);
self.layers.apply(&mut out);
run_publish(pipeline, &self.publisher, &mut out, extensions).await
}
}
impl<P, C, PL> std::fmt::Debug for TypedPublisher<P, C, PL> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypedPublisher").finish_non_exhaustive()
}
}
pub struct Transactional<P, C, PL = PublishIdentity> {
inner: TypedPublisher<P, C, PL>,
}
impl<P, C, PL> std::fmt::Debug for Transactional<P, C, PL> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Transactional").finish_non_exhaustive()
}
}
mod sealed {
pub trait Sealed {}
impl<P, C, PL> Sealed for super::TypedPublisher<P, C, PL> {}
impl<P, C, PL> Sealed for super::Transactional<P, C, PL> {}
}
pub trait ReplyPublisher: Sealed + Send + Sync {
type Codec: Codec;
#[doc(hidden)]
fn reply_codec(&self) -> &Self::Codec;
#[doc(hidden)]
fn publish_batch<'a, T>(
&'a self,
name: &'a str,
replies: &'a [T],
pipeline: &'a [Arc<dyn PublishMiddleware>],
extensions: &'a Extensions,
) -> impl Future<Output = Result<(), BoxError>> + Send
where
T: Serialize + Sync;
}
impl<P, C, PL> ReplyPublisher for TypedPublisher<P, C, PL>
where
P: Publisher,
C: Codec,
PL: PublishLayer,
{
type Codec = C;
fn reply_codec(&self) -> &C {
self.codec()
}
async fn publish_batch<'a, T>(
&'a self,
name: &'a str,
replies: &'a [T],
pipeline: &'a [Arc<dyn PublishMiddleware>],
extensions: &'a Extensions,
) -> Result<(), BoxError>
where
T: Serialize + Sync,
{
for reply in replies {
self.publish(name, reply, pipeline, extensions).await?;
}
Ok(())
}
}
impl<P, C, PL> ReplyPublisher for Transactional<P, C, PL>
where
P: TransactionalPublisher,
C: Codec,
PL: PublishLayer,
{
type Codec = C;
fn reply_codec(&self) -> &C {
self.inner.codec()
}
async fn publish_batch<'a, T>(
&'a self,
name: &'a str,
replies: &'a [T],
pipeline: &'a [Arc<dyn PublishMiddleware>],
extensions: &'a Extensions,
) -> Result<(), BoxError>
where
T: Serialize + Sync,
{
let publisher = &self.inner.publisher;
publisher
.begin_transaction()
.await
.map_err(|e| Box::new(e) as BoxError)?;
for reply in replies {
if let Err(err) = self.inner.publish(name, reply, pipeline, extensions).await {
abort_quietly(publisher).await;
return Err(err);
}
}
if let Err(err) = publisher.commit().await {
abort_quietly(publisher).await;
return Err(Box::new(err) as BoxError);
}
Ok(())
}
}
async fn abort_quietly<P: TransactionalPublisher>(publisher: &P) {
if let Err(err) = publisher.abort().await {
warn!(target: "ruststream::dispatch", error = %err, "transaction abort failed");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn borrowed_name_is_not_owned() {
let out = Outgoing::new("orders.created", b"payload".as_slice());
assert!(matches!(out.name, Cow::Borrowed(_)));
assert_eq!(out.name(), "orders.created");
assert_eq!(out.payload(), b"payload");
}
#[test]
fn owned_name_moves_in() {
let computed = format!("orders.{}", 42);
let out = Outgoing::new(computed, BytesMut::from(&b"x"[..]));
assert!(matches!(out.name, Cow::Owned(_)));
assert_eq!(out.name(), "orders.42");
}
#[test]
fn payload_mutates_in_place() {
let mut out = Outgoing::new("t", BytesMut::from(&b"body"[..]));
out.payload_mut().extend_from_slice(b"!");
assert_eq!(out.payload(), b"body!");
out.set_payload(b"fresh".as_slice());
assert_eq!(out.payload(), b"fresh");
}
#[test]
fn set_name_and_headers() {
let mut out = Outgoing::new("a", b"".as_slice());
out.set_name("b");
out.headers_mut().insert("k", "v");
assert_eq!(out.name(), "b");
assert_eq!(out.headers().get_str("k"), Some("v"));
}
}