inference_lab/compute/
engine.rs

1/// Compute engine for calculating inference timing
2
3use super::arithmetic;
4use crate::config::{HardwareConfig, ModelConfig};
5use crate::request::Request;
6
7pub struct ComputeEngine {
8    hardware: HardwareConfig,
9    model: ModelConfig,
10}
11
12impl ComputeEngine {
13    pub fn new(hardware: HardwareConfig, model: ModelConfig) -> Self {
14        Self { hardware, model }
15    }
16
17    /// Calculate time to process an iteration (in seconds)
18    /// Takes batch of requests and number of tokens to process for each
19    /// Returns max(compute_time, memory_time) since they happen in parallel
20    pub fn calculate_iteration_time(
21        &self,
22        batch_requests: &[&Request],
23        tokens_per_request: &[u32],
24    ) -> f64 {
25        if batch_requests.is_empty() {
26            return 0.0;
27        }
28
29        let total_tokens: u32 = tokens_per_request.iter().sum();
30
31        // Calculate compute time: FLOPs / compute throughput
32        let flops = arithmetic::flops_for_tokens(total_tokens, &self.model, batch_requests, tokens_per_request);
33        let compute_time = flops / self.hardware.compute_flops;
34
35        // Calculate memory time: bytes transferred / memory bandwidth
36        let bytes = self.calculate_bytes_transferred(batch_requests, tokens_per_request);
37        let memory_time = bytes / self.hardware.memory_bandwidth;
38
39        // We're limited by whichever takes longer
40        compute_time.max(memory_time)
41    }
42
43    /// Calculate FLOPS utilization for this iteration (0.0 to 1.0)
44    pub fn calculate_flops_utilization(
45        &self,
46        batch_requests: &[&Request],
47        tokens_per_request: &[u32],
48        actual_time: f64,
49    ) -> f64 {
50        if actual_time == 0.0 {
51            return 0.0;
52        }
53
54        let total_tokens: u32 = tokens_per_request.iter().sum();
55        let flops = arithmetic::flops_for_tokens(total_tokens, &self.model, batch_requests, tokens_per_request);
56        let theoretical_time = flops / self.hardware.compute_flops;
57        (theoretical_time / actual_time).min(1.0)
58    }
59
60    /// Calculate memory bandwidth utilization for this iteration (0.0 to 1.0)
61    pub fn calculate_bandwidth_utilization(
62        &self,
63        bytes_transferred: f64,
64        actual_time: f64,
65    ) -> f64 {
66        if actual_time == 0.0 {
67            return 0.0;
68        }
69
70        let theoretical_time = bytes_transferred / self.hardware.memory_bandwidth;
71        (theoretical_time / actual_time).min(1.0)
72    }
73
74    /// Calculate total bytes transferred for a batch of requests
75    pub fn calculate_bytes_transferred(
76        &self,
77        batch_requests: &[&Request],
78        tokens_per_request: &[u32],
79    ) -> f64 {
80        // Model weights (constant per iteration)
81        let weight_bytes = arithmetic::model_weight_bytes(&self.model, &self.hardware);
82
83        // KV cache bytes (depends on sequence lengths)
84        let mut kv_cache_bytes = 0.0;
85        for (req, &tokens) in batch_requests.iter().zip(tokens_per_request) {
86            // Average sequence length during this iteration
87            let avg_seq_len = req.num_computed_tokens + tokens / 2;
88            kv_cache_bytes += arithmetic::kv_cache_bytes(avg_seq_len, &self.model);
89        }
90
91        weight_bytes + kv_cache_bytes
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::config::Config;
99    use crate::request::Request;
100
101    fn create_test_engine() -> ComputeEngine {
102        let config = Config::test_default();
103        ComputeEngine::new(config.hardware, config.model)
104    }
105
106    fn create_test_request(id: &str, computed: u32, prompt: u32) -> Request {
107        let mut req = Request::new(id.to_string(), 0, 0.0, prompt, 50);
108        req.num_computed_tokens = computed;
109        req
110    }
111
112    #[test]
113    fn test_high_token_time() {
114        let engine = create_test_engine();
115
116        // For 2000+ tokens, likely compute-bound
117        let req1 = create_test_request("req-1", 0, 1000);
118        let req2 = create_test_request("req-2", 0, 1000);
119
120        let requests = vec![&req1, &req2];
121        let tokens = vec![1000, 1000];
122
123        let time = engine.calculate_iteration_time(&requests, &tokens);
124
125        // Time should be max(compute_time, memory_time)
126        // With 2000 tokens, likely compute-bound
127        assert!(time > 0.0);
128    }
129
130    #[test]
131    fn test_low_token_time() {
132        let engine = create_test_engine();
133
134        // For few tokens, likely memory-bound
135        let req1 = create_test_request("req-1", 0, 100);
136
137        let requests = vec![&req1];
138        let tokens = vec![50]; // Only 50 tokens
139
140        let time = engine.calculate_iteration_time(&requests, &tokens);
141
142        // Time should be max(compute_time, memory_time)
143        // With few tokens, likely memory-bound
144        assert!(time > 0.0);
145    }
146
147    #[test]
148    fn test_empty_batch() {
149        let engine = create_test_engine();
150
151        let requests: Vec<&Request> = vec![];
152        let tokens: Vec<u32> = vec![];
153
154        let time = engine.calculate_iteration_time(&requests, &tokens);
155        assert_eq!(time, 0.0);
156    }
157
158    #[test]
159    fn test_flops_utilization() {
160        let engine = create_test_engine();
161
162        // Test with 1000 tokens
163        let req = create_test_request("req-1", 0, 1000);
164        let requests = vec![&req];
165        let tokens = vec![1000];
166
167        let flops = arithmetic::flops_for_tokens(1000, &engine.model, &requests, &tokens);
168        let theoretical_time = flops / engine.hardware.compute_flops;
169
170        // If actual time equals theoretical, utilization should be 100%
171        let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time);
172        assert!((util - 1.0).abs() < 1e-10);
173
174        // If actual time is 2x theoretical, utilization should be 50%
175        let util = engine.calculate_flops_utilization(&requests, &tokens, theoretical_time * 2.0);
176        assert!((util - 0.5).abs() < 1e-10);
177
178        // Test with zero time
179        let util = engine.calculate_flops_utilization(&requests, &tokens, 0.0);
180        assert_eq!(util, 0.0);
181    }
182
183    #[test]
184    fn test_bandwidth_utilization() {
185        let engine = create_test_engine();
186
187        let bytes = 1e12; // 1 TB
188        let theoretical_time = bytes / engine.hardware.memory_bandwidth;
189
190        // If actual time equals theoretical, utilization should be 100%
191        let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time);
192        assert!((util - 1.0).abs() < 1e-10);
193
194        // If actual time is 2x theoretical, utilization should be 50%
195        let util = engine.calculate_bandwidth_utilization(bytes, theoretical_time * 2.0);
196        assert!((util - 0.5).abs() < 1e-10);
197
198        // Test with zero time
199        let util = engine.calculate_bandwidth_utilization(bytes, 0.0);
200        assert_eq!(util, 0.0);
201    }
202}