inference_lab/compute/
engine.rs1use super::arithmetic;
4use crate::config::{HardwareConfig, ModelConfig};
5use crate::request::Request;
6
7pub struct ComputeEngine {
8 hardware: HardwareConfig,
9 model: ModelConfig,
10}
11
12impl ComputeEngine {
13 pub fn new(hardware: HardwareConfig, model: ModelConfig) -> Self {
14 Self { hardware, model }
15 }
16
17 pub fn calculate_iteration_time(
21 &self,
22 batch_requests: &[&Request],
23 tokens_per_request: &[u32],
24 ) -> f64 {
25 if batch_requests.is_empty() {
26 return 0.0;
27 }
28
29 let total_tokens: u32 = tokens_per_request.iter().sum();
30
31 let flops = arithmetic::flops_for_tokens(total_tokens, &self.model, batch_requests, tokens_per_request);
33 let compute_time = flops / self.hardware.compute_flops;
34
35 let bytes = self.calculate_bytes_transferred(batch_requests, tokens_per_request);
37 let memory_time = bytes / self.hardware.memory_bandwidth;
38
39 compute_time.max(memory_time)
41 }
42
43 pub fn calculate_flops_utilization(
45 &self,
46 batch_requests: &[&Request],
47 tokens_per_request: &[u32],
48 actual_time: f64,
49 ) -> f64 {
50 if actual_time == 0.0 {
51 return 0.0;
52 }
53
54 let total_tokens: u32 = tokens_per_request.iter().sum();
55 let flops = arithmetic::flops_for_tokens(total_tokens, &self.model, batch_requests, tokens_per_request);
56 let theoretical_time = flops / self.hardware.compute_flops;
57 (theoretical_time / actual_time).min(1.0)
58 }
59
60 pub fn calculate_bandwidth_utilization(
62 &self,
63 bytes_transferred: f64,
64 actual_time: f64,
65 ) -> f64 {
66 if actual_time == 0.0 {
67 return 0.0;
68 }
69
70 let theoretical_time = bytes_transferred / self.hardware.memory_bandwidth;
71 (theoretical_time / actual_time).min(1.0)
72 }
73
74 pub fn calculate_bytes_transferred(
76 &self,
77 batch_requests: &[&Request],
78 tokens_per_request: &[u32],
79 ) -> f64 {
80 let weight_bytes = arithmetic::model_weight_bytes(&self.model, &self.hardware);
82
83 let mut kv_cache_bytes = 0.0;
85 for (req, &tokens) in batch_requests.iter().zip(tokens_per_request) {
86 let avg_seq_len = req.num_computed_tokens + tokens / 2;
88 kv_cache_bytes += arithmetic::kv_cache_bytes(avg_seq_len, &self.model);
89 }
90
91 weight_bytes + kv_cache_bytes
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::config::Config;
99 use crate::request::Request;
100
101 fn create_test_engine() -> ComputeEngine {
102 let config = Config::test_default();
103 ComputeEngine::new(config.hardware, config.model)
104 }
105
106 fn create_test_request(id: &str, computed: u32, prompt: u32) -> Request {
107 let mut req = Request::new(id.to_string(), 0, 0.0, prompt, 50);
108 req.num_computed_tokens = computed;
109 req
110 }
111
112 #[test]
113 fn test_high_token_time() {
114 let engine = create_test_engine();
115
116 let req1 = create_test_request("req-1", 0, 1000);
118 let req2 = create_test_request("req-2", 0, 1000);
119
120 let requests = vec![&req1, &req2];
121 let tokens = vec![1000, 1000];
122
123 let time = engine.calculate_iteration_time(&requests, &tokens);
124
125 assert!(time > 0.0);
128 }
129
130 #[test]
131 fn test_low_token_time() {
132 let engine = create_test_engine();
133
134 let req1 = create_test_request("req-1", 0, 100);
136
137 let requests = vec![&req1];
138 let tokens = vec![50]; let time = engine.calculate_iteration_time(&requests, &tokens);
141
142 assert!(time > 0.0);
145 }
146
147 #[test]
148 fn test_empty_batch() {
149 let engine = create_test_engine();
150
151 let requests: Vec<&Request> = vec![];
152 let tokens: Vec<u32> = vec![];
153
154 let time = engine.calculate_iteration_time(&requests, &tokens);
155 assert_eq!(time, 0.0);
156 }
157
158 #[test]
159 fn test_flops_utilization() {
160 let engine = create_test_engine();
161
162 let req = create_test_request("req-1", 0, 1000);
164 let requests = vec![&req];
165 let tokens = vec![1000];
166
167 let flops = arithmetic::flops_for_tokens(1000, &engine.model, &requests, &tokens);
168 let theoretical_time = flops / engine.hardware.compute_flops;
169
170 let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time);
172 assert!((util - 1.0).abs() < 1e-10);
173
174 let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time * 2.0);
176 assert!((util - 0.5).abs() < 1e-10);
177
178 let util = engine.calculate_flops_utilization(&requests, &tokens, 0.0);
180 assert_eq!(util, 0.0);
181 }
182
183 #[test]
184 fn test_bandwidth_utilization() {
185 let engine = create_test_engine();
186
187 let bytes = 1e12; let theoretical_time = bytes / engine.hardware.memory_bandwidth;
189
190 let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time);
192 assert!((util - 1.0).abs() < 1e-10);
193
194 let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time * 2.0);
196 assert!((util - 0.5).abs() < 1e-10);
197
198 let util = engine.calculate_bandwidth_utilization(bytes, 0.0);
200 assert_eq!(util, 0.0);
201 }
202}