#[derive(Debug, Clone, Copy)]
pub struct MemoryEstimate {
pub weights: u64,
pub activations: u64,
pub gradients: u64,
}
impl MemoryEstimate {
#[must_use]
pub fn total(&self) -> u64 {
self.weights + self.activations + self.gradients
}
#[must_use]
pub fn fits_in(&self, available: u64) -> bool {
self.total() <= available
}
#[must_use]
pub fn fp32(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
Self {
weights: param_count * 4,
activations: (batch_size * seq_len * hidden_size * 4) as u64,
gradients: 0, }
}
#[must_use]
pub fn fp16(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
Self {
weights: param_count * 2,
activations: (batch_size * seq_len * hidden_size * 2) as u64,
gradients: 0,
}
}
#[must_use]
pub fn int4(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
Self {
weights: param_count / 2, activations: (batch_size * seq_len * hidden_size * 2) as u64,
gradients: 0,
}
}
}