dynamo_llm/http/service/disconnect.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
5//! to disconnects from the client.
6//!
7//! There are two potential phases in any request where we need to handle the disconnect.
8//!
9//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
10//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
11//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
12//! a second task that will monitor for disconnects from the client. The primary task which spawned the
13//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
14//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
15//!
16//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
17//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
18//! This means the primary task handle will go out of scope when it returns the stream. When we create our
19//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
20//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
21//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
22//!
23//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
24//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
25//! cancel the request or not.
26//!
27//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
28//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
29//!
30
31use axum::response::sse::Event;
32use dynamo_runtime::engine::AsyncEngineContext;
33use futures::{Stream, StreamExt};
34use std::sync::Arc;
35
36use crate::http::service::metrics::{InflightGuard, Metrics};
37
38#[derive(Clone, Copy)]
39pub enum ConnectionStatus {
40 Disabled,
41 ClosedUnexpectedly,
42 ClosedGracefully,
43}
44
45pub struct ConnectionHandle {
46 sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
47 on_drop: ConnectionStatus,
48}
49
50impl ConnectionHandle {
51 /// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
52 pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
53 Self {
54 sender: Some(sender),
55 on_drop: ConnectionStatus::ClosedGracefully,
56 }
57 }
58
59 /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
60 pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
61 Self {
62 sender: Some(sender),
63 on_drop: ConnectionStatus::ClosedUnexpectedly,
64 }
65 }
66
67 /// Handle which will not issue a signal when dropped.
68 pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
69 Self {
70 sender: Some(sender),
71 on_drop: ConnectionStatus::Disabled,
72 }
73 }
74
75 /// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
76 pub fn disarm(&mut self) {
77 self.on_drop = ConnectionStatus::ClosedGracefully;
78 }
79
80 /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
81 pub fn arm(&mut self) {
82 self.on_drop = ConnectionStatus::ClosedUnexpectedly;
83 }
84}
85
86impl Drop for ConnectionHandle {
87 fn drop(&mut self) {
88 if let Some(sender) = self.sender.take() {
89 let _ = sender.send(self.on_drop);
90 }
91 }
92}
93
94/// Creates a pair of handles which will monitor for disconnects from the client.
95///
96/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
97/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
98///
99/// The handles are returned in the order of the first being armed and the second being disarmed.
100pub async fn create_connection_monitor(
101 engine_context: Arc<dyn AsyncEngineContext>,
102 metrics: Option<Arc<Metrics>>,
103) -> (ConnectionHandle, ConnectionHandle) {
104 // these oneshot channels monitor possible disconnects from the client in two different scopes:
105 // - the local task (connection_handle)
106 // - an optionally streaming response (stream_handle)
107 let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
108 let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
109
110 // detached task that will naturally close when both handles are dropped
111 tokio::spawn(connection_monitor(
112 engine_context.clone(),
113 connection_rx,
114 stream_rx,
115 metrics,
116 ));
117
118 // Two handles, the first is armed, the second is disarmed
119 (
120 ConnectionHandle::create_armed(connection_tx),
121 ConnectionHandle::create_disabled(stream_tx),
122 )
123}
124
125#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
126async fn connection_monitor(
127 engine_context: Arc<dyn AsyncEngineContext>,
128 connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
129 stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
130 metrics: Option<Arc<Metrics>>,
131) {
132 match connection_rx.await {
133 Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
134 // the client has disconnected, no need to gracefully cancel, just kill the context
135 tracing::trace!("Connection closed unexpectedly; issuing cancellation");
136 if let Some(metrics) = &metrics {
137 metrics.inc_client_disconnect();
138 }
139 engine_context.kill();
140 }
141 Ok(ConnectionStatus::ClosedGracefully) => {
142 tracing::trace!("Connection closed gracefully");
143 }
144 Ok(ConnectionStatus::Disabled) => {}
145 }
146
147 match stream_rx.await {
148 Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
149 tracing::trace!("Stream closed unexpectedly; issuing cancellation");
150 if let Some(metrics) = &metrics {
151 metrics.inc_client_disconnect();
152 }
153 engine_context.kill();
154 }
155 Ok(ConnectionStatus::ClosedGracefully) => {
156 tracing::trace!("Stream closed gracefully");
157 }
158 Ok(ConnectionStatus::Disabled) => {}
159 }
160}
161
162/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
163///
164/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
165/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
166/// naturally, we mark the request as successful and send the final `[DONE]` event.
167pub fn monitor_for_disconnects(
168 stream: impl Stream<Item = Result<Event, axum::Error>>,
169 context: Arc<dyn AsyncEngineContext>,
170 mut inflight_guard: InflightGuard,
171 mut stream_handle: ConnectionHandle,
172) -> impl Stream<Item = Result<Event, axum::Error>> {
173 stream_handle.arm();
174 async_stream::try_stream! {
175 tokio::pin!(stream);
176 loop {
177 tokio::select! {
178 event = stream.next() => {
179 match event {
180 Some(Ok(event)) => {
181 yield event;
182 }
183 Some(Err(err)) => {
184 yield Event::default().event("error").comment(err.to_string());
185 }
186 None => {
187 // Stream ended normally
188 inflight_guard.mark_ok();
189 stream_handle.disarm();
190
191 // todo: if we yield a dynamo sentinel event, we need to do it before the done or the
192 // async-openai client will chomp it.
193 yield Event::default().data("[DONE]");
194 break;
195 }
196 }
197 }
198 _ = context.stopped() => {
199 tracing::trace!("Context stopped; breaking stream");
200 break;
201 }
202 }
203 }
204 }
205}