1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4use std::str::FromStr;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub struct KvBlockKey {
8 pub model_id: String,
9 pub tokenizer_id: String,
10 pub adapter_id: Option<String>,
11 pub tenant_id: String,
12 pub prefix_hash: String,
13 pub block_hash: String,
14 pub block_index: u32,
15 pub token_count: u32,
16}
17
18impl KvBlockKey {
19 pub fn new(
20 model_id: impl Into<String>,
21 tokenizer_id: impl Into<String>,
22 tenant_id: impl Into<String>,
23 prefix_hash: impl Into<String>,
24 block_hash: impl Into<String>,
25 block_index: u32,
26 token_count: u32,
27 ) -> Self {
28 Self {
29 model_id: model_id.into(),
30 tokenizer_id: tokenizer_id.into(),
31 adapter_id: None,
32 tenant_id: tenant_id.into(),
33 prefix_hash: prefix_hash.into(),
34 block_hash: block_hash.into(),
35 block_index,
36 token_count,
37 }
38 }
39
40 pub fn external_hash(parts: ExternalKvBlockKey) -> Self {
41 Self {
42 model_id: parts.model_id,
43 tokenizer_id: parts.tokenizer_id,
44 adapter_id: parts.adapter_id,
45 tenant_id: parts.tenant_id,
46 prefix_hash: parts.prefix_hash,
47 block_hash: parts.block_hash,
48 block_index: parts.block_index,
49 token_count: parts.token_count,
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub struct ExternalKvBlockKey {
56 pub model_id: String,
57 pub tokenizer_id: String,
58 pub adapter_id: Option<String>,
59 pub tenant_id: String,
60 pub prefix_hash: String,
61 pub block_hash: String,
62 pub block_index: u32,
63 pub token_count: u32,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum CacheTier {
68 Hbm,
69 RemoteHbm,
70 CpuDram,
71 LocalSsd,
72 ObjectStore,
73}
74
75impl fmt::Display for CacheTier {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 CacheTier::Hbm => f.write_str("hbm"),
79 CacheTier::RemoteHbm => f.write_str("remote_hbm"),
80 CacheTier::CpuDram => f.write_str("cpu_dram"),
81 CacheTier::LocalSsd => f.write_str("local_ssd"),
82 CacheTier::ObjectStore => f.write_str("object_store"),
83 }
84 }
85}
86
87impl FromStr for CacheTier {
88 type Err = CacheTierParseError;
89
90 fn from_str(value: &str) -> Result<Self, Self::Err> {
91 match value.to_ascii_lowercase().as_str() {
92 "hbm" | "gpu" | "gpu_memory" | "vram" => Ok(Self::Hbm),
93 "remote_hbm" | "remote_gpu" | "remote_gpu_memory" => Ok(Self::RemoteHbm),
94 "cpu" | "dram" | "cpu_dram" | "host" | "host_memory" => Ok(Self::CpuDram),
95 "ssd" | "local_ssd" | "nvme" | "disk" => Ok(Self::LocalSsd),
96 "object" | "object_store" | "s3" | "blob" => Ok(Self::ObjectStore),
97 _ => Err(CacheTierParseError {
98 value: value.to_string(),
99 }),
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
105#[error("unknown cache tier or medium: {value}")]
106pub struct CacheTierParseError {
107 pub value: String,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111pub struct CacheResidency {
112 pub key: KvBlockKey,
113 pub worker_id: String,
114 pub tier: CacheTier,
115 pub bytes: u64,
116 pub last_access_ms: u64,
117 pub ref_count: u32,
118 pub pinned: bool,
119}
120
121impl CacheResidency {
122 pub fn hbm(worker_id: impl Into<String>, key: KvBlockKey, bytes: u64) -> Self {
123 Self {
124 key,
125 worker_id: worker_id.into(),
126 tier: CacheTier::Hbm,
127 bytes,
128 last_access_ms: 0,
129 ref_count: 0,
130 pinned: false,
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct WorkerState {
137 pub id: String,
138 pub locality_domain: String,
139 pub hbm_capacity_bytes: u64,
140 pub hbm_used_bytes: u64,
141 pub cpu_capacity_bytes: u64,
142 pub cpu_used_bytes: u64,
143 pub running_decodes: u32,
144 pub queued_prefill_tokens: u32,
145}
146
147impl WorkerState {
148 pub fn new(id: impl Into<String>, locality_domain: impl Into<String>) -> Self {
149 Self {
150 id: id.into(),
151 locality_domain: locality_domain.into(),
152 hbm_capacity_bytes: 80 * 1024 * 1024 * 1024,
153 hbm_used_bytes: 0,
154 cpu_capacity_bytes: 512 * 1024 * 1024 * 1024,
155 cpu_used_bytes: 0,
156 running_decodes: 0,
157 queued_prefill_tokens: 0,
158 }
159 }
160
161 pub fn with_load(mut self, queued_prefill_tokens: u32, running_decodes: u32) -> Self {
162 self.queued_prefill_tokens = queued_prefill_tokens;
163 self.running_decodes = running_decodes;
164 self
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
169pub enum EngineKind {
170 Vllm,
171 Sglang,
172 Lmcache,
173 Mock,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
177pub struct EngineEndpoint {
178 pub id: String,
179 pub kind: EngineKind,
180 pub base_url: String,
181 pub model_id: String,
182 pub tokenizer_id: String,
183 pub tenant_id: String,
184 pub locality_domain: String,
185}
186
187impl EngineEndpoint {
188 pub fn worker_state(&self) -> WorkerState {
189 WorkerState::new(self.id.clone(), self.locality_domain.clone())
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
194pub struct BlockHint {
195 pub block_hash: String,
196 pub token_count: u32,
197 pub bytes: Option<u64>,
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
201pub struct RequestKvHints {
202 pub request_id: Option<String>,
203 pub model_id: Option<String>,
204 pub tokenizer_id: Option<String>,
205 pub adapter_id: Option<String>,
206 pub tenant_id: Option<String>,
207 pub session_id: Option<String>,
208 pub block_hashes: Vec<String>,
209 pub block_tokens: Option<u32>,
210 pub estimated_decode_tokens: Option<u32>,
211 pub block_bytes: Option<u64>,
212}
213
214impl RequestKvHints {
215 pub fn to_blocks(
216 &self,
217 fallback_model_id: &str,
218 fallback_tokenizer_id: &str,
219 fallback_tenant_id: &str,
220 ) -> Vec<KvBlockKey> {
221 let model_id = self.model_id.as_deref().unwrap_or(fallback_model_id);
222 let tokenizer_id = self
223 .tokenizer_id
224 .as_deref()
225 .unwrap_or(fallback_tokenizer_id);
226 let tenant_id = self.tenant_id.as_deref().unwrap_or(fallback_tenant_id);
227 let token_count = self.block_tokens.unwrap_or(64);
228 let mut parent = self.session_id.as_deref().unwrap_or("root").to_string();
229
230 self.block_hashes
231 .iter()
232 .enumerate()
233 .map(|(idx, block_hash)| {
234 let key = KvBlockKey::external_hash(ExternalKvBlockKey {
235 model_id: model_id.to_string(),
236 tokenizer_id: tokenizer_id.to_string(),
237 adapter_id: self.adapter_id.clone(),
238 tenant_id: tenant_id.to_string(),
239 prefix_hash: parent.clone(),
240 block_hash: block_hash.clone(),
241 block_index: idx as u32,
242 token_count,
243 });
244 parent = block_hash.clone();
245 key
246 })
247 .collect()
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct KvEventBatch {
253 pub engine_id: String,
254 pub ts_ms: Option<u64>,
255 #[serde(default)]
256 pub model_id: Option<String>,
257 #[serde(default)]
258 pub tokenizer_id: Option<String>,
259 #[serde(default)]
260 pub adapter_id: Option<String>,
261 #[serde(default)]
262 pub tenant_id: Option<String>,
263 #[serde(default)]
264 pub bytes_per_block: Option<u64>,
265 pub events: Vec<KvEvent>,
266}
267
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
269#[serde(tag = "type", rename_all = "snake_case")]
270pub enum KvEvent {
271 BlockStored(BlockStoredEvent),
272 BlockRemoved(BlockRemovedEvent),
273 AllBlocksCleared,
274}
275
276#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
277pub struct BlockStoredEvent {
278 pub block_hashes: Vec<String>,
279 #[serde(default)]
280 pub parent_block_hash: Option<String>,
281 #[serde(default)]
282 pub token_ids: Vec<u32>,
283 pub block_size: u32,
284 #[serde(default)]
285 pub medium: Option<String>,
286 #[serde(default)]
287 pub lora_name: Option<String>,
288 #[serde(default)]
289 pub group_idx: Option<u32>,
290 #[serde(default)]
291 pub bytes_per_block: Option<u64>,
292}
293
294#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
295pub struct BlockRemovedEvent {
296 pub block_hashes: Vec<String>,
297 #[serde(default)]
298 pub medium: Option<String>,
299 #[serde(default)]
300 pub group_idx: Option<u32>,
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
304pub struct SloTarget {
305 pub ttft_ms: u64,
306 pub tpot_ms: u64,
307}
308
309impl Default for SloTarget {
310 fn default() -> Self {
311 Self {
312 ttft_ms: 800,
313 tpot_ms: 80,
314 }
315 }
316}
317
318#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
319pub struct RequestShape {
320 pub id: String,
321 pub model_id: String,
322 pub tokenizer_id: String,
323 pub adapter_id: Option<String>,
324 pub tenant_id: String,
325 pub blocks: Vec<KvBlockKey>,
326 pub estimated_decode_tokens: u32,
327 pub slo: SloTarget,
328}
329
330impl RequestShape {
331 pub fn input_tokens(&self) -> u32 {
332 self.blocks.iter().map(|block| block.token_count).sum()
333 }
334}
335
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
337pub struct CostModel {
338 pub prefill_us_per_token: u64,
339 pub decode_us_per_token: u64,
340 pub queue_us_per_prefill_token: u64,
341 pub running_decode_penalty_us: u64,
342 pub hbm_hit_us: u64,
343 pub remote_hbm_us_per_mb: u64,
344 pub cpu_dram_us_per_mb: u64,
345 pub local_ssd_us_per_mb: u64,
346 pub object_store_us_per_mb: u64,
347 pub cross_domain_penalty_us: u64,
348}
349
350impl Default for CostModel {
351 fn default() -> Self {
352 Self {
353 prefill_us_per_token: 45,
354 decode_us_per_token: 80,
355 queue_us_per_prefill_token: 4,
356 running_decode_penalty_us: 1_500,
357 hbm_hit_us: 5,
358 remote_hbm_us_per_mb: 20,
359 cpu_dram_us_per_mb: 55,
360 local_ssd_us_per_mb: 280,
361 object_store_us_per_mb: 1_800,
362 cross_domain_penalty_us: 350,
363 }
364 }
365}
366
367impl CostModel {
368 pub fn prefill_cost_us(&self, tokens: u32) -> u64 {
369 self.prefill_us_per_token * u64::from(tokens)
370 }
371
372 pub fn decode_cost_us(&self, tokens: u32, running_decodes: u32) -> u64 {
373 self.decode_us_per_token * u64::from(tokens)
374 + self.running_decode_penalty_us * u64::from(running_decodes)
375 }
376
377 pub fn queue_cost_us(&self, worker: &WorkerState) -> u64 {
378 self.queue_us_per_prefill_token * u64::from(worker.queued_prefill_tokens)
379 }
380
381 pub fn transfer_cost_us(
382 &self,
383 tier: CacheTier,
384 bytes: u64,
385 same_worker: bool,
386 same_locality_domain: bool,
387 ) -> u64 {
388 if same_worker && tier == CacheTier::Hbm {
389 return self.hbm_hit_us;
390 }
391
392 let mb = bytes.div_ceil(1024 * 1024).max(1);
393 let base = match tier {
394 CacheTier::Hbm | CacheTier::RemoteHbm => self.remote_hbm_us_per_mb,
395 CacheTier::CpuDram => self.cpu_dram_us_per_mb,
396 CacheTier::LocalSsd => self.local_ssd_us_per_mb,
397 CacheTier::ObjectStore => self.object_store_us_per_mb,
398 } * mb;
399
400 if same_worker || same_locality_domain {
401 base
402 } else {
403 base + self.cross_domain_penalty_us
404 }
405 }
406}
407
408#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
416pub struct IdentityScope {
417 pub model_id: String,
418 pub tokenizer_id: String,
419 pub adapter_id: Option<String>,
420 pub tenant_id: String,
421}
422
423impl IdentityScope {
424 pub fn from_key(key: &KvBlockKey) -> Self {
425 Self {
426 model_id: key.model_id.clone(),
427 tokenizer_id: key.tokenizer_id.clone(),
428 adapter_id: key.adapter_id.clone(),
429 tenant_id: key.tenant_id.clone(),
430 }
431 }
432
433 pub fn matches(&self, key: &KvBlockKey) -> bool {
437 self.model_id == key.model_id
438 && self.tokenizer_id == key.tokenizer_id
439 && self.adapter_id == key.adapter_id
440 && self.tenant_id == key.tenant_id
441 }
442}
443
444#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
451pub struct IndexMetrics {
452 pub resident_blocks: u64,
453 pub resident_bytes: u64,
454 pub puts: u64,
455 pub removes: u64,
456 pub prefix_scans: u64,
457 pub bytes_written: u64,
460}
461
462pub trait IndexBackend: std::fmt::Debug + Send + Sync {
478 fn name(&self) -> &str;
480
481 fn put(&mut self, residency: CacheResidency);
483
484 fn locate(&self, key: &KvBlockKey) -> Vec<CacheResidency>;
487
488 fn prefix_scan(&self, scope: &IdentityScope, prefix_hash: &str) -> Vec<CacheResidency>;
492
493 fn remove_block(&mut self, scope: &IdentityScope, worker_id: &str, block_hash: &str) -> usize;
496
497 fn clear_worker(&mut self, worker_id: &str);
500
501 fn clear(&mut self);
503
504 fn snapshot(&self) -> Vec<CacheResidency>;
507
508 fn len(&self) -> usize;
510
511 fn is_empty(&self) -> bool {
512 self.len() == 0
513 }
514
515 fn metrics(&self) -> IndexMetrics {
519 let snapshot = self.snapshot();
520 IndexMetrics {
521 resident_blocks: snapshot.len() as u64,
522 resident_bytes: snapshot.iter().map(|entry| entry.bytes).sum(),
523 ..IndexMetrics::default()
524 }
525 }
526
527 fn persistent(&self) -> bool {
532 false
533 }
534}
535
536#[derive(Debug, Default)]
541pub struct MemoryIndex {
542 entries: HashMap<KvBlockKey, Vec<CacheResidency>>,
543 puts: u64,
544 removes: u64,
545}
546
547impl MemoryIndex {
548 pub fn new() -> Self {
549 Self::default()
550 }
551}
552
553impl IndexBackend for MemoryIndex {
554 fn name(&self) -> &str {
555 "memory"
556 }
557
558 fn put(&mut self, residency: CacheResidency) {
559 let entries = self.entries.entry(residency.key.clone()).or_default();
560 entries.retain(|entry| {
561 !(entry.worker_id == residency.worker_id && entry.tier == residency.tier)
562 });
563 entries.push(residency);
564 self.puts += 1;
565 }
566
567 fn locate(&self, key: &KvBlockKey) -> Vec<CacheResidency> {
568 self.entries.get(key).cloned().unwrap_or_default()
569 }
570
571 fn prefix_scan(&self, scope: &IdentityScope, prefix_hash: &str) -> Vec<CacheResidency> {
572 self.entries
573 .iter()
574 .filter(|(key, _)| scope.matches(key) && key.prefix_hash == prefix_hash)
575 .flat_map(|(_, entries)| entries.iter().cloned())
576 .collect()
577 }
578
579 fn remove_block(&mut self, scope: &IdentityScope, worker_id: &str, block_hash: &str) -> usize {
580 let mut removed = 0;
581 self.entries.retain(|key, entries| {
582 if scope.matches(key) && key.block_hash == block_hash {
583 let before = entries.len();
584 entries.retain(|entry| entry.worker_id != worker_id);
585 removed += before - entries.len();
586 }
587 !entries.is_empty()
588 });
589 self.removes += removed as u64;
590 removed
591 }
592
593 fn clear_worker(&mut self, worker_id: &str) {
594 self.entries.retain(|_, entries| {
595 entries.retain(|entry| entry.worker_id != worker_id);
596 !entries.is_empty()
597 });
598 }
599
600 fn clear(&mut self) {
601 self.entries.clear();
602 }
603
604 fn snapshot(&self) -> Vec<CacheResidency> {
605 self.entries
606 .values()
607 .flat_map(|entries| entries.iter().cloned())
608 .collect()
609 }
610
611 fn len(&self) -> usize {
612 self.entries.values().map(Vec::len).sum()
613 }
614
615 fn metrics(&self) -> IndexMetrics {
616 IndexMetrics {
617 resident_blocks: self.len() as u64,
618 resident_bytes: self
619 .entries
620 .values()
621 .flatten()
622 .map(|entry| entry.bytes)
623 .sum(),
624 puts: self.puts,
625 removes: self.removes,
626 prefix_scans: 0,
627 bytes_written: 0,
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn hbm_hit_is_cheaper_than_recompute_for_a_block() {
638 let cost = CostModel::default();
639 assert!(
640 cost.transfer_cost_us(CacheTier::Hbm, 4 * 1024 * 1024, true, true)
641 < cost.prefill_cost_us(64)
642 );
643 }
644}