inference_lab/compute/
arithmetic.rs

1/// Core inference arithmetic formulas based on the inference-arithmetic.mdx blog post
2use crate::config::{HardwareConfig, ModelConfig};
3
4/// Calculate the compute-bound threshold (number of tokens at which inference becomes compute-bound)
5/// Formula: threshold = (bytes_per_param * compute_flops) / memory_bandwidth
6pub fn compute_bound_threshold(hardware: &HardwareConfig) -> u32 {
7    (hardware.bytes_per_param as f64 * hardware.compute_flops / hardware.memory_bandwidth) as u32
8}
9
10/// Calculate FLOPS for a given number of tokens
11/// Formula: FLOPS = 2 * num_tokens * active_parameters + attention_flops
12/// For MoE models, uses active_parameters (not total) since only some experts are activated
13/// Includes both matmul and attention FLOPs
14pub fn flops_for_tokens(
15    total_tokens: u32,
16    model: &ModelConfig,
17    requests: &[&crate::request::Request],
18    tokens_per_request: &[u32],
19) -> f64 {
20    // MatMul FLOPs: 2 * num_tokens * active_parameters
21    // For MoE: only counts activated expert parameters, not all experts
22    let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
23
24    // Attention FLOPs: 4 * L * B * S * T * D
25    // where L = num_layers, B = batch size, S = new tokens, T = attended tokens, D = hidden_dim
26    let mut attention_flops = 0.0;
27    for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
28        let batch_size = 1.0; // Each request is one sequence
29        let s = num_new_tokens as f64; // New tokens being processed
30        let t = (req.num_computed_tokens + num_new_tokens) as f64; // Total attended tokens
31        let d = model.hidden_dim as f64;
32        let l = model.num_layers as f64;
33
34        // 4LBSTD FLOPs for attention across all layers
35        // Note: Causal masking zeros out some values, but the matmul still computes them
36        attention_flops += 4.0 * l * batch_size * s * t * d;
37    }
38
39    matmul_flops + attention_flops
40}
41
42/// Calculate memory transfer bytes for model weights
43/// Formula: weight_bytes = num_parameters * bytes_per_param
44pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
45    model.num_parameters as f64 * hardware.bytes_per_param as f64
46}
47
48/// Calculate memory transfer bytes for KV cache for a given sequence length
49/// Formula: kv_bytes = kv_cache_bytes_per_token * seq_len
50pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
51    model.kv_cache_bytes_per_token as f64 * seq_len as f64
52}
53
54/// Calculate total memory transfer bytes for an iteration
55/// Formula: total_bytes = model_weights + sum(kv_cache for each request)
56pub fn total_memory_transfer(
57    model: &ModelConfig,
58    hardware: &HardwareConfig,
59    request_seq_lens: &[u32],
60) -> f64 {
61    let weight_bytes = model_weight_bytes(model, hardware);
62    let kv_bytes: f64 = request_seq_lens
63        .iter()
64        .map(|&seq_len| kv_cache_bytes(seq_len, model))
65        .sum();
66    weight_bytes + kv_bytes
67}
68
69/// Check if a workload is compute-bound
70/// A workload is compute-bound if the number of tokens >= compute-bound threshold
71pub fn is_compute_bound(num_tokens: u32, hardware: &HardwareConfig) -> bool {
72    num_tokens >= hardware.compute_bound_threshold
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::config::Config;
79
80    #[test]
81    fn test_compute_bound_threshold_h100() {
82        let mut hardware = crate::config::HardwareConfig {
83            name: "H100".to_string(),
84            compute_flops: 1.513e15,   // 1513 TFLOPS bf16
85            memory_bandwidth: 3.35e12, // 3.35 TB/s
86            memory_capacity: 80_000_000_000,
87            kv_cache_capacity: 60_000_000_000,
88            gpu_memory_utilization: 0.9,
89            bytes_per_param: 2, // bf16
90            compute_bound_threshold: 0,
91        };
92        hardware.compute_threshold();
93
94        let threshold = compute_bound_threshold(&hardware);
95
96        // Should be approximately 903 for H100 bf16
97        assert!(threshold > 900 && threshold < 910);
98        assert_eq!(threshold, hardware.compute_bound_threshold);
99    }
100
101    #[test]
102    fn test_flops_calculation() {
103        let config = Config::test_default();
104
105        // Create test request for 100 tokens
106        let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
107        req.num_computed_tokens = 0;
108        let requests = vec![&req];
109        let tokens = vec![100];
110
111        // For 100 tokens through a 7B model
112        let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
113
114        // MatMul FLOPs: 2 * 100 * 7e9 = 1.4e12
115        // Attention FLOPs: 4 * 1 * 100 * 100 * 4096 (assuming default hidden_dim)
116        // Total should be matmul + attention
117        assert!(flops >= 1.4e12); // At least the matmul FLOPs
118    }
119
120    #[test]
121    fn test_model_weight_bytes() {
122        let config = Config::test_default();
123
124        let bytes = model_weight_bytes(&config.model, &config.hardware);
125
126        // 7e9 parameters * 2 bytes = 14GB
127        assert_eq!(bytes, 14_000_000_000.0);
128    }
129
130    #[test]
131    fn test_kv_cache_bytes() {
132        let mut config = Config::test_default();
133        config.model.kv_cache_bytes_per_token = 524_288; // 512KB per token
134
135        let bytes = kv_cache_bytes(100, &config.model);
136
137        // 524_288 * 100 = 52,428,800
138        assert_eq!(bytes, 52_428_800.0);
139    }
140
141    #[test]
142    fn test_total_memory_transfer() {
143        let mut config = Config::test_default();
144        config.model.kv_cache_bytes_per_token = 524_288;
145
146        // 3 requests with sequence lengths 50, 100, 150
147        let seq_lens = vec![50, 100, 150];
148        let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
149
150        let expected_weights = 14_000_000_000.0;
151        let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
152        assert_eq!(total, expected_weights + expected_kv);
153    }
154
155    #[test]
156    fn test_is_compute_bound() {
157        let mut hardware = crate::config::HardwareConfig {
158            name: "Test".to_string(),
159            compute_flops: 1e15,
160            memory_bandwidth: 1e12,
161            memory_capacity: 80_000_000_000,
162            kv_cache_capacity: 60_000_000_000,
163            gpu_memory_utilization: 0.9,
164            bytes_per_param: 2,
165            compute_bound_threshold: 0,
166        };
167        hardware.compute_threshold();
168
169        // Threshold should be 2000
170        assert_eq!(hardware.compute_bound_threshold, 2000);
171
172        assert!(!is_compute_bound(1000, &hardware));
173        assert!(!is_compute_bound(1999, &hardware));
174        assert!(is_compute_bound(2000, &hardware));
175        assert!(is_compute_bound(3000, &hardware));
176    }
177}