1pub mod config;
16pub mod pipeline;
17pub mod validation;
18
19pub use config::DistillConfig;
20pub use pipeline::{Pipeline, PipelineResult};
21pub use validation::ConfigValidator;
22
23use entrenar_common::Result;
24
25pub fn run(config: &DistillConfig) -> Result<PipelineResult> {
27 ConfigValidator::validate(config)?;
29
30 Pipeline::new(config).execute()
32}
33
34pub fn estimate_memory(config: &DistillConfig) -> Result<MemoryEstimate> {
36 ConfigValidator::validate(config)?;
37 Pipeline::estimate_memory(config)
38}
39
40#[derive(Debug, Clone)]
42pub struct MemoryEstimate {
43 pub model_bytes: u64,
45 pub activation_bytes: u64,
47 pub optimizer_bytes: u64,
49 pub total_bytes: u64,
51 pub fits_in_vram: bool,
53 pub recommended_batch_size: usize,
55}
56
57impl MemoryEstimate {
58 pub fn new(model_params: u64, batch_size: usize, seq_len: usize, hidden_dim: usize) -> Self {
60 let model_bytes = model_params * 2;
62
63 let activation_bytes = (batch_size * seq_len * hidden_dim * 32 * 2) as u64;
65
66 let optimizer_bytes = model_bytes * 2;
68
69 let total_bytes = model_bytes + activation_bytes + optimizer_bytes;
70
71 let available_vram = 24 * 1024 * 1024 * 1024u64;
73 let fits_in_vram = total_bytes < available_vram;
74
75 let target_memory = (available_vram as f64 * 0.8) as u64;
77 let per_sample = (seq_len * hidden_dim * 32 * 2) as u64;
78 let available_for_activations = target_memory.saturating_sub(model_bytes + optimizer_bytes);
79 let recommended_batch_size = if per_sample > 0 {
80 (available_for_activations / per_sample).max(1) as usize
81 } else {
82 1
83 };
84
85 Self {
86 model_bytes,
87 activation_bytes,
88 optimizer_bytes,
89 total_bytes,
90 fits_in_vram,
91 recommended_batch_size,
92 }
93 }
94
95 pub fn to_human_readable(&self) -> String {
97 format!(
98 "Memory Estimate:\n Model: {:.1} GB\n Activations: {:.1} GB\n Optimizer: {:.1} GB\n Total: {:.1} GB\n Fits in 24GB VRAM: {}",
99 self.model_bytes as f64 / 1e9,
100 self.activation_bytes as f64 / 1e9,
101 self.optimizer_bytes as f64 / 1e9,
102 self.total_bytes as f64 / 1e9,
103 if self.fits_in_vram { "Yes" } else { "No" }
104 )
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_memory_estimate_calculation() {
114 let estimate = MemoryEstimate::new(7_000_000_000, 32, 512, 4096);
116
117 assert!(estimate.model_bytes > 10_000_000_000);
119 assert!(estimate.model_bytes < 20_000_000_000);
120
121 assert!(estimate.total_bytes > estimate.model_bytes);
123 }
124
125 #[test]
126 fn test_memory_estimate_fits_calculation() {
127 let small = MemoryEstimate::new(100_000_000, 8, 256, 768);
129 assert!(small.fits_in_vram);
130
131 let huge = MemoryEstimate::new(70_000_000_000, 32, 2048, 8192);
133 assert!(!huge.fits_in_vram);
134 }
135
136 #[test]
137 fn test_recommended_batch_size_positive() {
138 let estimate = MemoryEstimate::new(7_000_000_000, 32, 512, 4096);
139 assert!(estimate.recommended_batch_size >= 1);
140 }
141}