use axum::response::sse::Event;
use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use crate::http::service::metrics::{ErrorType, InflightGuard, Metrics};
#[derive(Clone, Copy)]
pub enum ConnectionStatus {
Disabled,
ClosedUnexpectedly,
ClosedGracefully,
}
pub struct ConnectionHandle {
sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
on_drop: ConnectionStatus,
}
impl ConnectionHandle {
pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedGracefully,
}
}
pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedUnexpectedly,
}
}
pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::Disabled,
}
}
pub fn disarm(&mut self) {
self.on_drop = ConnectionStatus::ClosedGracefully;
}
pub fn arm(&mut self) {
self.on_drop = ConnectionStatus::ClosedUnexpectedly;
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let _ = sender.send(self.on_drop);
}
}
}
pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
metrics: Option<Arc<Metrics>>,
) -> (ConnectionHandle, ConnectionHandle) {
let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
tokio::spawn(connection_monitor(
engine_context.clone(),
connection_rx,
stream_rx,
metrics,
));
(
ConnectionHandle::create_armed(connection_tx),
ConnectionHandle::create_disabled(stream_tx),
)
}
#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
async fn connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
metrics: Option<Arc<Metrics>>,
) {
match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::trace!("Connection closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
}
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Connection closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
match stream_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::trace!("Stream closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
}
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Stream closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
}
pub fn monitor_for_disconnects(
stream: impl Stream<Item = Result<Event, axum::Error>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Result<Event, axum::Error>> {
stream_handle.arm();
inflight_guard.mark_error(ErrorType::Cancelled);
async_stream::try_stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(Ok(event)) => {
yield event;
}
Some(Err(err)) => {
inflight_guard.mark_error(ErrorType::Internal);
yield Event::default().event("error").comment(err.to_string());
break;
}
None => {
inflight_guard.mark_ok();
stream_handle.disarm();
yield Event::default().data("[DONE]");
break;
}
}
}
_ = context.stopped() => {
tracing::trace!("Context stopped; breaking stream");
inflight_guard.mark_error(ErrorType::Cancelled);
break;
}
}
}
}
}