use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
};
use tonic::Streaming;
use tracing::{debug, warn};
pub trait AbortOnDropClient: Clone + Send + Sync + 'static {
fn abort_for_drop(
self,
request_id: String,
) -> Pin<Box<dyn Future<Output = Result<(), tonic::Status>> + Send>>;
}
pub struct AbortOnDropStream<T, C: AbortOnDropClient> {
inner: Streaming<T>,
request_id: String,
client: C,
aborted: Arc<AtomicBool>,
_marker: PhantomData<fn() -> T>,
}
impl<T, C: AbortOnDropClient> AbortOnDropStream<T, C> {
pub fn new(stream: Streaming<T>, request_id: String, client: C) -> Self {
debug!("Created AbortOnDropStream for request {}", request_id);
Self {
inner: stream,
request_id,
client,
aborted: Arc::new(AtomicBool::new(false)),
_marker: PhantomData,
}
}
pub fn mark_completed(&self) {
self.aborted.store(true, Ordering::Release);
debug!("Request {} marked as completed", self.request_id);
}
}
impl<T, C: AbortOnDropClient> Drop for AbortOnDropStream<T, C> {
fn drop(&mut self) {
if self
.aborted
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let request_id = self.request_id.clone();
let request_id_for_log = request_id.clone();
let client = self.client.clone();
#[expect(
clippy::disallowed_methods,
reason = "fire-and-forget abort on Drop is intentional"
)]
tokio::spawn(async move {
debug!(
"Stream dropped without completion for request {}, sending abort",
request_id_for_log
);
if let Err(e) = client.abort_for_drop(request_id).await {
warn!(
"Failed to send abort on drop for request {}: {}",
request_id_for_log, e
);
}
});
}
}
impl<T, C: AbortOnDropClient> Unpin for AbortOnDropStream<T, C> {}
impl<T, C: AbortOnDropClient> futures::Stream for AbortOnDropStream<T, C> {
type Item = Result<T, tonic::Status>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trait_is_implementable_by_simple_client() {
#[derive(Clone)]
struct DummyClient;
impl AbortOnDropClient for DummyClient {
fn abort_for_drop(
self,
_request_id: String,
) -> Pin<Box<dyn Future<Output = Result<(), tonic::Status>> + Send>> {
Box::pin(async { Ok(()) })
}
}
fn assert_send_sync<X: Send + Sync>() {}
assert_send_sync::<AbortOnDropStream<(), DummyClient>>();
}
}