use async_nats::jetstream::consumer::{PullConsumer, pull::Config as ConsumerConfig};
use async_nats::{Client, ToServerAddrs};
use ruststream::{Broker, DescribeServer, ServerSpec, Subscribe};
use std::sync::Arc;
use tokio::sync::OnceCell;
use crate::{
error::NatsError, publisher::NatsPublisher, subscribe_options::SubscribeOptions,
subscriber::NatsSubscriber,
};
#[derive(Clone)]
pub struct NatsBroker {
client: Arc<OnceCell<Client>>,
addrs: Option<String>,
}
impl std::fmt::Debug for NatsBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NatsBroker").finish_non_exhaustive()
}
}
impl NatsBroker {
#[must_use]
pub fn new(addrs: impl Into<String>) -> Self {
Self {
client: Arc::new(OnceCell::new()),
addrs: Some(addrs.into()),
}
}
pub async fn connect(addrs: impl ToServerAddrs) -> Result<Self, NatsError> {
let client = async_nats::connect(addrs)
.await
.map_err(|err| NatsError::Connect(Box::new(err)))?;
Ok(Self::from_client(client))
}
#[must_use]
pub fn from_client(client: Client) -> Self {
Self {
client: Arc::new(OnceCell::new_with(Some(client))),
addrs: None,
}
}
#[must_use]
pub fn client(&self) -> Client {
self.client
.get()
.cloned()
.expect("NatsBroker::client() called before connect()")
}
fn connected(&self) -> Result<Client, NatsError> {
self.client.get().cloned().ok_or(NatsError::NotConnected)
}
pub async fn subscribe(&self, opts: SubscribeOptions) -> Result<NatsSubscriber, NatsError> {
opts.validate()?;
if opts.is_jetstream() {
self.subscribe_jetstream(opts).await
} else {
self.subscribe_core(opts).await
}
}
async fn subscribe_core(&self, opts: SubscribeOptions) -> Result<NatsSubscriber, NatsError> {
let client = self.connected()?;
let subject = opts.subject().to_owned();
let inner = if let Some(queue) = opts.queue_group_ref() {
client
.queue_subscribe(subject.clone(), queue.to_owned())
.await
.map_err(|err| NatsError::Subscribe(Box::new(err)))?
} else {
client
.subscribe(subject.clone())
.await
.map_err(|err| NatsError::Subscribe(Box::new(err)))?
};
Ok(NatsSubscriber::from_core(subject, inner))
}
async fn subscribe_jetstream(
&self,
opts: SubscribeOptions,
) -> Result<NatsSubscriber, NatsError> {
let ctx = async_nats::jetstream::new(self.connected()?);
let stream_name = opts
.stream_ref()
.expect("validated jetstream option")
.to_owned();
let stream = ctx
.get_stream(&stream_name)
.await
.map_err(|err| NatsError::JetStream(Box::new(err)))?;
let consumer_cfg = ConsumerConfig {
durable_name: opts.durable_ref().map(str::to_owned),
filter_subject: opts.filter_subject_or_default(),
max_ack_pending: opts.max_ack_pending_or_default(),
ack_wait: opts.ack_wait_or_default(),
deliver_policy: opts.deliver_policy_or_default(),
..Default::default()
};
let consumer: PullConsumer = stream
.create_consumer(consumer_cfg)
.await
.map_err(|err| NatsError::JetStream(Box::new(err)))?;
let messages = consumer
.messages()
.await
.map_err(|err| NatsError::JetStream(Box::new(err)))?;
Ok(NatsSubscriber::from_jetstream(
opts.subject().to_owned(),
stream_name,
messages,
consumer,
opts.pull_batch_or_default(),
opts.pull_expires_or_default(),
))
}
#[must_use]
pub fn publisher(&self) -> NatsPublisher {
NatsPublisher::new(Arc::clone(&self.client))
}
pub async fn shutdown_client(&self) {
if let Some(client) = self.client.get() {
let _ = client.drain().await;
}
}
}
impl Broker for NatsBroker {
type Error = NatsError;
async fn connect(&self) -> Result<(), Self::Error> {
self.client
.get_or_try_init(|| async {
let addrs = self.addrs.as_deref().ok_or(NatsError::NotConnected)?;
async_nats::connect(addrs)
.await
.map_err(|err| NatsError::Connect(Box::new(err)))
})
.await?;
Ok(())
}
async fn shutdown(&self) -> Result<(), Self::Error> {
self.shutdown_client().await;
Ok(())
}
}
#[allow(clippy::use_self)]
impl Subscribe for NatsBroker {
type Subscriber = NatsSubscriber;
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
NatsBroker::subscribe(self, SubscribeOptions::new(name)).await
}
}
impl DescribeServer for NatsBroker {
fn describe_server(&self) -> ServerSpec {
if let Some(client) = self.client.get() {
let info = client.server_info();
return ServerSpec::new(format!("{}:{}", info.host, info.port), "nats");
}
let host = self
.addrs
.as_deref()
.unwrap_or("")
.trim_start_matches("nats://")
.trim_start_matches("tls://")
.to_owned();
ServerSpec::new(host, "nats")
}
}
#[cfg(test)]
mod tests {
use ruststream::{OutgoingMessage, Publisher};
use super::*;
#[tokio::test]
async fn new_does_not_connect() {
let broker = NatsBroker::new("nats://127.0.0.1:4222");
let publish_err = broker
.publisher()
.publish(OutgoingMessage::new("orders", b"{}".as_slice()))
.await
.unwrap_err();
assert!(matches!(publish_err, NatsError::NotConnected));
let subscribe_err = broker
.subscribe(SubscribeOptions::new("orders"))
.await
.unwrap_err();
assert!(matches!(subscribe_err, NatsError::NotConnected));
}
}