use crate::error::AckError;
use crate::subscriber::lease_state::NACK_SHUTDOWN_ERROR;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot::Receiver;
#[derive(Debug, PartialEq)]
pub(super) enum Action {
Ack(String),
Nack(String),
ExactlyOnceAck(String),
ExactlyOnceNack(String),
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Handler {
AtLeastOnce(AtLeastOnce),
ExactlyOnce(ExactlyOnce),
}
impl Handler {
pub fn ack(self) {
match self {
Handler::AtLeastOnce(h) => h.ack(),
Handler::ExactlyOnce(h) => h.ack(),
}
}
pub fn nack(self) {
match self {
Handler::AtLeastOnce(h) => h.nack(),
Handler::ExactlyOnce(h) => h.nack(),
}
}
#[cfg(test)]
pub(crate) fn ack_id(&self) -> &str {
match self {
Handler::AtLeastOnce(h) => h.ack_id(),
Handler::ExactlyOnce(h) => h.ack_id(),
}
}
}
#[derive(Debug)]
struct AtLeastOnceImpl {
ack_id: String,
ack_tx: UnboundedSender<Action>,
}
impl AtLeastOnceImpl {
fn ack(self) {
let _ = self.ack_tx.send(Action::Ack(self.ack_id));
}
fn nack(self) {
let _ = self.ack_tx.send(Action::Nack(self.ack_id));
}
}
#[derive(Debug)]
pub struct AtLeastOnce {
inner: Option<AtLeastOnceImpl>,
}
impl AtLeastOnce {
pub(super) fn new(ack_id: String, ack_tx: UnboundedSender<Action>) -> Self {
Self {
inner: Some(AtLeastOnceImpl { ack_id, ack_tx }),
}
}
pub fn ack(mut self) {
if let Some(inner) = self.inner.take() {
inner.ack();
}
}
pub fn nack(mut self) {
if let Some(inner) = self.inner.take() {
inner.nack();
}
}
#[cfg(test)]
pub(crate) fn ack_id(&self) -> &str {
self.inner
.as_ref()
.map(|i| i.ack_id.as_str())
.unwrap_or_default()
}
}
impl Drop for AtLeastOnce {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
inner.nack();
}
}
}
#[derive(Debug)]
pub struct ExactlyOnce {
inner: Option<ExactlyOnceImpl>,
}
impl ExactlyOnce {
pub(super) fn new(
ack_id: String,
ack_tx: UnboundedSender<Action>,
result_rx: Receiver<AckResult>,
) -> Self {
Self {
inner: Some(ExactlyOnceImpl {
ack_id,
ack_tx,
result_rx,
}),
}
}
pub(crate) fn ack(mut self) {
if let Some(inner) = self.inner.take() {
inner.ack();
}
}
pub(crate) fn nack(mut self) {
if let Some(inner) = self.inner.take() {
inner.nack();
}
}
pub async fn confirmed_ack(mut self) -> std::result::Result<(), AckError> {
let inner = self.inner.take().expect("handler impl is always some");
inner.confirmed_ack().await
}
pub async fn confirmed_nack(mut self) -> std::result::Result<(), AckError> {
let inner = self.inner.take().expect("handler impl is always some");
inner.confirmed_nack().await
}
#[cfg(test)]
pub(crate) fn ack_id(&self) -> &str {
self.inner
.as_ref()
.map(|i| i.ack_id.as_str())
.unwrap_or_default()
}
}
impl Drop for ExactlyOnce {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
inner.nack();
}
}
}
#[derive(Debug)]
struct ExactlyOnceImpl {
pub(super) ack_id: String,
pub(super) ack_tx: UnboundedSender<Action>,
pub(super) result_rx: Receiver<AckResult>,
}
impl ExactlyOnceImpl {
pub fn ack(self) {
let _ = self.ack_tx.send(Action::ExactlyOnceAck(self.ack_id));
}
pub fn nack(self) {
let _ = self.ack_tx.send(Action::ExactlyOnceNack(self.ack_id));
}
pub async fn confirmed_ack(self) -> AckResult {
self.ack_tx
.send(Action::ExactlyOnceAck(self.ack_id))
.map_err(|_| AckError::ShutdownBeforeAck)?;
self.result_rx
.await
.map_err(|e| AckError::Shutdown(e.into()))?
}
pub async fn confirmed_nack(self) -> AckResult {
self.ack_tx
.send(Action::ExactlyOnceNack(self.ack_id))
.map_err(|_| AckError::Shutdown(NACK_SHUTDOWN_ERROR.into()))?;
self.result_rx
.await
.map_err(|e| AckError::Shutdown(e.into()))?
}
}
pub(super) type AckResult = std::result::Result<(), AckError>;
#[cfg(test)]
mod tests {
use std::error::Error;
use super::super::lease_state::tests::test_id;
use super::*;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::oneshot::channel;
#[test]
fn handler_at_least_once_ack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = Handler::AtLeastOnce(AtLeastOnce::new(test_id(1), ack_tx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.ack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Ack(test_id(1)));
Ok(())
}
#[test]
fn handler_at_least_once_nack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = Handler::AtLeastOnce(AtLeastOnce::new(test_id(1), ack_tx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.nack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Nack(test_id(1)));
Ok(())
}
#[test]
fn handler_exactly_once_ack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = Handler::ExactlyOnce(ExactlyOnce::new(test_id(1), ack_tx, result_rx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.ack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceAck(test_id(1)));
Ok(())
}
#[test]
fn handler_exactly_once_nack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = Handler::ExactlyOnce(ExactlyOnce::new(test_id(1), ack_tx, result_rx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.nack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceNack(test_id(1)));
Ok(())
}
#[test]
fn at_least_once_ack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = AtLeastOnce::new(test_id(1), ack_tx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.ack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Ack(test_id(1)));
Ok(())
}
#[test]
fn at_least_once_nack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = AtLeastOnce::new(test_id(1), ack_tx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.nack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Nack(test_id(1)));
Ok(())
}
#[test]
fn exactly_once_ack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.ack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceAck(test_id(1)));
Ok(())
}
#[tokio::test]
async fn exactly_once_success() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
let task = tokio::task::spawn(async move { h.confirmed_ack().await });
let ack = ack_rx.recv().await.expect("ack should be sent");
assert_eq!(ack, Action::ExactlyOnceAck(test_id(1)));
result_tx
.send(Ok(()))
.expect("sending on a channel succeeds");
task.await??;
Ok(())
}
#[tokio::test]
async fn exactly_once_nack_success() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
let task = tokio::task::spawn(async move { h.confirmed_nack().await });
let nack = ack_rx.recv().await.expect("ack should be sent");
assert_eq!(nack, Action::ExactlyOnceNack(test_id(1)));
result_tx
.send(Ok(()))
.expect("sending on a channel succeeds");
task.await??;
Ok(())
}
#[tokio::test]
async fn exactly_once_error() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
let task = tokio::task::spawn(async move { h.confirmed_ack().await });
let ack = ack_rx.recv().await.expect("ack should be sent");
assert_eq!(ack, Action::ExactlyOnceAck(test_id(1)));
result_tx
.send(Err(AckError::LeaseExpired))
.expect("sending on a channel succeeds");
let err = task.await?.expect_err("ack should fail");
assert!(matches!(err, AckError::LeaseExpired), "{err:?}");
Ok(())
}
#[tokio::test]
async fn exactly_once_nack_error() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
let task = tokio::task::spawn(async move { h.confirmed_nack().await });
let nack = ack_rx.recv().await.expect("ack should be sent");
assert_eq!(nack, Action::ExactlyOnceNack(test_id(1)));
result_tx
.send(Err(AckError::LeaseExpired))
.expect("sending on a channel succeeds");
let err = task.await?.expect_err("ack should fail");
assert!(matches!(err, AckError::LeaseExpired), "{err:?}");
Ok(())
}
#[tokio::test]
async fn exactly_once_action_channel_closed() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(ack_rx);
let err = h.confirmed_ack().await.expect_err("ack should fail");
assert!(matches!(err, AckError::ShutdownBeforeAck), "{err:?}");
Ok(())
}
#[tokio::test]
async fn exactly_once_nack_action_channel_closed() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(ack_rx);
let err = h.confirmed_nack().await.expect_err("nack should fail");
assert!(matches!(err, AckError::Shutdown(_)), "{err:?}");
assert_eq!(
err.source()
.expect("shutdown errors have a source")
.to_string(),
NACK_SHUTDOWN_ERROR
);
Ok(())
}
#[tokio::test]
async fn exactly_once_result_channel_closed() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
let task = tokio::task::spawn(async move { h.confirmed_ack().await });
let ack = ack_rx.recv().await.expect("ack should be sent");
assert_eq!(ack, Action::ExactlyOnceAck(test_id(1)));
drop(result_tx);
let err = task.await?.expect_err("ack should fail");
assert!(matches!(err, AckError::Shutdown(_)), "{err:?}");
Ok(())
}
#[test]
fn exactly_once_nack() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
h.nack();
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceNack(test_id(1)));
Ok(())
}
#[test]
fn handler_at_least_once_nack_on_drop() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = Handler::AtLeastOnce(AtLeastOnce::new(test_id(1), ack_tx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(h);
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Nack(test_id(1)));
Ok(())
}
#[test]
fn handler_exactly_once_nack_on_drop() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = Handler::ExactlyOnce(ExactlyOnce::new(test_id(1), ack_tx, result_rx));
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(h);
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceNack(test_id(1)));
Ok(())
}
#[test]
fn at_least_once_nack_on_drop() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let h = AtLeastOnce::new(test_id(1), ack_tx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(h);
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::Nack(test_id(1)));
Ok(())
}
#[test]
fn exactly_once_nack_on_drop() -> anyhow::Result<()> {
let (ack_tx, mut ack_rx) = unbounded_channel();
let (_result_tx, result_rx) = channel();
let h = ExactlyOnce::new(test_id(1), ack_tx, result_rx);
assert_eq!(ack_rx.try_recv(), Err(TryRecvError::Empty));
drop(h);
let ack = ack_rx.try_recv()?;
assert_eq!(ack, Action::ExactlyOnceNack(test_id(1)));
Ok(())
}
}