inference_lab/config/hardware.rs
1use serde::Deserialize;
2
3fn default_gpu_memory_utilization() -> f64 {
4 0.9
5}
6
7#[derive(Debug, Clone, Deserialize)]
8pub struct HardwareConfig {
9 /// Accelerator name (e.g., "H100", "A100")
10 pub name: String,
11
12 /// Compute capacity in FLOPS (for specific precision, e.g., bf16)
13 pub compute_flops: f64,
14
15 /// Memory bandwidth in bytes/sec
16 pub memory_bandwidth: f64,
17
18 /// Total memory capacity in bytes
19 pub memory_capacity: u64,
20
21 /// KV cache capacity in bytes (subset of memory_capacity)
22 /// If not specified, calculated from gpu_memory_utilization
23 #[serde(default)]
24 pub kv_cache_capacity: u64,
25
26 /// Fraction of GPU memory to use (vLLM default: 0.9)
27 /// Used to calculate kv_cache_capacity if not explicitly set
28 #[serde(default = "default_gpu_memory_utilization")]
29 pub gpu_memory_utilization: f64,
30
31 /// Number of bytes per parameter (1 for fp8, 2 for bf16)
32 pub bytes_per_param: u32,
33
34 /// Compute-bound threshold (derived from flops/bandwidth ratio)
35 /// This is calculated: bytes_per_param * compute_flops / memory_bandwidth
36 #[serde(skip)]
37 pub compute_bound_threshold: u32,
38}
39
40impl HardwareConfig {
41 /// Calculate and set the compute-bound threshold
42 pub fn compute_threshold(&mut self) {
43 self.compute_bound_threshold =
44 (self.bytes_per_param as f64 * self.compute_flops / self.memory_bandwidth) as u32;
45 }
46
47 /// Calculate KV cache capacity if not explicitly set
48 /// Formula: (memory_capacity * gpu_memory_utilization) - model_size
49 /// This matches vLLM's behavior: requested_memory - non_kv_cache_memory
50 pub fn compute_kv_cache_capacity(&mut self, model_size_bytes: u64) {
51 if self.kv_cache_capacity == 0 {
52 let requested_memory = (self.memory_capacity as f64 * self.gpu_memory_utilization) as u64;
53 // In vLLM, non_kv_cache_memory includes weights + activations + overhead
54 // For simplicity, we approximate this as just the model weights
55 self.kv_cache_capacity = requested_memory.saturating_sub(model_size_bytes);
56 }
57 }
58
59 /// Initialize with threshold pre-computed
60 pub fn with_threshold(mut self) -> Self {
61 self.compute_threshold();
62 self
63 }
64}