dynamo_llm/http/service/
disconnect.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
5//! to disconnects from the client.
6//!
7//! There are two potential phases in any request where we need to handle the disconnect.
8//!
9//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
10//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
11//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
12//! a second task that will monitor for disconnects from the client. The primary task which spawned the
13//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
14//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
15//!
16//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
17//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
18//! This means the primary task handle will go out of scope when it returns the stream. When we create our
19//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
20//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
21//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
22//!
23//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
24//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
25//! cancel the request or not.
26//!
27//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
28//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
29//!
30
31use axum::response::sse::Event;
32use dynamo_runtime::engine::AsyncEngineContext;
33use futures::{Stream, StreamExt};
34use std::sync::Arc;
35
36use crate::http::service::metrics::{InflightGuard, Metrics};
37
38#[derive(Clone, Copy)]
39pub enum ConnectionStatus {
40    Disabled,
41    ClosedUnexpectedly,
42    ClosedGracefully,
43}
44
45pub struct ConnectionHandle {
46    sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
47    on_drop: ConnectionStatus,
48}
49
50impl ConnectionHandle {
51    /// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
52    pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
53        Self {
54            sender: Some(sender),
55            on_drop: ConnectionStatus::ClosedGracefully,
56        }
57    }
58
59    /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
60    pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
61        Self {
62            sender: Some(sender),
63            on_drop: ConnectionStatus::ClosedUnexpectedly,
64        }
65    }
66
67    /// Handle which will not issue a signal when dropped.
68    pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
69        Self {
70            sender: Some(sender),
71            on_drop: ConnectionStatus::Disabled,
72        }
73    }
74
75    /// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
76    pub fn disarm(&mut self) {
77        self.on_drop = ConnectionStatus::ClosedGracefully;
78    }
79
80    /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
81    pub fn arm(&mut self) {
82        self.on_drop = ConnectionStatus::ClosedUnexpectedly;
83    }
84}
85
86impl Drop for ConnectionHandle {
87    fn drop(&mut self) {
88        if let Some(sender) = self.sender.take() {
89            let _ = sender.send(self.on_drop);
90        }
91    }
92}
93
94/// Creates a pair of handles which will monitor for disconnects from the client.
95///
96/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
97/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
98///
99/// The handles are returned in the order of the first being armed and the second being disarmed.
100pub async fn create_connection_monitor(
101    engine_context: Arc<dyn AsyncEngineContext>,
102    metrics: Option<Arc<Metrics>>,
103) -> (ConnectionHandle, ConnectionHandle) {
104    // these oneshot channels monitor possible disconnects from the client in two different scopes:
105    // - the local task (connection_handle)
106    // - an optionally streaming response (stream_handle)
107    let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
108    let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
109
110    // detached task that will naturally close when both handles are dropped
111    tokio::spawn(connection_monitor(
112        engine_context.clone(),
113        connection_rx,
114        stream_rx,
115        metrics,
116    ));
117
118    // Two handles, the first is armed, the second is disarmed
119    (
120        ConnectionHandle::create_armed(connection_tx),
121        ConnectionHandle::create_disabled(stream_tx),
122    )
123}
124
125#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
126async fn connection_monitor(
127    engine_context: Arc<dyn AsyncEngineContext>,
128    connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
129    stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
130    metrics: Option<Arc<Metrics>>,
131) {
132    match connection_rx.await {
133        Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
134            // the client has disconnected, no need to gracefully cancel, just kill the context
135            tracing::trace!("Connection closed unexpectedly; issuing cancellation");
136            if let Some(metrics) = &metrics {
137                metrics.inc_client_disconnect();
138            }
139            engine_context.kill();
140        }
141        Ok(ConnectionStatus::ClosedGracefully) => {
142            tracing::trace!("Connection closed gracefully");
143        }
144        Ok(ConnectionStatus::Disabled) => {}
145    }
146
147    match stream_rx.await {
148        Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
149            tracing::trace!("Stream closed unexpectedly; issuing cancellation");
150            if let Some(metrics) = &metrics {
151                metrics.inc_client_disconnect();
152            }
153            engine_context.kill();
154        }
155        Ok(ConnectionStatus::ClosedGracefully) => {
156            tracing::trace!("Stream closed gracefully");
157        }
158        Ok(ConnectionStatus::Disabled) => {}
159    }
160}
161
162/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
163///
164/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
165/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
166/// naturally, we mark the request as successful and send the final `[DONE]` event.
167pub fn monitor_for_disconnects(
168    stream: impl Stream<Item = Result<Event, axum::Error>>,
169    context: Arc<dyn AsyncEngineContext>,
170    mut inflight_guard: InflightGuard,
171    mut stream_handle: ConnectionHandle,
172) -> impl Stream<Item = Result<Event, axum::Error>> {
173    stream_handle.arm();
174    async_stream::try_stream! {
175        tokio::pin!(stream);
176        loop {
177            tokio::select! {
178                event = stream.next() => {
179                    match event {
180                        Some(Ok(event)) => {
181                            yield event;
182                        }
183                        Some(Err(err)) => {
184                            yield Event::default().event("error").comment(err.to_string());
185                        }
186                        None => {
187                            // Stream ended normally
188                            inflight_guard.mark_ok();
189                            stream_handle.disarm();
190
191                            // todo: if we yield a dynamo sentinel event, we need to do it before the done or the
192                            // async-openai client will chomp it.
193                            yield Event::default().data("[DONE]");
194                            break;
195                        }
196                    }
197                }
198                _ = context.stopped() => {
199                    tracing::trace!("Context stopped; breaking stream");
200                    break;
201                }
202            }
203        }
204    }
205}