use std::sync::Arc;
use std::time::Duration;
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::handler::InterceptHandler;
use crate::metrics::ProxyMetricsStore;
use crate::runtime::handler_guard::HandlerCallbackGuard;
use crate::types::{FlowId, RawResponse, StreamChunk};
#[derive(Debug)]
enum DispatchWork {
Response(RawResponse),
StreamChunk(StreamChunk),
WebSocketStart(RawResponse),
}
#[derive(Debug)]
struct FlowDispatcher {
sender: mpsc::Sender<DispatchWork>,
worker: JoinHandle<()>,
}
#[derive(Debug)]
pub(crate) struct FlowDispatchers<H: InterceptHandler> {
handler: Arc<H>,
callback_guard: Arc<HandlerCallbackGuard>,
metrics_store: Arc<ProxyMetricsStore>,
closed_flow_live: Arc<DashSet<FlowId>>,
per_flow: DashMap<FlowId, FlowDispatcher>,
queue_capacity: usize,
queue_send_timeout: Duration,
close_join_timeout: Duration,
}
impl<H: InterceptHandler> FlowDispatchers<H> {
pub(crate) fn new(
handler: Arc<H>,
callback_guard: Arc<HandlerCallbackGuard>,
metrics_store: Arc<ProxyMetricsStore>,
closed_flow_live: Arc<DashSet<FlowId>>,
queue_capacity: usize,
queue_send_timeout: Duration,
close_join_timeout: Duration,
) -> Self {
Self {
handler,
callback_guard,
metrics_store,
closed_flow_live,
per_flow: DashMap::new(),
queue_capacity: queue_capacity.max(1),
queue_send_timeout: queue_send_timeout.max(Duration::from_millis(1)),
close_join_timeout: close_join_timeout.max(Duration::from_millis(1)),
}
}
pub(crate) async fn enqueue_response(&self, flow_id: FlowId, response: RawResponse) -> bool {
self.enqueue(flow_id, DispatchWork::Response(response))
.await
}
pub(crate) async fn enqueue_stream_chunk(&self, flow_id: FlowId, chunk: StreamChunk) -> bool {
self.enqueue(flow_id, DispatchWork::StreamChunk(chunk))
.await
}
pub(crate) async fn enqueue_websocket_start(
&self,
flow_id: FlowId,
response: RawResponse,
) -> bool {
self.enqueue(flow_id, DispatchWork::WebSocketStart(response))
.await
}
pub(crate) async fn close_and_drain(&self, flow_id: FlowId) {
let Some((_, mut dispatcher)) = self.per_flow.remove(&flow_id) else {
return;
};
drop(dispatcher.sender);
match tokio::time::timeout(self.close_join_timeout, &mut dispatcher.worker).await {
Ok(Ok(())) => {}
Ok(Err(error)) => {
self.metrics_store.record_dispatch_drop();
tracing::warn!(flow_id = flow_id.as_u64(), error = %error, "flow dispatcher worker join failed");
}
Err(_) => {
dispatcher.worker.abort();
let _ =
tokio::time::timeout(Duration::from_millis(100), &mut dispatcher.worker).await;
self.metrics_store.record_dispatch_drop();
tracing::warn!(
flow_id = flow_id.as_u64(),
timeout_ms = self.close_join_timeout.as_millis(),
"flow dispatcher worker join timed out; worker aborted"
);
}
}
}
#[allow(dead_code)]
pub(crate) async fn shutdown_all(&self) {
let flow_ids: Vec<FlowId> = self.per_flow.iter().map(|entry| *entry.key()).collect();
for flow_id in flow_ids {
self.close_and_drain(flow_id).await;
}
}
pub(crate) fn abort_all_now(&self) {
let flow_ids: Vec<FlowId> = self.per_flow.iter().map(|entry| *entry.key()).collect();
for flow_id in flow_ids {
if let Some((_, dispatcher)) = self.per_flow.remove(&flow_id) {
dispatcher.worker.abort();
}
}
}
async fn enqueue(&self, flow_id: FlowId, work: DispatchWork) -> bool {
let Some(sender) = self.sender_for_flow(flow_id).await else {
self.metrics_store.record_dispatch_drop();
tracing::warn!(
flow_id = flow_id.as_u64(),
"dropped dispatch work for finalized flow"
);
return false;
};
match tokio::time::timeout(self.queue_send_timeout, sender.send(work)).await {
Ok(Ok(())) => true,
Ok(Err(_)) => {
self.metrics_store.record_dispatch_drop();
tracing::warn!(
flow_id = flow_id.as_u64(),
"dropped dispatch work; flow worker closed"
);
false
}
Err(_) => {
self.metrics_store.record_dispatch_drop();
tracing::warn!(
flow_id = flow_id.as_u64(),
timeout_ms = self.queue_send_timeout.as_millis(),
"dispatch queue send timed out; dropping work item"
);
false
}
}
}
async fn sender_for_flow(&self, flow_id: FlowId) -> Option<mpsc::Sender<DispatchWork>> {
if self.is_flow_closed(flow_id) {
return None;
}
if let Some(existing) = self.per_flow.get(&flow_id) {
return Some(existing.sender.clone());
}
let (sender, receiver) = mpsc::channel(self.queue_capacity);
let worker = spawn_flow_dispatch_worker(
Arc::clone(&self.handler),
Arc::clone(&self.callback_guard),
receiver,
);
let selected_sender = match self.per_flow.entry(flow_id) {
Entry::Occupied(existing) => {
worker.abort();
existing.get().sender.clone()
}
Entry::Vacant(vacant) => {
vacant.insert(FlowDispatcher {
sender: sender.clone(),
worker,
});
sender.clone()
}
};
if !self.is_flow_closed(flow_id) {
return Some(selected_sender);
}
if let Some((_, dispatcher)) = self.per_flow.remove(&flow_id) {
dispatcher.worker.abort();
}
None
}
fn is_flow_closed(&self, flow_id: FlowId) -> bool {
self.closed_flow_live.contains(&flow_id)
}
}
impl<H: InterceptHandler> Drop for FlowDispatchers<H> {
fn drop(&mut self) {
self.abort_all_now();
}
}
fn spawn_flow_dispatch_worker<H: InterceptHandler>(
handler: Arc<H>,
callback_guard: Arc<HandlerCallbackGuard>,
mut receiver: mpsc::Receiver<DispatchWork>,
) -> JoinHandle<()> {
tokio::spawn(async move {
while let Some(work) = receiver.recv().await {
match work {
DispatchWork::Response(response) => {
let handler = Arc::clone(&handler);
callback_guard
.run_response((), async move { handler.on_response(&response).await })
.await;
}
DispatchWork::StreamChunk(chunk) => {
let handler = Arc::clone(&handler);
callback_guard
.run_response((), async move { handler.on_stream_chunk(&chunk).await })
.await;
}
DispatchWork::WebSocketStart(response) => {
let handler = Arc::clone(&handler);
callback_guard
.run_response(
(),
async move { handler.on_websocket_start(&response).await },
)
.await;
}
}
}
})
}