inference_lab/compute/
arithmetic.rs1use crate::config::{HardwareConfig, ModelConfig};
4
5pub fn compute_bound_threshold(hardware: &HardwareConfig) -> u32 {
8 (hardware.bytes_per_param as f64 * hardware.compute_flops / hardware.memory_bandwidth) as u32
9}
10
11pub fn flops_for_tokens(
16 total_tokens: u32,
17 model: &ModelConfig,
18 requests: &[&crate::request::Request],
19 tokens_per_request: &[u32],
20) -> f64 {
21 let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
24
25 let mut attention_flops = 0.0;
28 for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
29 let batch_size = 1.0; let s = num_new_tokens as f64; let t = (req.num_computed_tokens + num_new_tokens) as f64; let d = model.hidden_dim as f64;
33 let l = model.num_layers as f64;
34
35 attention_flops += 4.0 * l * batch_size * s * t * d;
38 }
39
40 matmul_flops + attention_flops
41}
42
43pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
46 model.num_parameters as f64 * hardware.bytes_per_param as f64
47}
48
49pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
52 model.kv_cache_bytes_per_token as f64 * seq_len as f64
53}
54
55pub fn total_memory_transfer(
58 model: &ModelConfig,
59 hardware: &HardwareConfig,
60 request_seq_lens: &[u32],
61) -> f64 {
62 let weight_bytes = model_weight_bytes(model, hardware);
63 let kv_bytes: f64 = request_seq_lens
64 .iter()
65 .map(|&seq_len| kv_cache_bytes(seq_len, model))
66 .sum();
67 weight_bytes + kv_bytes
68}
69
70pub fn is_compute_bound(num_tokens: u32, hardware: &HardwareConfig) -> bool {
73 num_tokens >= hardware.compute_bound_threshold
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use crate::config::Config;
80
81 #[test]
82 fn test_compute_bound_threshold_h100() {
83 let mut hardware = crate::config::HardwareConfig {
84 name: "H100".to_string(),
85 compute_flops: 1.513e15, memory_bandwidth: 3.35e12, memory_capacity: 80_000_000_000,
88 kv_cache_capacity: 60_000_000_000,
89 gpu_memory_utilization: 0.9,
90 bytes_per_param: 2, compute_bound_threshold: 0,
92 };
93 hardware.compute_threshold();
94
95 let threshold = compute_bound_threshold(&hardware);
96
97 assert!(threshold > 900 && threshold < 910);
99 assert_eq!(threshold, hardware.compute_bound_threshold);
100 }
101
102 #[test]
103 fn test_flops_calculation() {
104 let config = Config::test_default();
105
106 let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
108 req.num_computed_tokens = 0;
109 let requests = vec![&req];
110 let tokens = vec![100];
111
112 let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
114
115 assert!(flops >= 1.4e12); }
120
121 #[test]
122 fn test_model_weight_bytes() {
123 let config = Config::test_default();
124
125 let bytes = model_weight_bytes(&config.model, &config.hardware);
126
127 assert_eq!(bytes, 14_000_000_000.0);
129 }
130
131 #[test]
132 fn test_kv_cache_bytes() {
133 let mut config = Config::test_default();
134 config.model.kv_cache_bytes_per_token = 524_288; let bytes = kv_cache_bytes(100, &config.model);
137
138 assert_eq!(bytes, 52_428_800.0);
140 }
141
142 #[test]
143 fn test_total_memory_transfer() {
144 let mut config = Config::test_default();
145 config.model.kv_cache_bytes_per_token = 524_288;
146
147 let seq_lens = vec![50, 100, 150];
149 let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
150
151 let expected_weights = 14_000_000_000.0;
152 let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
153 assert_eq!(total, expected_weights + expected_kv);
154 }
155
156 #[test]
157 fn test_is_compute_bound() {
158 let mut hardware = crate::config::HardwareConfig {
159 name: "Test".to_string(),
160 compute_flops: 1e15,
161 memory_bandwidth: 1e12,
162 memory_capacity: 80_000_000_000,
163 kv_cache_capacity: 60_000_000_000,
164 gpu_memory_utilization: 0.9,
165 bytes_per_param: 2,
166 compute_bound_threshold: 0,
167 };
168 hardware.compute_threshold();
169
170 assert_eq!(hardware.compute_bound_threshold, 2000);
172
173 assert!(!is_compute_bound(1000, &hardware));
174 assert!(!is_compute_bound(1999, &hardware));
175 assert!(is_compute_bound(2000, &hardware));
176 assert!(is_compute_bound(3000, &hardware));
177 }
178}