inference_lab/compute/
engine.rs1use super::arithmetic;
3use crate::config::{HardwareConfig, ModelConfig};
4use crate::request::Request;
5
6pub struct ComputeEngine {
7 hardware: HardwareConfig,
8 model: ModelConfig,
9}
10
11impl ComputeEngine {
12 pub fn new(hardware: HardwareConfig, model: ModelConfig) -> Self {
13 Self { hardware, model }
14 }
15
16 pub fn calculate_iteration_time(
20 &self,
21 batch_requests: &[&Request],
22 tokens_per_request: &[u32],
23 ) -> f64 {
24 if batch_requests.is_empty() {
25 return 0.0;
26 }
27
28 let total_tokens: u32 = tokens_per_request.iter().sum();
29
30 let flops = arithmetic::flops_for_tokens(
32 total_tokens,
33 &self.model,
34 batch_requests,
35 tokens_per_request,
36 );
37 let compute_time = flops / self.hardware.compute_flops;
38
39 let bytes = self.calculate_bytes_transferred(batch_requests, tokens_per_request);
41 let memory_time = bytes / self.hardware.memory_bandwidth;
42
43 compute_time.max(memory_time)
45 }
46
47 pub fn calculate_flops_utilization(
49 &self,
50 batch_requests: &[&Request],
51 tokens_per_request: &[u32],
52 actual_time: f64,
53 ) -> f64 {
54 if actual_time == 0.0 {
55 return 0.0;
56 }
57
58 let total_tokens: u32 = tokens_per_request.iter().sum();
59 let flops = arithmetic::flops_for_tokens(
60 total_tokens,
61 &self.model,
62 batch_requests,
63 tokens_per_request,
64 );
65 let theoretical_time = flops / self.hardware.compute_flops;
66 (theoretical_time / actual_time).min(1.0)
67 }
68
69 pub fn calculate_bandwidth_utilization(&self, bytes_transferred: f64, actual_time: f64) -> f64 {
71 if actual_time == 0.0 {
72 return 0.0;
73 }
74
75 let theoretical_time = bytes_transferred / self.hardware.memory_bandwidth;
76 (theoretical_time / actual_time).min(1.0)
77 }
78
79 pub fn calculate_bytes_transferred(
81 &self,
82 batch_requests: &[&Request],
83 tokens_per_request: &[u32],
84 ) -> f64 {
85 let weight_bytes = arithmetic::model_weight_bytes(&self.model, &self.hardware);
87
88 let mut kv_cache_bytes = 0.0;
90 for (req, &tokens) in batch_requests.iter().zip(tokens_per_request) {
91 let avg_seq_len = req.num_computed_tokens + tokens / 2;
93 kv_cache_bytes += arithmetic::kv_cache_bytes(avg_seq_len, &self.model);
94 }
95
96 weight_bytes + kv_cache_bytes
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::config::Config;
104 use crate::request::Request;
105
106 fn create_test_engine() -> ComputeEngine {
107 let config = Config::test_default();
108 ComputeEngine::new(config.hardware, config.model)
109 }
110
111 fn create_test_request(id: &str, computed: u32, prompt: u32) -> Request {
112 let mut req = Request::new(id.to_string(), 0, 0.0, prompt, 50);
113 req.num_computed_tokens = computed;
114 req
115 }
116
117 #[test]
118 fn test_high_token_time() {
119 let engine = create_test_engine();
120
121 let req1 = create_test_request("req-1", 0, 1000);
123 let req2 = create_test_request("req-2", 0, 1000);
124
125 let requests = vec![&req1, &req2];
126 let tokens = vec![1000, 1000];
127
128 let time = engine.calculate_iteration_time(&requests, &tokens);
129
130 assert!(time > 0.0);
133 }
134
135 #[test]
136 fn test_low_token_time() {
137 let engine = create_test_engine();
138
139 let req1 = create_test_request("req-1", 0, 100);
141
142 let requests = vec![&req1];
143 let tokens = vec![50]; let time = engine.calculate_iteration_time(&requests, &tokens);
146
147 assert!(time > 0.0);
150 }
151
152 #[test]
153 fn test_empty_batch() {
154 let engine = create_test_engine();
155
156 let requests: Vec<&Request> = vec![];
157 let tokens: Vec<u32> = vec![];
158
159 let time = engine.calculate_iteration_time(&requests, &tokens);
160 assert_eq!(time, 0.0);
161 }
162
163 #[test]
164 fn test_flops_utilization() {
165 let engine = create_test_engine();
166
167 let req = create_test_request("req-1", 0, 1000);
169 let requests = vec![&req];
170 let tokens = vec![1000];
171
172 let flops = arithmetic::flops_for_tokens(1000, &engine.model, &requests, &tokens);
173 let theoretical_time = flops / engine.hardware.compute_flops;
174
175 let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time);
177 assert!((util - 1.0).abs() < 1e-10);
178
179 let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time * 2.0);
181 assert!((util - 0.5).abs() < 1e-10);
182
183 let util = engine.calculate_flops_utilization(&requests, &tokens, 0.0);
185 assert_eq!(util, 0.0);
186 }
187
188 #[test]
189 fn test_bandwidth_utilization() {
190 let engine = create_test_engine();
191
192 let bytes = 1e12; let theoretical_time = bytes / engine.hardware.memory_bandwidth;
194
195 let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time);
197 assert!((util - 1.0).abs() < 1e-10);
198
199 let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time * 2.0);
201 assert!((util - 0.5).abs() < 1e-10);
202
203 let util = engine.calculate_bandwidth_utilization(bytes, 0.0);
205 assert_eq!(util, 0.0);
206 }
207}