use futures_util::{Stream, StreamExt};
use std::{fmt::Display, time::Duration};
use async_nats::{
Request, client,
jetstream::{self, message::OutboundMessage, response::Response},
subject::ToSubject,
};
use serde::Deserialize;
const MAX_BATCH_SIZE: u64 = 1000;
pub trait BatchPublishExt:
client::traits::Requester
+ client::traits::Publisher
+ jetstream::context::traits::TimeoutProvider
+ Clone
{
fn batch_publish(&self) -> BatchPublishBuilder<Self>;
fn batch_publish_all(&self) -> BatchPublishAllBuilder<Self>;
}
impl<C> BatchPublishExt for C
where
C: client::traits::Requester
+ client::traits::Publisher
+ jetstream::context::traits::TimeoutProvider
+ Clone,
{
fn batch_publish(&self) -> BatchPublishBuilder<Self> {
BatchPublishBuilder::new(self.clone())
}
fn batch_publish_all(&self) -> BatchPublishAllBuilder<Self> {
BatchPublishAllBuilder::new(self.clone())
}
}
pub struct BatchPublishBuilder<C> {
client: C,
timeout: Duration,
ack_first: bool,
ack_every: Option<u64>,
}
impl<C> BatchPublishBuilder<C>
where
C: client::traits::Requester
+ client::traits::Publisher
+ jetstream::context::traits::TimeoutProvider
+ Clone,
{
pub fn new(context: C) -> Self {
Self {
client: context.clone(),
ack_first: true,
timeout: context.timeout(),
ack_every: None,
}
}
pub fn ack_every(mut self, count: u64) -> Self {
self.ack_every = Some(count);
self
}
pub fn ack_first(mut self, ack_first: bool) -> Self {
self.ack_first = ack_first;
self
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.timeout = duration;
self
}
pub fn build(self) -> BatchPublish<C> {
BatchPublish {
context: self.client,
sequence: 0,
batch_id: nuid::next().to_string(),
ack_every: self.ack_every,
ack_first: self.ack_first,
timeout: self.timeout,
closed: false,
}
}
}
pub struct BatchPublish<C> {
pub(crate) context: C,
pub(crate) sequence: u64,
pub(crate) batch_id: String,
ack_every: Option<u64>,
ack_first: bool,
timeout: Duration,
closed: bool,
}
impl<C> BatchPublish<C>
where
C: client::traits::Requester
+ client::traits::Publisher
+ jetstream::context::traits::TimeoutProvider
+ Clone,
{
pub fn batch_id(&self) -> &str {
&self.batch_id
}
pub fn size(&self) -> u64 {
self.sequence
}
pub async fn add<S: ToSubject>(
&mut self,
subject: S,
payload: bytes::Bytes,
) -> Result<(), BatchPublishError> {
self.add_message(OutboundMessage {
subject: subject.to_subject(),
payload,
headers: None,
})
.await
}
pub async fn add_message(
&mut self,
mut message: jetstream::message::OutboundMessage,
) -> Result<(), BatchPublishError> {
if self.closed {
return Err(BatchPublishError::new(BatchPublishErrorKind::BatchClosed));
}
Self::reject_protocol_headers(message.headers.as_ref(), self.sequence)?;
if self.sequence >= MAX_BATCH_SIZE {
return Err(BatchPublishError::new(
BatchPublishErrorKind::MaxMessagesExceeded,
));
}
self.sequence += 1;
self.add_header(&mut message);
let result = if let Some(ack_every) = self.ack_every
&& self.sequence.is_multiple_of(ack_every)
{
self.add_request(message).await
} else if self.ack_first && self.sequence == 1 {
self.add_request(message).await
} else {
self.context
.publish_message(message.into())
.await
.map_err(|e| BatchPublishError::with_source(BatchPublishErrorKind::Publish, e))
};
if let Err(e) = result {
self.closed = true;
return Err(e);
}
Ok(())
}
pub fn is_closed(&self) -> bool {
self.closed
}
pub async fn commit<S: ToSubject>(
self,
subject: S,
payload: bytes::Bytes,
) -> Result<BatchPubAck, BatchPublishError> {
self.commit_message(OutboundMessage {
subject: subject.to_subject(),
payload,
headers: None,
})
.await
}
pub async fn commit_message(
mut self,
mut message: jetstream::message::OutboundMessage,
) -> Result<BatchPubAck, BatchPublishError> {
if self.closed {
return Err(BatchPublishError::new(BatchPublishErrorKind::BatchClosed));
}
Self::reject_protocol_headers(message.headers.as_ref(), self.sequence)?;
if self.sequence >= MAX_BATCH_SIZE {
return Err(BatchPublishError::new(
BatchPublishErrorKind::MaxMessagesExceeded,
));
}
self.sequence += 1;
self.add_header(&mut message);
let headers = message
.headers
.get_or_insert_with(async_nats::HeaderMap::new);
headers.insert("Nats-Batch-Commit", "1");
self.commit_request(message).await
}
fn reject_protocol_headers(
headers: Option<&async_nats::HeaderMap>,
prior_sequence: u64,
) -> Result<(), BatchPublishError> {
let Some(headers) = headers else {
return Ok(());
};
const REJECTED: &[&str] = &[
"Nats-Msg-Id",
"Nats-Expected-Last-Msg-Id",
"Nats-Batch-Commit",
"Nats-Batch-Id",
"Nats-Batch-Sequence",
];
if REJECTED.iter().any(|h| headers.get(*h).is_some()) {
return Err(BatchPublishError::new(
BatchPublishErrorKind::BatchPublishUnsupportedHeader,
));
}
if prior_sequence >= 1 && headers.get("Nats-Expected-Last-Sequence").is_some() {
return Err(BatchPublishError::new(
BatchPublishErrorKind::BatchPublishUnsupportedHeader,
));
}
Ok(())
}
pub fn discard(self) {
}
fn add_header(&self, message: &mut jetstream::message::OutboundMessage) {
let headers = message
.headers
.get_or_insert_with(async_nats::HeaderMap::new);
headers.insert("Nats-Batch-Id", self.batch_id.clone());
headers.insert("Nats-Batch-Sequence", self.sequence.to_string());
}
async fn add_request(&self, message: OutboundMessage) -> Result<(), BatchPublishError> {
let request = Request {
payload: Some(message.payload),
headers: message.headers,
timeout: Some(Some(self.timeout)),
inbox: None,
};
let response = self
.context
.send_request(message.subject, request)
.await
.map_err(|e| BatchPublishError::with_source(BatchPublishErrorKind::Request, e))?;
if response.payload.is_empty() {
return Ok(());
}
let resp: Response<()> = serde_json::from_slice(response.payload.as_ref())
.map_err(|e| BatchPublishError::with_source(BatchPublishErrorKind::Serialization, e))?;
match resp {
Response::Err { error } => {
let kind = BatchPublishErrorKind::from_api_error(&error);
Err(BatchPublishError::with_source(kind, error))
}
Response::Ok(()) => Ok(()),
}
}
async fn commit_request(
&self,
message: OutboundMessage,
) -> Result<BatchPubAck, BatchPublishError> {
let request = Request {
payload: Some(message.payload),
headers: message.headers,
timeout: Some(Some(self.timeout)),
inbox: None,
};
let response = self
.context
.send_request(message.subject, request)
.await
.map_err(|e| BatchPublishError::with_source(BatchPublishErrorKind::Request, e))?;
let resp: Response<BatchPubAck> = serde_json::from_slice(response.payload.as_ref())
.map_err(|e| BatchPublishError::with_source(BatchPublishErrorKind::Serialization, e))?;
match resp {
Response::Err { error } => {
let kind = BatchPublishErrorKind::from_api_error(&error);
Err(BatchPublishError::with_source(kind, error))
}
Response::Ok(ack) => {
if ack.stream.is_empty()
|| ack.batch_id != self.batch_id
|| ack.batch_size != self.sequence
{
return Err(BatchPublishError::new(BatchPublishErrorKind::InvalidAck));
}
Ok(ack)
}
}
}
}
#[derive(Debug, Deserialize)]
pub struct BatchPubAck {
pub stream: String,
#[serde(rename = "seq")]
pub sequence: u64,
#[serde(default)]
pub domain: Option<String>,
#[serde(rename = "batch")]
pub batch_id: String,
#[serde(rename = "count")]
pub batch_size: u64,
#[serde(default, rename = "val")]
pub value: Option<String>,
}
pub struct BatchPublishAllBuilder<C> {
client: C,
timeout: Duration,
ack_first: bool,
ack_every: Option<u64>,
}
impl<C> BatchPublishAllBuilder<C>
where
C: client::traits::Requester
+ client::traits::Publisher
+ jetstream::context::traits::TimeoutProvider
+ Clone,
{
pub fn new(client: C) -> Self {
Self {
client: client.clone(),
ack_first: true,
timeout: client.timeout(),
ack_every: None,
}
}
pub fn ack_every(mut self, count: u64) -> Self {
self.ack_every = Some(count);
self
}
pub fn ack_first(mut self, ack_first: bool) -> Self {
self.ack_first = ack_first;
self
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.timeout = duration;
self
}
pub async fn publish<S>(self, messages: S) -> Result<BatchPubAck, BatchPublishError>
where
S: Stream<Item = OutboundMessage> + Unpin,
{
let mut batch = BatchPublish {
context: self.client,
sequence: 0,
batch_id: nuid::next().to_string(),
ack_every: self.ack_every,
ack_first: self.ack_first,
timeout: self.timeout,
closed: false,
};
let mut last_msg = None;
futures_util::pin_mut!(messages);
while let Some(msg) = messages.next().await {
if let Some(prev) = last_msg.replace(msg) {
batch.add_message(prev).await?;
}
}
match last_msg {
Some(msg) => batch.commit_message(msg).await,
None => Err(BatchPublishError::new(BatchPublishErrorKind::EmptyBatch)),
}
}
}
pub type BatchPublishError = async_nats::error::Error<BatchPublishErrorKind>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum BatchPublishErrorKind {
Request,
Publish,
Serialization,
MaxMessagesExceeded,
EmptyBatch,
BatchClosed,
InvalidAck,
BatchPublishNotEnabled,
BatchPublishIncomplete,
BatchPublishTooManyInflight,
BatchPublishMissingSeq,
BatchPublishInvalidId,
BatchPublishInvalidCommit,
BatchPublishDuplicateMsgId,
BatchPublishMirror,
BatchPublishUnsupportedHeader,
Other,
}
impl BatchPublishErrorKind {
fn from_api_error(error: &async_nats::jetstream::Error) -> Self {
use async_nats::jetstream::ErrorCode;
match error.error_code() {
ErrorCode::ATOMIC_PUBLISH_DISABLED => Self::BatchPublishNotEnabled,
ErrorCode::ATOMIC_PUBLISH_INCOMPLETE_BATCH => Self::BatchPublishIncomplete,
ErrorCode::ATOMIC_PUBLISH_TOO_MANY_INFLIGHT => Self::BatchPublishTooManyInflight,
ErrorCode::ATOMIC_PUBLISH_UNSUPPORTED_HEADER => Self::BatchPublishUnsupportedHeader,
ErrorCode::ATOMIC_PUBLISH_TOO_LARGE_BATCH => Self::MaxMessagesExceeded,
ErrorCode::ATOMIC_PUBLISH_MISSING_SEQ => Self::BatchPublishMissingSeq,
ErrorCode::ATOMIC_PUBLISH_INVALID_BATCH_ID => Self::BatchPublishInvalidId,
ErrorCode::ATOMIC_PUBLISH_INVALID_BATCH_COMMIT => Self::BatchPublishInvalidCommit,
ErrorCode::ATOMIC_PUBLISH_CONTAINS_DUPLICATE_MESSAGE => {
Self::BatchPublishDuplicateMsgId
}
ErrorCode::MIRROR_WITH_ATOMIC_PUBLISH => Self::BatchPublishMirror,
_ => Self::Other,
}
}
}
impl Display for BatchPublishErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Request => write!(f, "request failed"),
Self::Publish => write!(f, "publish failed"),
Self::Serialization => write!(f, "serialization/deserialization error"),
Self::MaxMessagesExceeded => write!(f, "batch exceeds server limit (1000 messages)"),
Self::EmptyBatch => write!(f, "empty batch cannot be committed"),
Self::BatchPublishNotEnabled => write!(f, "batch publishing not enabled on stream"),
Self::BatchPublishIncomplete => {
write!(f, "batch publish is incomplete and was abandoned")
}
Self::BatchPublishTooManyInflight => {
write!(
f,
"server has too many inflight batches (50 per stream, 1000 server-wide)"
)
}
Self::BatchPublishMissingSeq => {
write!(f, "batch sequence header missing or malformed")
}
Self::BatchPublishInvalidId => {
write!(f, "batch id is invalid (e.g. exceeds 64 characters)")
}
Self::BatchPublishInvalidCommit => write!(f, "batch commit marker is invalid"),
Self::BatchPublishDuplicateMsgId => {
write!(f, "two messages in the batch share the same Nats-Msg-Id")
}
Self::BatchPublishMirror => write!(
f,
"stream is a mirror; mirrors are incompatible with atomic publish"
),
Self::BatchPublishUnsupportedHeader => write!(
f,
"batch contains an unsupported header (e.g. Nats-Msg-Id, Nats-Expected-Last-Msg-Id, or a protocol header set by the user)"
),
Self::BatchClosed => {
write!(f, "batch was closed by a prior error and cannot be reused")
}
Self::InvalidAck => write!(f, "server commit ack failed invariant checks"),
Self::Other => write!(f, "other error"),
}
}
}