use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use futures::Stream;
use pin_project::pin_project;
use tokio::time::Sleep;
use crate::error::ConnectError;
use crate::handler::BoxStream;
#[non_exhaustive]
#[derive(Debug, Clone, Default)]
pub struct DeadlinePolicy {
min: Duration,
max: Option<Duration>,
default: Option<Duration>,
enforce_on_streams: bool,
inter_message_timeout: Option<Duration>,
}
impl DeadlinePolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_min(mut self, min: Duration) -> Self {
self.min = min;
self
}
#[must_use]
pub fn with_max(mut self, max: Duration) -> Self {
self.max = Some(max);
self
}
#[must_use]
pub fn with_default_timeout(mut self, default: Duration) -> Self {
self.default = Some(default);
self
}
#[must_use]
pub fn with_enforce_on_streams(mut self, enforce: bool) -> Self {
self.enforce_on_streams = enforce;
self
}
#[must_use]
pub fn with_inter_message_timeout(mut self, timeout: Duration) -> Self {
self.inter_message_timeout = Some(timeout);
self
}
#[must_use]
pub fn min(&self) -> Duration {
self.min
}
#[must_use]
pub fn max(&self) -> Option<Duration> {
self.max
}
#[must_use]
pub fn default_timeout(&self) -> Option<Duration> {
self.default
}
#[must_use]
pub fn enforce_on_streams(&self) -> bool {
self.enforce_on_streams
}
#[must_use]
pub fn inter_message_timeout(&self) -> Option<Duration> {
self.inter_message_timeout
}
pub(crate) fn moderate(&self, client: Option<Duration>, path: &str) -> Option<Duration> {
match client {
Some(asserted) => {
let upper = self.max.unwrap_or(Duration::MAX);
let lower = self.min.min(upper);
let clamped = asserted.clamp(lower, upper);
if clamped != asserted {
tracing::debug!(
target: "connectrpc::deadline",
path,
client_timeout_ms =
u64::try_from(asserted.as_millis()).unwrap_or(u64::MAX),
effective_timeout_ms =
u64::try_from(clamped.as_millis()).unwrap_or(u64::MAX),
"client-asserted timeout clamped by server DeadlinePolicy",
);
}
Some(clamped)
}
None => self.default,
}
}
pub(crate) fn enforce_on_response_stream(
&self,
stream: BoxStream<Result<Bytes, ConnectError>>,
remaining: Option<Duration>,
) -> BoxStream<Result<Bytes, ConnectError>> {
let absolute = if self.enforce_on_streams {
remaining
} else {
None
};
let per_item = self.inter_message_timeout;
if absolute.is_none() && per_item.is_none() {
return stream;
}
Box::pin(DeadlineStream::new(stream, absolute, per_item))
}
}
#[pin_project]
struct DeadlineStream<S> {
#[pin]
inner: Option<S>,
#[pin]
absolute: Option<Sleep>,
#[pin]
per_item: Option<Sleep>,
inter_message: Option<Duration>,
finished: bool,
}
impl<S> DeadlineStream<S> {
fn new(inner: S, absolute: Option<Duration>, inter_message: Option<Duration>) -> Self {
Self {
inner: Some(inner),
absolute: absolute.map(tokio::time::sleep),
per_item: inter_message.map(tokio::time::sleep),
inter_message,
finished: false,
}
}
}
impl<S> Stream for DeadlineStream<S>
where
S: Stream<Item = Result<Bytes, ConnectError>>,
{
type Item = Result<Bytes, ConnectError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.finished {
return Poll::Ready(None);
}
if let Some(sleep) = this.absolute.as_mut().as_pin_mut()
&& sleep.poll(cx).is_ready()
{
*this.finished = true;
this.inner.set(None);
return Poll::Ready(Some(Err(ConnectError::deadline_exceeded(
"request deadline exceeded while streaming",
))));
}
if let Some(sleep) = this.per_item.as_mut().as_pin_mut()
&& sleep.poll(cx).is_ready()
{
*this.finished = true;
this.inner.set(None);
return Poll::Ready(Some(Err(ConnectError::deadline_exceeded(
"stream stalled past inter-message timeout",
))));
}
let Some(inner) = this.inner.as_mut().as_pin_mut() else {
*this.finished = true;
return Poll::Ready(None);
};
match inner.poll_next(cx) {
Poll::Ready(Some(item)) => {
if let Some(d) = this.inter_message {
this.per_item.set(Some(tokio::time::sleep(*d)));
}
Poll::Ready(Some(item))
}
Poll::Ready(None) => {
*this.finished = true;
this.inner.set(None);
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
fn ms(n: u64) -> Duration {
Duration::from_millis(n)
}
#[test]
fn no_op_policy_is_passthrough() {
let p = DeadlinePolicy::new();
assert_eq!(p.moderate(Some(ms(5)), "/svc/m"), Some(ms(5)));
assert_eq!(p.moderate(None, "/svc/m"), None);
}
#[test]
fn min_clamps_up() {
let p = DeadlinePolicy::new().with_min(ms(10));
assert_eq!(p.moderate(Some(ms(1)), "/svc/m"), Some(ms(10)));
assert_eq!(p.moderate(Some(ms(50)), "/svc/m"), Some(ms(50)));
}
#[test]
fn max_clamps_down() {
let p = DeadlinePolicy::new().with_max(ms(100));
assert_eq!(p.moderate(Some(ms(500)), "/svc/m"), Some(ms(100)));
assert_eq!(p.moderate(Some(ms(50)), "/svc/m"), Some(ms(50)));
}
#[test]
fn default_applies_when_client_absent() {
let p = DeadlinePolicy::new().with_default_timeout(ms(200));
assert_eq!(p.moderate(None, "/svc/m"), Some(ms(200)));
assert_eq!(p.moderate(Some(ms(50)), "/svc/m"), Some(ms(50)));
}
#[test]
fn no_default_no_client_is_unbounded() {
let p = DeadlinePolicy::new().with_min(ms(1)).with_max(ms(100));
assert_eq!(p.moderate(None, "/svc/m"), None);
}
#[test]
fn misconfigured_min_above_max_does_not_panic() {
let p = DeadlinePolicy::new().with_min(ms(500)).with_max(ms(100));
assert_eq!(p.moderate(Some(ms(1)), "/svc/m"), Some(ms(100)));
assert_eq!(p.moderate(Some(ms(700)), "/svc/m"), Some(ms(100)));
}
#[test]
fn full_matrix() {
let p = DeadlinePolicy::new()
.with_min(ms(10))
.with_max(ms(100))
.with_default_timeout(ms(50));
assert_eq!(p.moderate(Some(ms(20)), "/svc/m"), Some(ms(20)));
assert_eq!(p.moderate(Some(ms(1)), "/svc/m"), Some(ms(10)));
assert_eq!(p.moderate(Some(ms(500)), "/svc/m"), Some(ms(100)));
assert_eq!(p.moderate(None, "/svc/m"), Some(ms(50)));
assert_eq!(p.moderate(Some(ms(10)), "/svc/m"), Some(ms(10)));
assert_eq!(p.moderate(Some(ms(100)), "/svc/m"), Some(ms(100)));
}
#[test]
fn accessors_round_trip() {
let p = DeadlinePolicy::new()
.with_min(ms(1))
.with_max(ms(2))
.with_default_timeout(ms(3))
.with_enforce_on_streams(true)
.with_inter_message_timeout(ms(4));
assert_eq!(p.min(), ms(1));
assert_eq!(p.max(), Some(ms(2)));
assert_eq!(p.default_timeout(), Some(ms(3)));
assert!(p.enforce_on_streams());
assert_eq!(p.inter_message_timeout(), Some(ms(4)));
}
#[tokio::test(start_paused = true)]
async fn enforce_no_op_returns_original_stream() {
let p = DeadlinePolicy::new();
let inner: BoxStream<Result<Bytes, ConnectError>> = Box::pin(
futures::stream::iter([Ok(Bytes::from_static(b"a"))]).chain(futures::stream::pending()),
);
let mut wrapped = p.enforce_on_response_stream(inner, Some(ms(1)));
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
tokio::time::advance(ms(10)).await;
let pending = futures::poll!(wrapped.next());
assert!(pending.is_pending());
}
#[tokio::test(start_paused = true)]
async fn fast_stream_completes_under_deadline() {
let p = DeadlinePolicy::new().with_enforce_on_streams(true);
let inner: BoxStream<Result<Bytes, ConnectError>> = Box::pin(futures::stream::iter([
Ok(Bytes::from_static(b"a")),
Ok(Bytes::from_static(b"b")),
]));
let mut wrapped = p.enforce_on_response_stream(inner, Some(Duration::from_secs(60)));
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"b")
);
assert!(wrapped.next().await.is_none());
}
#[tokio::test(start_paused = true)]
async fn slow_stream_cut_off_at_deadline() {
let p = DeadlinePolicy::new().with_enforce_on_streams(true);
let inner: BoxStream<Result<Bytes, ConnectError>> = Box::pin(
futures::stream::iter([Ok(Bytes::from_static(b"a"))]).chain(futures::stream::pending()),
);
let mut wrapped = p.enforce_on_response_stream(inner, Some(ms(100)));
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
tokio::time::advance(ms(200)).await;
let err = wrapped.next().await.unwrap().unwrap_err();
assert_eq!(err.code, crate::ErrorCode::DeadlineExceeded);
assert!(wrapped.next().await.is_none());
}
#[tokio::test(start_paused = true)]
async fn inter_message_timeout_cuts_off_stalled_stream() {
let p = DeadlinePolicy::new()
.with_enforce_on_streams(true)
.with_inter_message_timeout(ms(50));
let inner: BoxStream<Result<Bytes, ConnectError>> = Box::pin(
futures::stream::iter([Ok(Bytes::from_static(b"a"))]).chain(futures::stream::pending()),
);
let mut wrapped = p.enforce_on_response_stream(inner, Some(Duration::from_secs(3600)));
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
tokio::time::advance(ms(100)).await;
let err = wrapped.next().await.unwrap().unwrap_err();
assert_eq!(err.code, crate::ErrorCode::DeadlineExceeded);
assert!(err.message.as_deref().unwrap().contains("inter-message"));
}
#[tokio::test(start_paused = true)]
async fn inter_message_timeout_independent_of_enforce_on_streams() {
let p = DeadlinePolicy::new().with_inter_message_timeout(ms(50));
assert!(!p.enforce_on_streams());
let inner: BoxStream<Result<Bytes, ConnectError>> = Box::pin(
futures::stream::iter([Ok(Bytes::from_static(b"a"))]).chain(futures::stream::pending()),
);
let mut wrapped = p.enforce_on_response_stream(inner, None);
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
tokio::time::advance(ms(100)).await;
let err = wrapped.next().await.unwrap().unwrap_err();
assert_eq!(err.code, crate::ErrorCode::DeadlineExceeded);
assert!(err.message.as_deref().unwrap().contains("inter-message"));
}
#[tokio::test(start_paused = true)]
async fn no_deadline_no_inter_message_streams_unbounded() {
let p = DeadlinePolicy::new().with_enforce_on_streams(true);
let inner: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::iter([Ok(Bytes::from_static(b"a"))]));
let mut wrapped = p.enforce_on_response_stream(inner, None);
assert_eq!(
wrapped.next().await.unwrap().unwrap(),
Bytes::from_static(b"a")
);
assert!(wrapped.next().await.is_none());
}
}