Skip to main content

rnn/runtime/
flops.rs

1use crate::model_config::{ConfigError, TransformerConfig};
2use super::{RuntimeError, RuntimeProfile};
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub struct RuntimeFlopsEstimate {
6    pub prefill_flops: u128,
7    pub decode_token_flops: u128,
8    pub total_flops: u128,
9}
10
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct ThroughputEstimate {
13    pub flops_per_token: u128,
14    pub estimated_tokens_per_second: f32,
15}
16
17pub fn estimate_runtime_flops(
18    config: &TransformerConfig,
19    profile: RuntimeProfile,
20) -> Result<RuntimeFlopsEstimate, RuntimeError> {
21    config.validate().map_err(|_| RuntimeError::InvalidConfig)?;
22    if !profile.validate() {
23        return Err(RuntimeError::InvalidProfile);
24    }
25
26    let h = config.hidden_size as u128;
27    let f = config.ffw_size as u128;
28    let l = config.num_layers as u128;
29    let b = profile.batch_size as u128;
30    let s = profile.sequence_len as u128;
31
32    let qkv = h
33        .checked_mul(h)
34        .and_then(|x| x.checked_mul(6))
35        .ok_or(RuntimeError::Overflow)?;
36    let o_proj = h.checked_mul(h).and_then(|x| x.checked_mul(2)).ok_or(RuntimeError::Overflow)?;
37    let ff = h
38        .checked_mul(f)
39        .and_then(|x| x.checked_mul(4))
40        .ok_or(RuntimeError::Overflow)?;
41
42    let per_token_dense = qkv
43        .checked_add(o_proj)
44        .and_then(|x| x.checked_add(ff))
45        .ok_or(RuntimeError::Overflow)?;
46
47    let per_token_attention = s
48        .checked_mul(h)
49        .and_then(|x| x.checked_mul(4))
50        .ok_or(RuntimeError::Overflow)?;
51
52    let per_token_per_layer = per_token_dense
53        .checked_add(per_token_attention)
54        .ok_or(RuntimeError::Overflow)?;
55
56    let prefill_flops = b
57        .checked_mul(s)
58        .and_then(|x| x.checked_mul(l))
59        .and_then(|x| x.checked_mul(per_token_per_layer))
60        .ok_or(RuntimeError::Overflow)?;
61
62    let decode_attention = s
63        .checked_mul(h)
64        .and_then(|x| x.checked_mul(2))
65        .ok_or(RuntimeError::Overflow)?;
66    let decode_per_token_per_layer = per_token_dense
67        .checked_add(decode_attention)
68        .ok_or(RuntimeError::Overflow)?;
69    let decode_token_flops = b
70        .checked_mul(l)
71        .and_then(|x| x.checked_mul(decode_per_token_per_layer))
72        .ok_or(RuntimeError::Overflow)?;
73
74    let total_flops = prefill_flops
75        .checked_add(decode_token_flops)
76        .ok_or(RuntimeError::Overflow)?;
77
78    Ok(RuntimeFlopsEstimate {
79        prefill_flops,
80        decode_token_flops,
81        total_flops,
82    })
83}
84
85pub fn estimate_tokens_per_second(
86    config: &TransformerConfig,
87    profile: RuntimeProfile,
88    sustained_flops_per_second: u128,
89) -> Result<ThroughputEstimate, RuntimeError> {
90    config
91        .validate()
92        .map_err(|e| match e {
93            ConfigError::Invalid => RuntimeError::InvalidConfig,
94            ConfigError::Overflow => RuntimeError::Overflow,
95        })?;
96
97    if !profile.validate() || sustained_flops_per_second == 0 {
98        return Err(RuntimeError::InvalidProfile);
99    }
100
101    let flops = estimate_runtime_flops(config, profile)?;
102    let flops_per_token = flops.decode_token_flops;
103    if flops_per_token == 0 {
104        return Err(RuntimeError::InvalidProfile);
105    }
106
107    let tps = sustained_flops_per_second as f32 / flops_per_token as f32;
108    Ok(ThroughputEstimate {
109        flops_per_token,
110        estimated_tokens_per_second: tps,
111    })
112}