use axum::response::sse::Event;
use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use std::time::Duration;
use crate::http::service::metrics::{CancellationLabels, ErrorType, InflightGuard, Metrics};
use dynamo_runtime::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS as BACKEND_STREAM_TIMEOUT_ENV;
pub fn backend_stream_timeout() -> Option<Duration> {
std::env::var(BACKEND_STREAM_TIMEOUT_ENV)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&secs| secs > 0)
.map(|secs| Duration::from_secs(secs.saturating_mul(2)))
}
#[derive(Clone, Copy)]
pub enum ConnectionStatus {
Disabled,
ClosedUnexpectedly,
ClosedGracefully,
}
pub struct ConnectionHandle {
sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
on_drop: ConnectionStatus,
}
impl ConnectionHandle {
pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedGracefully,
}
}
pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedUnexpectedly,
}
}
pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::Disabled,
}
}
pub fn disarm(&mut self) {
self.on_drop = ConnectionStatus::ClosedGracefully;
}
pub fn arm(&mut self) {
self.on_drop = ConnectionStatus::ClosedUnexpectedly;
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let _ = sender.send(self.on_drop);
}
}
}
pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) -> (ConnectionHandle, ConnectionHandle) {
let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
tokio::spawn(connection_monitor(
engine_context.clone(),
connection_rx,
stream_rx,
metrics,
cancellation_labels,
));
(
ConnectionHandle::create_armed(connection_tx),
ConnectionHandle::create_disabled(stream_tx),
)
}
#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
async fn connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) {
match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::warn!("Connection closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
}
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Connection closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
match stream_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::warn!("Stream closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
}
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Stream closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
}
pub fn monitor_for_disconnects(
stream: impl Stream<Item = Result<Event, axum::Error>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Result<Event, axum::Error>> {
stream_handle.arm();
inflight_guard.mark_error(ErrorType::Cancelled);
let inactivity_timeout = backend_stream_timeout();
async_stream::try_stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(Ok(event)) => {
yield event;
}
Some(Err(err)) => {
inflight_guard.mark_error(ErrorType::Internal);
yield Event::default().event("error").comment(err.to_string());
break;
}
None => {
inflight_guard.mark_ok();
stream_handle.disarm();
yield Event::default().data("[DONE]");
break;
}
}
}
_ = context.stopped() => {
inflight_guard.mark_error(ErrorType::Cancelled);
tracing::warn!(
request_id = %inflight_guard.request_id(),
model = %inflight_guard.model(),
endpoint = %inflight_guard.endpoint(),
request_type = %inflight_guard.request_type(),
error_type = "cancelled",
elapsed_ms = %inflight_guard.elapsed_ms(),
"request cancelled"
);
break;
}
_ = async {
match inactivity_timeout {
Some(d) => tokio::time::sleep(d).await,
None => std::future::pending::<()>().await,
}
} => {
inflight_guard.mark_error(ErrorType::ResponseTimeout);
stream_handle.disarm();
tracing::warn!(
request_id = %inflight_guard.request_id(),
model = %inflight_guard.model(),
endpoint = %inflight_guard.endpoint(),
request_type = %inflight_guard.request_type(),
error_type = "response_timeout",
elapsed_ms = %inflight_guard.elapsed_ms(),
timeout_secs = ?inactivity_timeout.map(|d| d.as_secs()),
"backend stream inactivity timeout; killing engine context to release inflight gauge"
);
context.kill();
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::service::metrics::{Endpoint, ErrorType, RequestType, Status};
use futures::StreamExt;
use serial_test::serial;
#[derive(Debug)]
struct MockContext;
impl MockContext {
fn new() -> Self {
Self
}
}
#[async_trait::async_trait]
impl dynamo_runtime::engine::AsyncEngineContext for MockContext {
fn id(&self) -> &str {
"test"
}
fn stop(&self) {}
fn stop_generating(&self) {}
fn kill(&self) {}
fn is_stopped(&self) -> bool {
false
}
fn is_killed(&self) -> bool {
false
}
async fn stopped(&self) {
std::future::pending::<()>().await;
}
async fn killed(&self) {
std::future::pending::<()>().await;
}
fn link_child(&self, _: Arc<dyn dynamo_runtime::engine::AsyncEngineContext>) {}
}
fn hanging_stream()
-> impl futures::Stream<Item = Result<axum::response::sse::Event, axum::Error>> {
async_stream::try_stream! {
std::future::pending::<()>().await;
yield axum::response::sse::Event::default().data("unreachable");
}
}
fn timed_token_stream(
count: usize,
interval: Duration,
) -> impl futures::Stream<Item = Result<axum::response::sse::Event, axum::Error>> {
async_stream::try_stream! {
for i in 0..count {
tokio::time::sleep(interval).await;
yield axum::response::sse::Event::default().data(format!("token-{i}"));
}
}
}
fn setup_test(
model: &str,
req_id: &str,
timeout_secs: &str,
) -> (
Arc<Metrics>,
InflightGuard,
Arc<dyn AsyncEngineContext>,
ConnectionHandle,
) {
let metrics = Arc::new(Metrics::new());
let guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, true, req_id);
let context: Arc<dyn AsyncEngineContext> = Arc::new(MockContext::new());
let (tx, _rx) = tokio::sync::oneshot::channel();
let handle = ConnectionHandle::create_disabled(tx);
unsafe { std::env::set_var(BACKEND_STREAM_TIMEOUT_ENV, timeout_secs) };
(metrics, guard, context, handle)
}
fn cleanup_env() {
unsafe { std::env::remove_var(BACKEND_STREAM_TIMEOUT_ENV) };
}
#[tokio::test(start_paused = true)]
#[serial]
async fn test_backend_inactivity_timeout_releases_inflight_gauge() {
let model = "zombie-model";
let (metrics, guard, context, handle) = setup_test(model, "req-zombie", "1");
assert_eq!(metrics.get_inflight_count(model), 1);
let monitored = monitor_for_disconnects(hanging_stream(), context, guard, handle);
tokio::pin!(monitored);
tokio::time::advance(Duration::from_secs(3)).await;
let completed = tokio::time::timeout(Duration::from_secs(2), async move {
while monitored.next().await.is_some() {}
})
.await;
cleanup_env();
completed.expect("stream did not terminate — backend inactivity timeout is broken");
assert_eq!(
metrics.get_inflight_count(model),
0,
"inflight gauge leaked"
);
assert_eq!(
metrics.get_request_counter(
model,
&Endpoint::ChatCompletions,
&RequestType::Stream,
&Status::Error,
&ErrorType::ResponseTimeout,
),
1,
"inactivity timeout should be recorded as ResponseTimeout"
);
assert_eq!(
metrics.get_request_counter(
model,
&Endpoint::ChatCompletions,
&RequestType::Stream,
&Status::Error,
&ErrorType::Cancelled,
),
0,
"inactivity timeout should NOT be recorded as Cancelled"
);
}
#[tokio::test(start_paused = true)]
#[serial]
async fn test_inactivity_timeout_resets_on_each_token() {
let model = "reset-model";
let (metrics, guard_1, ctx_1, handle_1) = setup_test(model, "phase1", "5");
assert_eq!(metrics.get_inflight_count(model), 1);
let token_count = 5;
let monitored_1 = monitor_for_disconnects(
timed_token_stream(token_count, Duration::from_secs(2)),
ctx_1,
guard_1,
handle_1,
);
tokio::pin!(monitored_1);
let mut received = Vec::new();
let phase1 = tokio::time::timeout(Duration::from_secs(30), async {
while let Some(event) = monitored_1.next().await {
received.push(event);
}
})
.await;
assert!(
phase1.is_ok(),
"inactivity timeout incorrectly fired as a hard deadline"
);
assert_eq!(received.len(), token_count + 1); assert_eq!(metrics.get_inflight_count(model), 0);
let guard_2 =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, true, "phase2");
assert_eq!(metrics.get_inflight_count(model), 1);
let ctx_2: Arc<dyn AsyncEngineContext> = Arc::new(MockContext::new());
let (tx_2, _rx_2) = tokio::sync::oneshot::channel();
let handle_2 = ConnectionHandle::create_disabled(tx_2);
let monitored_2 = monitor_for_disconnects(hanging_stream(), ctx_2, guard_2, handle_2);
tokio::pin!(monitored_2);
tokio::time::advance(Duration::from_secs(11)).await;
let phase2 = tokio::time::timeout(Duration::from_secs(10), async {
while monitored_2.next().await.is_some() {}
})
.await;
cleanup_env();
assert!(
phase2.is_ok(),
"hanging stream was not terminated by inactivity timeout"
);
assert_eq!(
metrics.get_inflight_count(model),
0,
"inflight gauge leaked in phase 2"
);
}
}