use anyhow::{Context, Result};
use bytes::Bytes;
use rmp_serde::Serializer;
use serde::Serialize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use zeromq::{PubSocket, Socket, SocketSend};
use super::tracker::{CacheStatusTracker, ConsolidatedEvent};
#[derive(Debug, Serialize)]
struct EventBatch(
f64, Vec<Event>, Option<i32>, );
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum Event {
#[serde(rename = "BlockStored")]
BlockStored {
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<i32>,
block_size: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
medium: Option<String>,
},
#[serde(rename = "BlockRemoved")]
BlockRemoved {
block_hashes: Vec<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
medium: Option<String>,
},
#[serde(rename = "AllBlocksCleared")]
AllBlocksCleared {},
}
impl Event {
fn from_consolidated(event: ConsolidatedEvent) -> Result<Self> {
match event {
ConsolidatedEvent::Store {
block_hash,
parent_hash,
token_ids,
block_size,
lora_name,
source: _,
} => {
let parsed_hash = block_hash
.parse::<u64>()
.with_context(|| format!("Failed to parse block_hash: {}", block_hash))?;
let parsed_parent = parent_hash
.map(|h| {
h.parse::<u64>()
.with_context(|| format!("Failed to parse parent_hash: {}", h))
})
.transpose()?;
let token_ids_i32: Vec<i32> = token_ids
.into_iter()
.map(|t| {
i32::try_from(t).unwrap_or_else(|_| {
tracing::warn!("Token ID {} exceeds i32::MAX, clamping to i32::MAX", t);
i32::MAX
})
})
.collect();
let block_size_i32 = i32::try_from(block_size).unwrap_or_else(|_| {
tracing::warn!(
"Block size {} exceeds i32::MAX, clamping to i32::MAX",
block_size
);
i32::MAX
});
Ok(Event::BlockStored {
block_hashes: vec![parsed_hash],
parent_block_hash: parsed_parent,
token_ids: token_ids_i32,
block_size: block_size_i32,
lora_name,
medium: None,
})
}
ConsolidatedEvent::Remove {
block_hash,
source: _,
} => {
let parsed_hash = block_hash.parse::<u64>().with_context(|| {
format!("Failed to parse block_hash for removal: {}", block_hash)
})?;
Ok(Event::BlockRemoved {
block_hashes: vec![parsed_hash],
medium: None, })
}
ConsolidatedEvent::ClearAll => Ok(Event::AllBlocksCleared {}),
}
}
}
pub struct KvEventConsolidatorPublisher {
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
sequence: Arc<AtomicU64>,
task_handle: Option<JoinHandle<()>>,
}
impl KvEventConsolidatorPublisher {
pub fn new(endpoint: &str, tracker: Arc<RwLock<CacheStatusTracker>>) -> Result<Self> {
let endpoint = endpoint.to_string();
let sequence = Arc::new(AtomicU64::new(0));
let publisher = Self {
endpoint: endpoint.clone(),
tracker: tracker.clone(),
sequence: sequence.clone(),
task_handle: None,
};
let handle = tokio::spawn(async move {
if let Err(e) = Self::run_publisher_loop(endpoint, tracker, sequence).await {
panic!("Publisher task failed: {}", e);
}
});
Ok(Self {
endpoint: publisher.endpoint,
tracker: publisher.tracker,
sequence: publisher.sequence,
task_handle: Some(handle),
})
}
pub async fn shutdown(self) -> Result<()> {
if let Some(handle) = self.task_handle {
handle.abort();
let _ = handle.await;
}
Ok(())
}
async fn run_publisher_loop(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
sequence: Arc<AtomicU64>,
) -> Result<()> {
tracing::info!("Starting consolidated event publisher on {}", endpoint);
let mut socket = PubSocket::new();
socket
.bind(&endpoint)
.await
.with_context(|| format!("Failed to bind publisher to {}", endpoint))?;
tracing::info!("Publisher bound to {}", endpoint);
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(50));
loop {
interval.tick().await;
let events = {
let mut tracker_guard = tracker.write().await;
tracker_guard.drain_events()
};
if events.is_empty() {
continue;
}
tracing::debug!(
"Publishing {} consolidated event(s) to router",
events.len()
);
let vllm_events: Vec<Event> = events
.into_iter()
.filter_map(|event| match Event::from_consolidated(event) {
Ok(e) => Some(e),
Err(err) => {
tracing::error!("Failed to convert consolidated event, skipping: {}", err);
None
}
})
.collect();
if vllm_events.is_empty() {
tracing::warn!("All consolidated events failed validation, skipping publish");
continue;
}
let num_events = vllm_events.len();
let batch = EventBatch(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs_f64(), vllm_events, Some(0), );
let mut payload = Vec::new();
batch
.serialize(&mut Serializer::new(&mut payload))
.context("Failed to serialize event batch")?;
let seq = sequence.fetch_add(1, Ordering::SeqCst);
let seq_bytes = seq.to_be_bytes();
let frames = vec![
Bytes::from(""),
Bytes::from(seq_bytes.to_vec()),
Bytes::from(payload),
];
let msg = match zeromq::ZmqMessage::try_from(frames) {
Ok(m) => m,
Err(e) => {
tracing::error!("Failed to create multipart ZMQ message: {:?}", e);
continue;
}
};
if let Err(e) = socket.send(msg).await {
tracing::error!("Failed to send consolidated events: {}", e);
} else {
tracing::debug!(
"Consolidator: Published batch with {} event(s) to ZMQ (seq={})",
num_events,
seq
);
}
}
}
}