Skip to main content

inference_lab/compute/
engine.rs

1/// Compute engine for calculating inference timing
2use super::arithmetic;
3use crate::config::{HardwareConfig, ModelConfig};
4use crate::request::Request;
5
6pub struct ComputeEngine {
7    hardware: HardwareConfig,
8    model: ModelConfig,
9}
10
11impl ComputeEngine {
12    pub fn new(hardware: HardwareConfig, model: ModelConfig) -> Self {
13        Self { hardware, model }
14    }
15
16    /// Calculate time to process an iteration (in seconds)
17    /// Takes batch of requests and number of tokens to process for each
18    /// Returns max(compute_time, memory_time) since they happen in parallel
19    pub fn calculate_iteration_time(
20        &self,
21        batch_requests: &[&Request],
22        tokens_per_request: &[u32],
23    ) -> f64 {
24        if batch_requests.is_empty() {
25            return 0.0;
26        }
27
28        let total_tokens: u32 = tokens_per_request.iter().sum();
29
30        // Calculate compute time: FLOPs / compute throughput
31        let flops = arithmetic::flops_for_tokens(
32            total_tokens,
33            &self.model,
34            batch_requests,
35            tokens_per_request,
36        );
37        let compute_time = flops / self.hardware.compute_flops;
38
39        // Calculate memory time: bytes transferred / memory bandwidth
40        let bytes = self.calculate_bytes_transferred(batch_requests, tokens_per_request);
41        let memory_time = bytes / self.hardware.memory_bandwidth;
42
43        // We're limited by whichever takes longer
44        compute_time.max(memory_time)
45    }
46
47    /// Calculate FLOPS utilization for this iteration (0.0 to 1.0)
48    pub fn calculate_flops_utilization(
49        &self,
50        batch_requests: &[&Request],
51        tokens_per_request: &[u32],
52        actual_time: f64,
53    ) -> f64 {
54        if actual_time == 0.0 {
55            return 0.0;
56        }
57
58        let total_tokens: u32 = tokens_per_request.iter().sum();
59        let flops = arithmetic::flops_for_tokens(
60            total_tokens,
61            &self.model,
62            batch_requests,
63            tokens_per_request,
64        );
65        let theoretical_time = flops / self.hardware.compute_flops;
66        (theoretical_time / actual_time).min(1.0)
67    }
68
69    /// Calculate memory bandwidth utilization for this iteration (0.0 to 1.0)
70    pub fn calculate_bandwidth_utilization(&self, bytes_transferred: f64, actual_time: f64) -> f64 {
71        if actual_time == 0.0 {
72            return 0.0;
73        }
74
75        let theoretical_time = bytes_transferred / self.hardware.memory_bandwidth;
76        (theoretical_time / actual_time).min(1.0)
77    }
78
79    /// Calculate total bytes transferred for a batch of requests
80    pub fn calculate_bytes_transferred(
81        &self,
82        batch_requests: &[&Request],
83        tokens_per_request: &[u32],
84    ) -> f64 {
85        // Model weights (constant per iteration)
86        let weight_bytes = arithmetic::model_weight_bytes(&self.model, &self.hardware);
87
88        // KV cache bytes (depends on sequence lengths)
89        let mut kv_cache_bytes = 0.0;
90        for (req, &tokens) in batch_requests.iter().zip(tokens_per_request) {
91            // Average sequence length during this iteration
92            let avg_seq_len = req.num_computed_tokens + tokens / 2;
93            kv_cache_bytes += arithmetic::kv_cache_bytes(avg_seq_len, &self.model);
94        }
95
96        weight_bytes + kv_cache_bytes
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::config::Config;
104    use crate::request::Request;
105
106    fn create_test_engine() -> ComputeEngine {
107        let config = Config::test_default();
108        ComputeEngine::new(config.hardware, config.model)
109    }
110
111    fn create_test_request(id: &str, computed: u32, prompt: u32) -> Request {
112        let mut req = Request::new(id.to_string(), 0, 0.0, prompt, 50);
113        req.num_computed_tokens = computed;
114        req
115    }
116
117    #[test]
118    fn test_high_token_time() {
119        let engine = create_test_engine();
120
121        // For 2000+ tokens, likely compute-bound
122        let req1 = create_test_request("req-1", 0, 1000);
123        let req2 = create_test_request("req-2", 0, 1000);
124
125        let requests = vec![&req1, &req2];
126        let tokens = vec![1000, 1000];
127
128        let time = engine.calculate_iteration_time(&requests, &tokens);
129
130        // Time should be max(compute_time, memory_time)
131        // With 2000 tokens, likely compute-bound
132        assert!(time > 0.0);
133    }
134
135    #[test]
136    fn test_low_token_time() {
137        let engine = create_test_engine();
138
139        // For few tokens, likely memory-bound
140        let req1 = create_test_request("req-1", 0, 100);
141
142        let requests = vec![&req1];
143        let tokens = vec![50]; // Only 50 tokens
144
145        let time = engine.calculate_iteration_time(&requests, &tokens);
146
147        // Time should be max(compute_time, memory_time)
148        // With few tokens, likely memory-bound
149        assert!(time > 0.0);
150    }
151
152    #[test]
153    fn test_empty_batch() {
154        let engine = create_test_engine();
155
156        let requests: Vec<&Request> = vec![];
157        let tokens: Vec<u32> = vec![];
158
159        let time = engine.calculate_iteration_time(&requests, &tokens);
160        assert_eq!(time, 0.0);
161    }
162
163    #[test]
164    fn test_flops_utilization() {
165        let engine = create_test_engine();
166
167        // Test with 1000 tokens
168        let req = create_test_request("req-1", 0, 1000);
169        let requests = vec![&req];
170        let tokens = vec![1000];
171
172        let flops = arithmetic::flops_for_tokens(1000, &engine.model, &requests, &tokens);
173        let theoretical_time = flops / engine.hardware.compute_flops;
174
175        // If actual time equals theoretical, utilization should be 100%
176        let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time);
177        assert!((util - 1.0).abs() < 1e-10);
178
179        // If actual time is 2x theoretical, utilization should be 50%
180        let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time * 2.0);
181        assert!((util - 0.5).abs() < 1e-10);
182
183        // Test with zero time
184        let util = engine.calculate_flops_utilization(&requests, &tokens, 0.0);
185        assert_eq!(util, 0.0);
186    }
187
188    #[test]
189    fn test_bandwidth_utilization() {
190        let engine = create_test_engine();
191
192        let bytes = 1e12; // 1 TB
193        let theoretical_time = bytes / engine.hardware.memory_bandwidth;
194
195        // If actual time equals theoretical, utilization should be 100%
196        let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time);
197        assert!((util - 1.0).abs() < 1e-10);
198
199        // If actual time is 2x theoretical, utilization should be 50%
200        let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time * 2.0);
201        assert!((util - 0.5).abs() < 1e-10);
202
203        // Test with zero time
204        let util = engine.calculate_bandwidth_utilization(bytes, 0.0);
205        assert_eq!(util, 0.0);
206    }
207}