inference_lab/compute/
arithmetic.rs

1/// Core inference arithmetic formulas based on the inference-arithmetic.mdx blog post
2
3use crate::config::{HardwareConfig, ModelConfig};
4
5/// Calculate the compute-bound threshold (number of tokens at which inference becomes compute-bound)
6/// Formula: threshold = (bytes_per_param * compute_flops) / memory_bandwidth
7pub fn compute_bound_threshold(hardware: &HardwareConfig) -> u32 {
8    (hardware.bytes_per_param as f64 * hardware.compute_flops / hardware.memory_bandwidth) as u32
9}
10
11/// Calculate FLOPS for a given number of tokens
12/// Formula: FLOPS = 2 * num_tokens * active_parameters + attention_flops
13/// For MoE models, uses active_parameters (not total) since only some experts are activated
14/// Includes both matmul and attention FLOPs
15pub fn flops_for_tokens(
16    total_tokens: u32,
17    model: &ModelConfig,
18    requests: &[&crate::request::Request],
19    tokens_per_request: &[u32],
20) -> f64 {
21    // MatMul FLOPs: 2 * num_tokens * active_parameters
22    // For MoE: only counts activated expert parameters, not all experts
23    let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
24
25    // Attention FLOPs: 4 * L * B * S * T * D
26    // where L = num_layers, B = batch size, S = new tokens, T = attended tokens, D = hidden_dim
27    let mut attention_flops = 0.0;
28    for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
29        let batch_size = 1.0; // Each request is one sequence
30        let s = num_new_tokens as f64; // New tokens being processed
31        let t = (req.num_computed_tokens + num_new_tokens) as f64; // Total attended tokens
32        let d = model.hidden_dim as f64;
33        let l = model.num_layers as f64;
34
35        // 4LBSTD FLOPs for attention across all layers
36        // Note: Causal masking zeros out some values, but the matmul still computes them
37        attention_flops += 4.0 * l * batch_size * s * t * d;
38    }
39
40    matmul_flops + attention_flops
41}
42
43/// Calculate memory transfer bytes for model weights
44/// Formula: weight_bytes = num_parameters * bytes_per_param
45pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
46    model.num_parameters as f64 * hardware.bytes_per_param as f64
47}
48
49/// Calculate memory transfer bytes for KV cache for a given sequence length
50/// Formula: kv_bytes = kv_cache_bytes_per_token * seq_len
51pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
52    model.kv_cache_bytes_per_token as f64 * seq_len as f64
53}
54
55/// Calculate total memory transfer bytes for an iteration
56/// Formula: total_bytes = model_weights + sum(kv_cache for each request)
57pub fn total_memory_transfer(
58    model: &ModelConfig,
59    hardware: &HardwareConfig,
60    request_seq_lens: &[u32],
61) -> f64 {
62    let weight_bytes = model_weight_bytes(model, hardware);
63    let kv_bytes: f64 = request_seq_lens
64        .iter()
65        .map(|&seq_len| kv_cache_bytes(seq_len, model))
66        .sum();
67    weight_bytes + kv_bytes
68}
69
70/// Check if a workload is compute-bound
71/// A workload is compute-bound if the number of tokens >= compute-bound threshold
72pub fn is_compute_bound(num_tokens: u32, hardware: &HardwareConfig) -> bool {
73    num_tokens >= hardware.compute_bound_threshold
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::config::Config;
80
81    #[test]
82    fn test_compute_bound_threshold_h100() {
83        let mut hardware = crate::config::HardwareConfig {
84            name: "H100".to_string(),
85            compute_flops: 1.513e15,      // 1513 TFLOPS bf16
86            memory_bandwidth: 3.35e12,    // 3.35 TB/s
87            memory_capacity: 80_000_000_000,
88            kv_cache_capacity: 60_000_000_000,
89            gpu_memory_utilization: 0.9,
90            bytes_per_param: 2,           // bf16
91            compute_bound_threshold: 0,
92        };
93        hardware.compute_threshold();
94
95        let threshold = compute_bound_threshold(&hardware);
96
97        // Should be approximately 903 for H100 bf16
98        assert!(threshold > 900 && threshold < 910);
99        assert_eq!(threshold, hardware.compute_bound_threshold);
100    }
101
102    #[test]
103    fn test_flops_calculation() {
104        let config = Config::test_default();
105
106        // Create test request for 100 tokens
107        let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
108        req.num_computed_tokens = 0;
109        let requests = vec![&req];
110        let tokens = vec![100];
111
112        // For 100 tokens through a 7B model
113        let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
114
115        // MatMul FLOPs: 2 * 100 * 7e9 = 1.4e12
116        // Attention FLOPs: 4 * 1 * 100 * 100 * 4096 (assuming default hidden_dim)
117        // Total should be matmul + attention
118        assert!(flops >= 1.4e12); // At least the matmul FLOPs
119    }
120
121    #[test]
122    fn test_model_weight_bytes() {
123        let config = Config::test_default();
124
125        let bytes = model_weight_bytes(&config.model, &config.hardware);
126
127        // 7e9 parameters * 2 bytes = 14GB
128        assert_eq!(bytes, 14_000_000_000.0);
129    }
130
131    #[test]
132    fn test_kv_cache_bytes() {
133        let mut config = Config::test_default();
134        config.model.kv_cache_bytes_per_token = 524_288; // 512KB per token
135
136        let bytes = kv_cache_bytes(100, &config.model);
137
138        // 524_288 * 100 = 52,428,800
139        assert_eq!(bytes, 52_428_800.0);
140    }
141
142    #[test]
143    fn test_total_memory_transfer() {
144        let mut config = Config::test_default();
145        config.model.kv_cache_bytes_per_token = 524_288;
146
147        // 3 requests with sequence lengths 50, 100, 150
148        let seq_lens = vec![50, 100, 150];
149        let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
150
151        let expected_weights = 14_000_000_000.0;
152        let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
153        assert_eq!(total, expected_weights + expected_kv);
154    }
155
156    #[test]
157    fn test_is_compute_bound() {
158        let mut hardware = crate::config::HardwareConfig {
159            name: "Test".to_string(),
160            compute_flops: 1e15,
161            memory_bandwidth: 1e12,
162            memory_capacity: 80_000_000_000,
163            kv_cache_capacity: 60_000_000_000,
164            gpu_memory_utilization: 0.9,
165            bytes_per_param: 2,
166            compute_bound_threshold: 0,
167        };
168        hardware.compute_threshold();
169
170        // Threshold should be 2000
171        assert_eq!(hardware.compute_bound_threshold, 2000);
172
173        assert!(!is_compute_bound(1000, &hardware));
174        assert!(!is_compute_bound(1999, &hardware));
175        assert!(is_compute_bound(2000, &hardware));
176        assert!(is_compute_bound(3000, &hardware));
177    }
178}