pub(crate) mod driver;
pub(crate) use driver::{spawn_consumer_async, spawn_consumer_sync};
use crate::StreamReadError;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle;
use tokio::time::{Instant, sleep_until};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ConsumerError {
#[error("Failed to join/terminate the consumer task over stream '{stream_name}': {source}")]
TaskJoin {
stream_name: &'static str,
#[source]
source: tokio::task::JoinError,
},
#[error("{source}")]
StreamRead {
#[source]
source: StreamReadError,
},
}
pub trait Sink: Send + 'static {}
impl<T> Sink for T where T: Send + 'static {}
#[derive(Debug)]
pub enum ConsumerCancelOutcome<S: Sink> {
Cancelled(S),
Aborted,
}
impl<S: Sink> ConsumerCancelOutcome<S> {
#[must_use]
pub fn into_cancelled(self) -> Option<S> {
match self {
Self::Cancelled(sink) => Some(sink),
Self::Aborted => None,
}
}
pub fn expect_cancelled(self, message: &str) -> S {
self.into_cancelled().expect(message)
}
}
pub struct Consumer<S: Sink> {
pub(crate) stream_name: &'static str,
pub(crate) task: Option<JoinHandle<Result<S, StreamReadError>>>,
pub(crate) task_termination_sender: Option<Sender<()>>,
}
pub(crate) struct ConsumerWait<S: Sink> {
stream_name: &'static str,
guard: ConsumerWaitGuard<S>,
}
struct ConsumerWaitGuard<S: Sink> {
task: Option<JoinHandle<Result<S, StreamReadError>>>,
task_termination_sender: Option<Sender<()>>,
}
impl<S: Sink> ConsumerWaitGuard<S> {
fn cancel(&mut self) {
let _res = self
.task_termination_sender
.take()
.expect("`task_termination_sender` to be present.")
.send(());
}
async fn wait(&mut self, stream_name: &'static str) -> Result<S, ConsumerError> {
let sink = self
.task
.as_mut()
.expect("`task` to be present.")
.await
.map_err(|err| ConsumerError::TaskJoin {
stream_name,
source: err,
})?
.map_err(|source| ConsumerError::StreamRead { source });
self.task = None;
self.task_termination_sender = None;
sink
}
async fn abort(&mut self) {
if let Some(task_termination_sender) = self.task_termination_sender.take() {
let _res = task_termination_sender.send(());
}
if let Some(task) = &self.task {
task.abort();
}
if let Some(task) = self.task.as_mut() {
let _res = task.await;
}
self.task = None;
}
}
impl<S: Sink> Drop for ConsumerWaitGuard<S> {
fn drop(&mut self) {
if let Some(task_termination_sender) = self.task_termination_sender.take() {
let _res = task_termination_sender.send(());
}
if let Some(task) = self.task.take() {
task.abort();
}
}
}
impl<S: Sink> Consumer<S> {
pub(crate) fn into_wait(mut self) -> ConsumerWait<S> {
ConsumerWait {
stream_name: self.stream_name,
guard: ConsumerWaitGuard {
task: self.task.take(),
task_termination_sender: self.task_termination_sender.take(),
},
}
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.task.as_ref().is_none_or(JoinHandle::is_finished)
}
pub async fn wait(self) -> Result<S, ConsumerError> {
self.into_wait().wait().await
}
pub async fn abort(self) {
self.into_wait().abort().await;
}
pub async fn cancel(
self,
timeout: Duration,
) -> Result<ConsumerCancelOutcome<S>, ConsumerError> {
let mut wait = self.into_wait();
wait.cancel();
match wait.wait_until(Instant::now() + timeout).await? {
Some(sink) => Ok(ConsumerCancelOutcome::Cancelled(sink)),
None => Ok(ConsumerCancelOutcome::Aborted),
}
}
}
impl<S: Sink> ConsumerWait<S> {
pub(crate) fn cancel(&mut self) {
self.guard.cancel();
}
pub(crate) async fn wait(&mut self) -> Result<S, ConsumerError> {
self.guard.wait(self.stream_name).await
}
pub(crate) async fn wait_until(
&mut self,
deadline: Instant,
) -> Result<Option<S>, ConsumerError> {
let timeout = sleep_until(deadline);
tokio::pin!(timeout);
tokio::select! {
result = self.wait() => result.map(Some),
() = &mut timeout => {
self.abort().await;
Ok(None)
}
}
}
pub(crate) async fn abort(&mut self) {
self.guard.abort().await;
}
}
impl<S: Sink> Drop for Consumer<S> {
fn drop(&mut self) {
if let Some(task_termination_sender) = self.task_termination_sender.take() {
let _res = task_termination_sender.send(());
}
if let Some(task) = self.task.take() {
task.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assertr::prelude::*;
use std::io;
use tokio::sync::oneshot;
#[test]
fn stream_read_display_uses_source_context() {
let source = StreamReadError::new("stdout", io::Error::from(io::ErrorKind::BrokenPipe));
let expected = source.to_string();
let err = ConsumerError::StreamRead { source };
assert_that!(err.to_string()).is_equal_to(expected);
}
#[tokio::test]
async fn cancel_returns_cancelled_when_cooperative() {
let (task_termination_sender, task_termination_receiver) = oneshot::channel();
let consumer = Consumer {
stream_name: "custom",
task: Some(tokio::spawn(async move {
let _res = task_termination_receiver.await;
Ok(Vec::<u8>::new())
})),
task_termination_sender: Some(task_termination_sender),
};
let outcome = consumer.cancel(Duration::from_secs(1)).await.unwrap();
match outcome {
ConsumerCancelOutcome::Cancelled(bytes) => {
assert_that!(bytes).is_empty();
}
ConsumerCancelOutcome::Aborted => {
assert_that!(()).fail("expected cooperative cancellation");
}
}
}
}