inference_lab/compute/
arithmetic.rs1use crate::config::{HardwareConfig, ModelConfig};
3
4pub fn compute_bound_threshold(hardware: &HardwareConfig) -> u32 {
7 (hardware.bytes_per_param as f64 * hardware.compute_flops / hardware.memory_bandwidth) as u32
8}
9
10pub fn flops_for_tokens(
15 total_tokens: u32,
16 model: &ModelConfig,
17 requests: &[&crate::request::Request],
18 tokens_per_request: &[u32],
19) -> f64 {
20 let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
23
24 let mut attention_flops = 0.0;
27 for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
28 let batch_size = 1.0; let s = num_new_tokens as f64; let t = (req.num_computed_tokens + num_new_tokens) as f64; let d = model.hidden_dim as f64;
32 let l = model.num_layers as f64;
33
34 attention_flops += 4.0 * l * batch_size * s * t * d;
37 }
38
39 matmul_flops + attention_flops
40}
41
42pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
45 model.num_parameters as f64 * hardware.bytes_per_param as f64
46}
47
48pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
51 model.kv_cache_bytes_per_token as f64 * seq_len as f64
52}
53
54pub fn total_memory_transfer(
57 model: &ModelConfig,
58 hardware: &HardwareConfig,
59 request_seq_lens: &[u32],
60) -> f64 {
61 let weight_bytes = model_weight_bytes(model, hardware);
62 let kv_bytes: f64 = request_seq_lens
63 .iter()
64 .map(|&seq_len| kv_cache_bytes(seq_len, model))
65 .sum();
66 weight_bytes + kv_bytes
67}
68
69pub fn is_compute_bound(num_tokens: u32, hardware: &HardwareConfig) -> bool {
72 num_tokens >= hardware.compute_bound_threshold
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::config::Config;
79
80 #[test]
81 fn test_compute_bound_threshold_h100() {
82 let mut hardware = crate::config::HardwareConfig {
83 name: "H100".to_string(),
84 compute_flops: 1.513e15, memory_bandwidth: 3.35e12, memory_capacity: 80_000_000_000,
87 kv_cache_capacity: 60_000_000_000,
88 gpu_memory_utilization: 0.9,
89 bytes_per_param: 2, compute_bound_threshold: 0,
91 };
92 hardware.compute_threshold();
93
94 let threshold = compute_bound_threshold(&hardware);
95
96 assert!(threshold > 900 && threshold < 910);
98 assert_eq!(threshold, hardware.compute_bound_threshold);
99 }
100
101 #[test]
102 fn test_flops_calculation() {
103 let config = Config::test_default();
104
105 let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
107 req.num_computed_tokens = 0;
108 let requests = vec![&req];
109 let tokens = vec![100];
110
111 let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
113
114 assert!(flops >= 1.4e12); }
119
120 #[test]
121 fn test_model_weight_bytes() {
122 let config = Config::test_default();
123
124 let bytes = model_weight_bytes(&config.model, &config.hardware);
125
126 assert_eq!(bytes, 14_000_000_000.0);
128 }
129
130 #[test]
131 fn test_kv_cache_bytes() {
132 let mut config = Config::test_default();
133 config.model.kv_cache_bytes_per_token = 524_288; let bytes = kv_cache_bytes(100, &config.model);
136
137 assert_eq!(bytes, 52_428_800.0);
139 }
140
141 #[test]
142 fn test_total_memory_transfer() {
143 let mut config = Config::test_default();
144 config.model.kv_cache_bytes_per_token = 524_288;
145
146 let seq_lens = vec![50, 100, 150];
148 let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
149
150 let expected_weights = 14_000_000_000.0;
151 let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
152 assert_eq!(total, expected_weights + expected_kv);
153 }
154
155 #[test]
156 fn test_is_compute_bound() {
157 let mut hardware = crate::config::HardwareConfig {
158 name: "Test".to_string(),
159 compute_flops: 1e15,
160 memory_bandwidth: 1e12,
161 memory_capacity: 80_000_000_000,
162 kv_cache_capacity: 60_000_000_000,
163 gpu_memory_utilization: 0.9,
164 bytes_per_param: 2,
165 compute_bound_threshold: 0,
166 };
167 hardware.compute_threshold();
168
169 assert_eq!(hardware.compute_bound_threshold, 2000);
171
172 assert!(!is_compute_bound(1000, &hardware));
173 assert!(!is_compute_bound(1999, &hardware));
174 assert!(is_compute_bound(2000, &hardware));
175 assert!(is_compute_bound(3000, &hardware));
176 }
177}