Skip to main content

inference_lab/compute/
arithmetic.rs

1/// Core inference arithmetic formulas based on the inference-arithmetic.mdx blog post
2use crate::config::{HardwareConfig, ModelConfig};
3
4/// Calculate FLOPS for a given number of tokens
5/// Formula: FLOPS = 2 * num_tokens * active_parameters + attention_flops
6/// For MoE models, uses active_parameters (not total) since only some experts are activated
7/// Includes both matmul and attention FLOPs
8pub fn flops_for_tokens(
9    total_tokens: u32,
10    model: &ModelConfig,
11    requests: &[&crate::request::Request],
12    tokens_per_request: &[u32],
13) -> f64 {
14    // MatMul FLOPs: 2 * num_tokens * active_parameters
15    // For MoE: only counts activated expert parameters, not all experts
16    let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
17
18    // Attention FLOPs: 4 * L * B * S * T * D
19    // where L = num_layers, B = batch size, S = new tokens, T = attended tokens, D = hidden_dim
20    let mut attention_flops = 0.0;
21    for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
22        let batch_size = 1.0; // Each request is one sequence
23        let s = num_new_tokens as f64; // New tokens being processed
24        let t = (req.num_computed_tokens + num_new_tokens) as f64; // Total attended tokens
25        let d = model.hidden_dim as f64;
26        let l = model.num_layers as f64;
27
28        // 4LBSTD FLOPs for attention across all layers
29        // Note: Causal masking zeros out some values, but the matmul still computes them
30        attention_flops += 4.0 * l * batch_size * s * t * d;
31    }
32
33    matmul_flops + attention_flops
34}
35
36/// Calculate memory transfer bytes for model weights
37/// Formula: weight_bytes = num_parameters * bytes_per_param
38pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
39    model.num_parameters as f64 * hardware.bytes_per_param as f64
40}
41
42/// Calculate memory transfer bytes for KV cache for a given sequence length
43/// Formula: kv_bytes = kv_cache_bytes_per_token * seq_len
44pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
45    model.kv_cache_bytes_per_token as f64 * seq_len as f64
46}
47
48/// Calculate total memory transfer bytes for an iteration
49/// Formula: total_bytes = model_weights + sum(kv_cache for each request)
50pub fn total_memory_transfer(
51    model: &ModelConfig,
52    hardware: &HardwareConfig,
53    request_seq_lens: &[u32],
54) -> f64 {
55    let weight_bytes = model_weight_bytes(model, hardware);
56    let kv_bytes: f64 = request_seq_lens
57        .iter()
58        .map(|&seq_len| kv_cache_bytes(seq_len, model))
59        .sum();
60    weight_bytes + kv_bytes
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use crate::config::Config;
67
68    #[test]
69    fn test_flops_calculation() {
70        let config = Config::test_default();
71
72        // Create test request for 100 tokens
73        let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
74        req.num_computed_tokens = 0;
75        let requests = vec![&req];
76        let tokens = vec![100];
77
78        // For 100 tokens through a 7B model
79        let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
80
81        // MatMul FLOPs: 2 * 100 * 7e9 = 1.4e12
82        // Attention FLOPs: 4 * 1 * 100 * 100 * 4096 (assuming default hidden_dim)
83        // Total should be matmul + attention
84        assert!(flops >= 1.4e12); // At least the matmul FLOPs
85    }
86
87    #[test]
88    fn test_model_weight_bytes() {
89        let config = Config::test_default();
90
91        let bytes = model_weight_bytes(&config.model, &config.hardware);
92
93        // 7e9 parameters * 2 bytes = 14GB
94        assert_eq!(bytes, 14_000_000_000.0);
95    }
96
97    #[test]
98    fn test_kv_cache_bytes() {
99        let mut config = Config::test_default();
100        config.model.kv_cache_bytes_per_token = 524_288; // 512KB per token
101
102        let bytes = kv_cache_bytes(100, &config.model);
103
104        // 524_288 * 100 = 52,428,800
105        assert_eq!(bytes, 52_428_800.0);
106    }
107
108    #[test]
109    fn test_total_memory_transfer() {
110        let mut config = Config::test_default();
111        config.model.kv_cache_bytes_per_token = 524_288;
112
113        // 3 requests with sequence lengths 50, 100, 150
114        let seq_lens = vec![50, 100, 150];
115        let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
116
117        let expected_weights = 14_000_000_000.0;
118        let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
119        assert_eq!(total, expected_weights + expected_kv);
120    }
121}