use std::sync::{Arc, Mutex};
use anyhow::Result;
use dynamo_kv_router::protocols::{KvCacheEvent, RouterEvent, WorkerId};
use crate::common::protocols::{
ForwardPassSnapshot, FpmPublisher, KvCacheEventSink, KvEventPublishers, RawKvEvent,
RawKvEventSink,
};
#[derive(Clone, Default)]
pub(crate) struct CapturedRouterEventBuffer {
events: Arc<Mutex<Vec<RouterEvent>>>,
}
impl CapturedRouterEventBuffer {
pub(crate) fn push(&self, event: RouterEvent) {
self.events.lock().unwrap().push(event);
}
pub(crate) fn drain(&self) -> Vec<RouterEvent> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
#[derive(Clone)]
struct RouterEventCaptureSink {
worker_id: WorkerId,
buffer: CapturedRouterEventBuffer,
}
impl KvCacheEventSink for RouterEventCaptureSink {
fn publish(&self, event: KvCacheEvent) -> Result<()> {
self.buffer.push(RouterEvent::new(self.worker_id, event));
Ok(())
}
}
pub(crate) fn capture_router_event_sink(
worker_id: WorkerId,
) -> (CapturedRouterEventBuffer, Arc<dyn KvCacheEventSink>) {
let buffer = CapturedRouterEventBuffer::default();
let sink: Arc<dyn KvCacheEventSink> = Arc::new(RouterEventCaptureSink {
worker_id,
buffer: buffer.clone(),
});
(buffer, sink)
}
#[derive(Debug, Clone)]
pub(crate) struct DeferredKvPublish {
pub(crate) event: KvCacheEvent,
pub(crate) block_token_ids: Option<Vec<Vec<u32>>>,
}
#[derive(Clone, Default)]
pub(crate) struct DeferredKvPublishBuffer {
events: Arc<Mutex<Vec<DeferredKvPublish>>>,
}
impl DeferredKvPublishBuffer {
pub(crate) fn push(&self, event: KvCacheEvent, block_token_ids: Option<Vec<Vec<u32>>>) {
self.events.lock().unwrap().push(DeferredKvPublish {
event,
block_token_ids,
});
}
pub(crate) fn drain(&self) -> Vec<DeferredKvPublish> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
#[derive(Clone, Default)]
struct DeferredKvEventSink {
buffer: DeferredKvPublishBuffer,
}
impl KvCacheEventSink for DeferredKvEventSink {
fn publish(&self, event: KvCacheEvent) -> Result<()> {
self.buffer.push(event, None);
Ok(())
}
}
#[derive(Clone, Default)]
struct DeferredRawKvEventSink {
buffer: DeferredKvPublishBuffer,
}
impl RawKvEventSink for DeferredRawKvEventSink {
fn publish(&self, event: RawKvEvent) -> Result<()> {
let mut events = self.buffer.events.lock().unwrap();
if let Some(last) = events.last_mut()
&& last.event.event_id == event.event.event_id
&& last.event.dp_rank == event.event.dp_rank
{
last.block_token_ids = event.block_token_ids;
return Ok(());
}
events.push(DeferredKvPublish {
event: event.event,
block_token_ids: event.block_token_ids,
});
Ok(())
}
}
pub(crate) fn capture_deferred_kv_publish_sink(
capture_raw: bool,
) -> (DeferredKvPublishBuffer, KvEventPublishers) {
let buffer = DeferredKvPublishBuffer::default();
let event_sink: Arc<dyn KvCacheEventSink> = Arc::new(DeferredKvEventSink {
buffer: buffer.clone(),
});
let raw_sink = capture_raw.then(|| {
Arc::new(DeferredRawKvEventSink {
buffer: buffer.clone(),
}) as Arc<dyn RawKvEventSink>
});
(buffer, KvEventPublishers::new(Some(event_sink), raw_sink))
}
pub(crate) fn publish_deferred_kv_events(
sinks: &KvEventPublishers,
events: Vec<DeferredKvPublish>,
) {
for event in events {
if let Err(error) = sinks.publish(event.event, event.block_token_ids.as_deref()) {
tracing::warn!("Failed to forward buffered KV event: {error}");
}
}
}
#[derive(Clone, Default)]
pub(crate) struct DeferredFpmBuffer {
snapshots: Arc<Mutex<Vec<ForwardPassSnapshot>>>,
}
impl DeferredFpmBuffer {
pub(crate) fn push(&self, snapshot: ForwardPassSnapshot) {
self.snapshots.lock().unwrap().push(snapshot);
}
pub(crate) fn drain(&self) -> Vec<ForwardPassSnapshot> {
std::mem::take(&mut *self.snapshots.lock().unwrap())
}
}
pub(crate) fn publish_deferred_fpm(sink: &FpmPublisher, snapshots: Vec<ForwardPassSnapshot>) {
for snapshot in snapshots {
if let Err(error) = sink.publish(snapshot) {
tracing::warn!("Failed to forward buffered FPM snapshot: {error}");
}
}
}