1use crate::model_config::{ConfigError, TransformerConfig};
2use super::profile::RuntimeProfile;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub struct RuntimeEstimate {
6 pub parameter_bytes: usize,
7 pub activation_bytes: usize,
8 pub kv_cache_bytes: usize,
9 pub total_bytes: usize,
10}
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum RuntimeError {
14 InvalidProfile,
15 Overflow,
16 InvalidConfig,
17}
18
19pub fn estimate_runtime_memory(
20 config: &TransformerConfig,
21 profile: RuntimeProfile,
22) -> Result<RuntimeEstimate, RuntimeError> {
23 config.validate().map_err(|_| RuntimeError::InvalidConfig)?;
24
25 if !profile.validate() {
26 return Err(RuntimeError::InvalidProfile);
27 }
28
29 let params = config
30 .approximate_parameter_count()
31 .map_err(|e| match e {
32 ConfigError::Invalid => RuntimeError::InvalidConfig,
33 ConfigError::Overflow => RuntimeError::Overflow,
34 })?;
35
36 let parameter_bytes = params.checked_mul(profile.bytes_per_param).ok_or(RuntimeError::Overflow)?;
37 let tokens = profile.batch_size.checked_mul(profile.sequence_len).ok_or(RuntimeError::Overflow)?;
38 let activation_elems = tokens
39 .checked_mul(config.hidden_size)
40 .and_then(|x| x.checked_mul(config.num_layers))
41 .ok_or(RuntimeError::Overflow)?;
42 let activation_bytes = activation_elems
43 .checked_mul(profile.bytes_per_activation)
44 .ok_or(RuntimeError::Overflow)?;
45
46 let kv_elems = tokens
47 .checked_mul(config.hidden_size)
48 .and_then(|x| x.checked_mul(config.num_layers))
49 .and_then(|x| x.checked_mul(2))
50 .ok_or(RuntimeError::Overflow)?;
51 let kv_cache_bytes = kv_elems
52 .checked_mul(profile.bytes_per_kv)
53 .ok_or(RuntimeError::Overflow)?;
54
55 let total_bytes = parameter_bytes
56 .checked_add(activation_bytes)
57 .and_then(|x| x.checked_add(kv_cache_bytes))
58 .ok_or(RuntimeError::Overflow)?;
59
60 Ok(RuntimeEstimate {
61 parameter_bytes,
62 activation_bytes,
63 kv_cache_bytes,
64 total_bytes,
65 })
66}