inference_lab/compute/
arithmetic.rs1use crate::config::{HardwareConfig, ModelConfig};
3
4pub fn flops_for_tokens(
9 total_tokens: u32,
10 model: &ModelConfig,
11 requests: &[&crate::request::Request],
12 tokens_per_request: &[u32],
13) -> f64 {
14 let matmul_flops = 2.0 * total_tokens as f64 * model.active_parameters() as f64;
17
18 let mut attention_flops = 0.0;
21 for (req, &num_new_tokens) in requests.iter().zip(tokens_per_request) {
22 let batch_size = 1.0; let s = num_new_tokens as f64; let t = (req.num_computed_tokens + num_new_tokens) as f64; let d = model.hidden_dim as f64;
26 let l = model.num_layers as f64;
27
28 attention_flops += 4.0 * l * batch_size * s * t * d;
31 }
32
33 matmul_flops + attention_flops
34}
35
36pub fn model_weight_bytes(model: &ModelConfig, hardware: &HardwareConfig) -> f64 {
39 model.num_parameters as f64 * hardware.bytes_per_param as f64
40}
41
42pub fn kv_cache_bytes(seq_len: u32, model: &ModelConfig) -> f64 {
45 model.kv_cache_bytes_per_token as f64 * seq_len as f64
46}
47
48pub fn total_memory_transfer(
51 model: &ModelConfig,
52 hardware: &HardwareConfig,
53 request_seq_lens: &[u32],
54) -> f64 {
55 let weight_bytes = model_weight_bytes(model, hardware);
56 let kv_bytes: f64 = request_seq_lens
57 .iter()
58 .map(|&seq_len| kv_cache_bytes(seq_len, model))
59 .sum();
60 weight_bytes + kv_bytes
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::config::Config;
67
68 #[test]
69 fn test_flops_calculation() {
70 let config = Config::test_default();
71
72 let mut req = crate::request::Request::new("test".to_string(), 0, 0.0, 100, 50);
74 req.num_computed_tokens = 0;
75 let requests = vec![&req];
76 let tokens = vec![100];
77
78 let flops = flops_for_tokens(100, &config.model, &requests, &tokens);
80
81 assert!(flops >= 1.4e12); }
86
87 #[test]
88 fn test_model_weight_bytes() {
89 let config = Config::test_default();
90
91 let bytes = model_weight_bytes(&config.model, &config.hardware);
92
93 assert_eq!(bytes, 14_000_000_000.0);
95 }
96
97 #[test]
98 fn test_kv_cache_bytes() {
99 let mut config = Config::test_default();
100 config.model.kv_cache_bytes_per_token = 524_288; let bytes = kv_cache_bytes(100, &config.model);
103
104 assert_eq!(bytes, 52_428_800.0);
106 }
107
108 #[test]
109 fn test_total_memory_transfer() {
110 let mut config = Config::test_default();
111 config.model.kv_cache_bytes_per_token = 524_288;
112
113 let seq_lens = vec![50, 100, 150];
115 let total = total_memory_transfer(&config.model, &config.hardware, &seq_lens);
116
117 let expected_weights = 14_000_000_000.0;
118 let expected_kv = 524_288.0 * (50.0 + 100.0 + 150.0);
119 assert_eq!(total, expected_weights + expected_kv);
120 }
121}