use core::pin::Pin;
use core::task::{Context, Poll};
use crate::connection::State;
use crate::subject::ToSubject;
use crate::{PublishMessage, ServerInfo};
use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber};
use crate::error::Error;
use bytes::Bytes;
use futures::future::TryFutureExt;
use futures::{Sink, SinkExt as _, StreamExt};
use once_cell::sync::Lazy;
use portable_atomic::AtomicU64;
use regex::Regex;
use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::PollSender;
use tracing::trace;
static VERSION_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"\Av?([0-9]+)\.?([0-9]+)?\.?([0-9]+)?").unwrap());
pub type PublishError = Error<PublishErrorKind>;
impl From<tokio::sync::mpsc::error::SendError<Command>> for PublishError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
PublishError::with_source(PublishErrorKind::Send, err)
}
}
impl From<tokio_util::sync::PollSendError<Command>> for PublishError {
fn from(err: tokio_util::sync::PollSendError<Command>) -> Self {
PublishError::with_source(PublishErrorKind::Send, err)
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum PublishErrorKind {
MaxPayloadExceeded,
Send,
}
impl Display for PublishErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PublishErrorKind::MaxPayloadExceeded => write!(f, "max payload size exceeded"),
PublishErrorKind::Send => write!(f, "failed to send message"),
}
}
}
#[derive(Clone, Debug)]
pub struct Client {
info: tokio::sync::watch::Receiver<ServerInfo>,
pub(crate) state: tokio::sync::watch::Receiver<State>,
pub(crate) sender: mpsc::Sender<Command>,
poll_sender: PollSender<Command>,
next_subscription_id: Arc<AtomicU64>,
subscription_capacity: usize,
inbox_prefix: Arc<str>,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
connection_stats: Arc<Statistics>,
}
impl Sink<PublishMessage> for Client {
type Error = PublishError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_ready_unpin(cx).map_err(Into::into)
}
fn start_send(mut self: Pin<&mut Self>, msg: PublishMessage) -> Result<(), Self::Error> {
self.poll_sender
.start_send_unpin(Command::Publish(msg))
.map_err(Into::into)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_flush_unpin(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_close_unpin(cx).map_err(Into::into)
}
}
impl Client {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
info: tokio::sync::watch::Receiver<ServerInfo>,
state: tokio::sync::watch::Receiver<State>,
sender: mpsc::Sender<Command>,
capacity: usize,
inbox_prefix: String,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
statistics: Arc<Statistics>,
) -> Client {
let poll_sender = PollSender::new(sender.clone());
Client {
info,
state,
sender,
poll_sender,
next_subscription_id: Arc::new(AtomicU64::new(1)),
subscription_capacity: capacity,
inbox_prefix: inbox_prefix.into(),
request_timeout,
max_payload,
connection_stats: statistics,
}
}
pub fn timeout(&self) -> Option<Duration> {
self.request_timeout
}
pub fn server_info(&self) -> ServerInfo {
self.info.borrow().to_owned()
}
pub fn is_server_compatible(&self, major: i64, minor: i64, patch: i64) -> bool {
let info = self.server_info();
let server_version_captures = match VERSION_RE.captures(&info.version) {
Some(captures) => captures,
None => return false,
};
let server_major = server_version_captures
.get(1)
.map(|m| m.as_str().parse::<i64>().unwrap())
.unwrap();
let server_minor = server_version_captures
.get(2)
.map(|m| m.as_str().parse::<i64>().unwrap())
.unwrap();
let server_patch = server_version_captures
.get(3)
.map(|m| m.as_str().parse::<i64>().unwrap())
.unwrap();
if server_major < major
|| (server_major == major && server_minor < minor)
|| (server_major == major && server_minor == minor && server_patch < patch)
{
return false;
}
true
}
pub async fn publish<S: ToSubject>(
&self,
subject: S,
payload: Bytes,
) -> Result<(), PublishError> {
let subject = subject.to_subject();
let max_payload = self.max_payload.load(Ordering::Relaxed);
if payload.len() > max_payload {
return Err(PublishError::with_source(
PublishErrorKind::MaxPayloadExceeded,
format!(
"Payload size limit of {} exceeded by message size of {}",
payload.len(),
max_payload
),
));
}
self.sender
.send(Command::Publish(PublishMessage {
subject,
payload,
reply: None,
headers: None,
}))
.await?;
Ok(())
}
pub async fn publish_with_headers<S: ToSubject>(
&self,
subject: S,
headers: HeaderMap,
payload: Bytes,
) -> Result<(), PublishError> {
let subject = subject.to_subject();
self.sender
.send(Command::Publish(PublishMessage {
subject,
payload,
reply: None,
headers: Some(headers),
}))
.await?;
Ok(())
}
pub async fn publish_with_reply<S: ToSubject, R: ToSubject>(
&self,
subject: S,
reply: R,
payload: Bytes,
) -> Result<(), PublishError> {
let subject = subject.to_subject();
let reply = reply.to_subject();
self.sender
.send(Command::Publish(PublishMessage {
subject,
payload,
reply: Some(reply),
headers: None,
}))
.await?;
Ok(())
}
pub async fn publish_with_reply_and_headers<S: ToSubject, R: ToSubject>(
&self,
subject: S,
reply: R,
headers: HeaderMap,
payload: Bytes,
) -> Result<(), PublishError> {
let subject = subject.to_subject();
let reply = reply.to_subject();
self.sender
.send(Command::Publish(PublishMessage {
subject,
payload,
reply: Some(reply),
headers: Some(headers),
}))
.await?;
Ok(())
}
pub async fn request<S: ToSubject>(
&self,
subject: S,
payload: Bytes,
) -> Result<Message, RequestError> {
let subject = subject.to_subject();
trace!(
"request sent to subject: {} ({})",
subject.as_ref(),
payload.len()
);
let request = Request::new().payload(payload);
self.send_request(subject, request).await
}
pub async fn request_with_headers<S: ToSubject>(
&self,
subject: S,
headers: HeaderMap,
payload: Bytes,
) -> Result<Message, RequestError> {
let subject = subject.to_subject();
let request = Request::new().headers(headers).payload(payload);
self.send_request(subject, request).await
}
pub async fn send_request<S: ToSubject>(
&self,
subject: S,
request: Request,
) -> Result<Message, RequestError> {
let subject = subject.to_subject();
if let Some(inbox) = request.inbox {
let timeout = request.timeout.unwrap_or(self.request_timeout);
let mut subscriber = self.subscribe(inbox.clone()).await?;
let payload: Bytes = request.payload.unwrap_or_default();
match request.headers {
Some(headers) => {
self.publish_with_reply_and_headers(subject, inbox, headers, payload)
.await?
}
None => self.publish_with_reply(subject, inbox, payload).await?,
}
let request = match timeout {
Some(timeout) => {
tokio::time::timeout(timeout, subscriber.next())
.map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err))
.await?
}
None => subscriber.next().await,
};
match request {
Some(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(RequestError::with_source(
RequestErrorKind::NoResponders,
"no responders",
));
}
Ok(message)
}
None => Err(RequestError::with_source(
RequestErrorKind::Other,
"broken pipe",
)),
}
} else {
let (sender, receiver) = oneshot::channel();
let payload = request.payload.unwrap_or_default();
let respond = self.new_inbox().into();
let headers = request.headers;
self.sender
.send(Command::Request {
subject,
payload,
respond,
headers,
sender,
})
.map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))
.await?;
let timeout = request.timeout.unwrap_or(self.request_timeout);
let request = match timeout {
Some(timeout) => {
tokio::time::timeout(timeout, receiver)
.map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err))
.await?
}
None => receiver.await,
};
match request {
Ok(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(RequestError::with_source(
RequestErrorKind::NoResponders,
"no responders",
));
}
Ok(message)
}
Err(err) => Err(RequestError::with_source(RequestErrorKind::Other, err)),
}
}
}
pub fn new_inbox(&self) -> String {
format!("{}.{}", self.inbox_prefix, nuid::next())
}
pub async fn subscribe<S: ToSubject>(&self, subject: S) -> Result<Subscriber, SubscribeError> {
let subject = subject.to_subject();
let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(self.subscription_capacity);
self.sender
.send(Command::Subscribe {
sid,
subject,
queue_group: None,
sender,
})
.await?;
Ok(Subscriber::new(sid, self.sender.clone(), receiver))
}
pub async fn queue_subscribe<S: ToSubject>(
&self,
subject: S,
queue_group: String,
) -> Result<Subscriber, SubscribeError> {
let subject = subject.to_subject();
let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(self.subscription_capacity);
self.sender
.send(Command::Subscribe {
sid,
subject,
queue_group: Some(queue_group),
sender,
})
.await?;
Ok(Subscriber::new(sid, self.sender.clone(), receiver))
}
pub async fn flush(&self) -> Result<(), FlushError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.sender
.send(Command::Flush { observer: tx })
.await
.map_err(|err| FlushError::with_source(FlushErrorKind::SendError, err))?;
rx.await
.map_err(|err| FlushError::with_source(FlushErrorKind::FlushError, err))?;
Ok(())
}
pub async fn drain(&self) -> Result<(), DrainError> {
self.sender.send(Command::Drain { sid: None }).await?;
Ok(())
}
pub fn connection_state(&self) -> State {
self.state.borrow().to_owned()
}
pub async fn force_reconnect(&self) -> Result<(), ReconnectError> {
self.sender
.send(Command::Reconnect)
.await
.map_err(Into::into)
}
pub fn statistics(&self) -> Arc<Statistics> {
self.connection_stats.clone()
}
}
#[derive(Default)]
pub struct Request {
payload: Option<Bytes>,
headers: Option<HeaderMap>,
timeout: Option<Option<Duration>>,
inbox: Option<String>,
}
impl Request {
pub fn new() -> Request {
Default::default()
}
pub fn payload(mut self, payload: Bytes) -> Request {
self.payload = Some(payload);
self
}
pub fn headers(mut self, headers: HeaderMap) -> Request {
self.headers = Some(headers);
self
}
pub fn timeout(mut self, timeout: Option<Duration>) -> Request {
self.timeout = Some(timeout);
self
}
pub fn inbox(mut self, inbox: String) -> Request {
self.inbox = Some(inbox);
self
}
}
#[derive(Error, Debug)]
#[error("failed to send reconnect: {0}")]
pub struct ReconnectError(#[source] crate::Error);
impl From<tokio::sync::mpsc::error::SendError<Command>> for ReconnectError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
ReconnectError(Box::new(err))
}
}
#[derive(Error, Debug)]
#[error("failed to send subscribe: {0}")]
pub struct SubscribeError(#[source] crate::Error);
impl From<tokio::sync::mpsc::error::SendError<Command>> for SubscribeError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
SubscribeError(Box::new(err))
}
}
#[derive(Error, Debug)]
#[error("failed to send drain: {0}")]
pub struct DrainError(#[source] crate::Error);
impl From<tokio::sync::mpsc::error::SendError<Command>> for DrainError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
DrainError(Box::new(err))
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RequestErrorKind {
TimedOut,
NoResponders,
Other,
}
impl Display for RequestErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TimedOut => write!(f, "request timed out"),
Self::NoResponders => write!(f, "no responders"),
Self::Other => write!(f, "request failed"),
}
}
}
pub type RequestError = Error<RequestErrorKind>;
impl From<PublishError> for RequestError {
fn from(e: PublishError) -> Self {
RequestError::with_source(RequestErrorKind::Other, e)
}
}
impl From<SubscribeError> for RequestError {
fn from(e: SubscribeError) -> Self {
RequestError::with_source(RequestErrorKind::Other, e)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum FlushErrorKind {
SendError,
FlushError,
}
impl Display for FlushErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SendError => write!(f, "failed to send flush request"),
Self::FlushError => write!(f, "flush failed"),
}
}
}
pub type FlushError = Error<FlushErrorKind>;
#[derive(Default, Debug)]
pub struct Statistics {
pub in_bytes: AtomicU64,
pub out_bytes: AtomicU64,
pub in_messages: AtomicU64,
pub out_messages: AtomicU64,
pub connects: AtomicU64,
}