use crate::tokens::Token;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RouterRequest {
pub tokens: Vec<Token>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RouterResponse {
pub worker_id: i64,
}
#[derive(Debug)]
pub struct WorkerSelectionResult {
pub worker_id: i64,
pub required_blocks: u64,
pub overlap_blocks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics {
pub request_active_slots: u64,
pub request_total_slots: u64,
pub kv_active_blocks: u64,
pub kv_total_blocks: u64,
pub num_requests_waiting: u64,
pub gpu_cache_usage_perc: f32,
pub gpu_prefix_cache_hit_rate: f32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct LocalBlockHash(pub u64);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct ExternalSequenceBlockHash(pub u64);
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvents {
pub events: Vec<KvCacheEvent>,
pub shutdown: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvent {
pub event_id: u64,
pub data: KvCacheEventData,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum KvCacheEventData {
Stored(KvCacheStoreData),
Removed(KvCacheRemoveData),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheStoreData {
pub parent_hash: Option<ExternalSequenceBlockHash>,
pub blocks: Vec<KvCacheStoredBlockData>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheStoredBlockData {
pub block_hash: ExternalSequenceBlockHash,
pub tokens_hash: LocalBlockHash,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheRemoveData {
pub block_hashes: Vec<ExternalSequenceBlockHash>,
}
impl Serialize for LocalBlockHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.0)
}
}
impl<'de> Deserialize<'de> for LocalBlockHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u64::deserialize(deserializer)?;
Ok(LocalBlockHash(value))
}
}
impl Serialize for ExternalSequenceBlockHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.0)
}
}
impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u64::deserialize(deserializer)?;
Ok(ExternalSequenceBlockHash(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345);
let serialized = serde_json::to_string(&hash).unwrap();
assert_eq!(serialized, "12345");
let deserialized: LocalBlockHash = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, hash);
}
#[test]
fn test_external_sequence_block_hash_serialization() {
let hash = ExternalSequenceBlockHash(67890);
let serialized = serde_json::to_string(&hash).unwrap();
assert_eq!(serialized, "67890");
let deserialized: ExternalSequenceBlockHash = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, hash);
}
#[test]
fn test_kv_cache_events_serialization() {
let event_data = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(3),
}],
});
let event = KvCacheEvent {
event_id: 1,
data: event_data,
};
let events = KvCacheEvents {
events: vec![event],
shutdown: false,
};
let serialized = serde_json::to_string(&events).unwrap();
let deserialized: KvCacheEvents = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.events.len(), 1);
assert_eq!(deserialized.events[0].event_id, 1);
if let KvCacheEventData::Stored(store_data) = &deserialized.events[0].data {
assert_eq!(store_data.parent_hash.unwrap().0, 1);
assert_eq!(store_data.blocks.len(), 1);
assert_eq!(store_data.blocks[0].block_hash.0, 2);
assert_eq!(store_data.blocks[0].tokens_hash.0, 3);
} else {
panic!("Expected KvCacheEventData::Stored variant");
}
assert!(!deserialized.shutdown);
}
#[test]
fn test_kv_cache_remove_data_serialization() {
let remove_data = KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(4), ExternalSequenceBlockHash(5)],
};
let serialized = serde_json::to_string(&remove_data).unwrap();
let deserialized: KvCacheRemoveData = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.block_hashes.len(), 2);
assert_eq!(deserialized.block_hashes[0].0, 4);
assert_eq!(deserialized.block_hashes[1].0, 5);
}
}