Skip to main content

rnn/runtime/
estimate.rs

1use crate::model_config::{ConfigError, TransformerConfig};
2use super::profile::RuntimeProfile;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub struct RuntimeEstimate {
6    pub parameter_bytes: usize,
7    pub activation_bytes: usize,
8    pub kv_cache_bytes: usize,
9    pub total_bytes: usize,
10}
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum RuntimeError {
14    InvalidProfile,
15    Overflow,
16    InvalidConfig,
17}
18
19pub fn estimate_runtime_memory(
20    config: &TransformerConfig,
21    profile: RuntimeProfile,
22) -> Result<RuntimeEstimate, RuntimeError> {
23    config.validate().map_err(|_| RuntimeError::InvalidConfig)?;
24
25    if !profile.validate() {
26        return Err(RuntimeError::InvalidProfile);
27    }
28
29    let params = config
30        .approximate_parameter_count()
31        .map_err(|e| match e {
32            ConfigError::Invalid => RuntimeError::InvalidConfig,
33            ConfigError::Overflow => RuntimeError::Overflow,
34        })?;
35
36    let parameter_bytes = params.checked_mul(profile.bytes_per_param).ok_or(RuntimeError::Overflow)?;
37    let tokens = profile.batch_size.checked_mul(profile.sequence_len).ok_or(RuntimeError::Overflow)?;
38    let activation_elems = tokens
39        .checked_mul(config.hidden_size)
40        .and_then(|x| x.checked_mul(config.num_layers))
41        .ok_or(RuntimeError::Overflow)?;
42    let activation_bytes = activation_elems
43        .checked_mul(profile.bytes_per_activation)
44        .ok_or(RuntimeError::Overflow)?;
45
46    let kv_elems = tokens
47        .checked_mul(config.hidden_size)
48        .and_then(|x| x.checked_mul(config.num_layers))
49        .and_then(|x| x.checked_mul(2))
50        .ok_or(RuntimeError::Overflow)?;
51    let kv_cache_bytes = kv_elems
52        .checked_mul(profile.bytes_per_kv)
53        .ok_or(RuntimeError::Overflow)?;
54
55    let total_bytes = parameter_bytes
56        .checked_add(activation_bytes)
57        .and_then(|x| x.checked_add(kv_cache_bytes))
58        .ok_or(RuntimeError::Overflow)?;
59
60    Ok(RuntimeEstimate {
61        parameter_bytes,
62        activation_bytes,
63        kv_cache_bytes,
64        total_bytes,
65    })
66}