1use crate::model_config::{ConfigError, TransformerConfig};
2use super::{RuntimeError, RuntimeProfile};
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub struct RuntimeFlopsEstimate {
6 pub prefill_flops: u128,
7 pub decode_token_flops: u128,
8 pub total_flops: u128,
9}
10
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct ThroughputEstimate {
13 pub flops_per_token: u128,
14 pub estimated_tokens_per_second: f32,
15}
16
17pub fn estimate_runtime_flops(
18 config: &TransformerConfig,
19 profile: RuntimeProfile,
20) -> Result<RuntimeFlopsEstimate, RuntimeError> {
21 config.validate().map_err(|_| RuntimeError::InvalidConfig)?;
22 if !profile.validate() {
23 return Err(RuntimeError::InvalidProfile);
24 }
25
26 let h = config.hidden_size as u128;
27 let f = config.ffw_size as u128;
28 let l = config.num_layers as u128;
29 let b = profile.batch_size as u128;
30 let s = profile.sequence_len as u128;
31
32 let qkv = h
33 .checked_mul(h)
34 .and_then(|x| x.checked_mul(6))
35 .ok_or(RuntimeError::Overflow)?;
36 let o_proj = h.checked_mul(h).and_then(|x| x.checked_mul(2)).ok_or(RuntimeError::Overflow)?;
37 let ff = h
38 .checked_mul(f)
39 .and_then(|x| x.checked_mul(4))
40 .ok_or(RuntimeError::Overflow)?;
41
42 let per_token_dense = qkv
43 .checked_add(o_proj)
44 .and_then(|x| x.checked_add(ff))
45 .ok_or(RuntimeError::Overflow)?;
46
47 let per_token_attention = s
48 .checked_mul(h)
49 .and_then(|x| x.checked_mul(4))
50 .ok_or(RuntimeError::Overflow)?;
51
52 let per_token_per_layer = per_token_dense
53 .checked_add(per_token_attention)
54 .ok_or(RuntimeError::Overflow)?;
55
56 let prefill_flops = b
57 .checked_mul(s)
58 .and_then(|x| x.checked_mul(l))
59 .and_then(|x| x.checked_mul(per_token_per_layer))
60 .ok_or(RuntimeError::Overflow)?;
61
62 let decode_attention = s
63 .checked_mul(h)
64 .and_then(|x| x.checked_mul(2))
65 .ok_or(RuntimeError::Overflow)?;
66 let decode_per_token_per_layer = per_token_dense
67 .checked_add(decode_attention)
68 .ok_or(RuntimeError::Overflow)?;
69 let decode_token_flops = b
70 .checked_mul(l)
71 .and_then(|x| x.checked_mul(decode_per_token_per_layer))
72 .ok_or(RuntimeError::Overflow)?;
73
74 let total_flops = prefill_flops
75 .checked_add(decode_token_flops)
76 .ok_or(RuntimeError::Overflow)?;
77
78 Ok(RuntimeFlopsEstimate {
79 prefill_flops,
80 decode_token_flops,
81 total_flops,
82 })
83}
84
85pub fn estimate_tokens_per_second(
86 config: &TransformerConfig,
87 profile: RuntimeProfile,
88 sustained_flops_per_second: u128,
89) -> Result<ThroughputEstimate, RuntimeError> {
90 config
91 .validate()
92 .map_err(|e| match e {
93 ConfigError::Invalid => RuntimeError::InvalidConfig,
94 ConfigError::Overflow => RuntimeError::Overflow,
95 })?;
96
97 if !profile.validate() || sustained_flops_per_second == 0 {
98 return Err(RuntimeError::InvalidProfile);
99 }
100
101 let flops = estimate_runtime_flops(config, profile)?;
102 let flops_per_token = flops.decode_token_flops;
103 if flops_per_token == 0 {
104 return Err(RuntimeError::InvalidProfile);
105 }
106
107 let tps = sustained_flops_per_second as f32 / flops_per_token as f32;
108 Ok(ThroughputEstimate {
109 flops_per_token,
110 estimated_tokens_per_second: tps,
111 })
112}