entrenar_distill/
lib.rs

1//! End-to-end knowledge distillation CLI.
2//!
3//! This crate provides a complete pipeline for knowledge distillation:
4//! - Fetch teacher models from HuggingFace
5//! - Configure distillation parameters via YAML
6//! - Train student models with progressive/attention distillation
7//! - Export to SafeTensors, GGUF, or APR formats
8//!
9//! # Toyota Way Principles
10//!
11//! - **Jidoka**: Pre-flight validation catches errors before expensive training
12//! - **Heijunka**: Memory estimation enables level scheduling of GPU resources
13//! - **Kaizen**: Configurable hyperparameters enable continuous improvement
14
15pub mod config;
16pub mod pipeline;
17pub mod validation;
18
19pub use config::DistillConfig;
20pub use pipeline::{Pipeline, PipelineResult};
21pub use validation::ConfigValidator;
22
23use entrenar_common::Result;
24
25/// Run the distillation pipeline with the given configuration.
26pub fn run(config: &DistillConfig) -> Result<PipelineResult> {
27    // Validate configuration first (Jidoka)
28    ConfigValidator::validate(config)?;
29
30    // Execute pipeline
31    Pipeline::new(config).execute()
32}
33
34/// Estimate memory requirements without running training.
35pub fn estimate_memory(config: &DistillConfig) -> Result<MemoryEstimate> {
36    ConfigValidator::validate(config)?;
37    Pipeline::estimate_memory(config)
38}
39
40/// Memory estimation result.
41#[derive(Debug, Clone)]
42pub struct MemoryEstimate {
43    /// Model weights memory in bytes
44    pub model_bytes: u64,
45    /// Activation memory in bytes
46    pub activation_bytes: u64,
47    /// Optimizer state memory in bytes
48    pub optimizer_bytes: u64,
49    /// Total memory in bytes
50    pub total_bytes: u64,
51    /// Whether this fits in available VRAM
52    pub fits_in_vram: bool,
53    /// Recommended batch size for available memory
54    pub recommended_batch_size: usize,
55}
56
57impl MemoryEstimate {
58    /// Create a new memory estimate.
59    pub fn new(model_params: u64, batch_size: usize, seq_len: usize, hidden_dim: usize) -> Self {
60        // Model weights (assume FP16 for training)
61        let model_bytes = model_params * 2;
62
63        // Activations: batch * seq * hidden * layers * 2 (forward + backward)
64        let activation_bytes = (batch_size * seq_len * hidden_dim * 32 * 2) as u64;
65
66        // Optimizer state: 2x model size for Adam (momentum + variance)
67        let optimizer_bytes = model_bytes * 2;
68
69        let total_bytes = model_bytes + activation_bytes + optimizer_bytes;
70
71        // Assume 24GB VRAM as target
72        let available_vram = 24 * 1024 * 1024 * 1024u64;
73        let fits_in_vram = total_bytes < available_vram;
74
75        // Calculate recommended batch size to fit in 80% of available VRAM
76        let target_memory = (available_vram as f64 * 0.8) as u64;
77        let per_sample = (seq_len * hidden_dim * 32 * 2) as u64;
78        let available_for_activations = target_memory.saturating_sub(model_bytes + optimizer_bytes);
79        let recommended_batch_size = if per_sample > 0 {
80            (available_for_activations / per_sample).max(1) as usize
81        } else {
82            1
83        };
84
85        Self {
86            model_bytes,
87            activation_bytes,
88            optimizer_bytes,
89            total_bytes,
90            fits_in_vram,
91            recommended_batch_size,
92        }
93    }
94
95    /// Format as human-readable string.
96    pub fn to_human_readable(&self) -> String {
97        format!(
98            "Memory Estimate:\n  Model: {:.1} GB\n  Activations: {:.1} GB\n  Optimizer: {:.1} GB\n  Total: {:.1} GB\n  Fits in 24GB VRAM: {}",
99            self.model_bytes as f64 / 1e9,
100            self.activation_bytes as f64 / 1e9,
101            self.optimizer_bytes as f64 / 1e9,
102            self.total_bytes as f64 / 1e9,
103            if self.fits_in_vram { "Yes" } else { "No" }
104        )
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_memory_estimate_calculation() {
114        // 7B parameter model
115        let estimate = MemoryEstimate::new(7_000_000_000, 32, 512, 4096);
116
117        // Model weights should be ~14GB for FP16
118        assert!(estimate.model_bytes > 10_000_000_000);
119        assert!(estimate.model_bytes < 20_000_000_000);
120
121        // Total should include model + activations + optimizer
122        assert!(estimate.total_bytes > estimate.model_bytes);
123    }
124
125    #[test]
126    fn test_memory_estimate_fits_calculation() {
127        // Small model should fit
128        let small = MemoryEstimate::new(100_000_000, 8, 256, 768);
129        assert!(small.fits_in_vram);
130
131        // Huge model shouldn't fit
132        let huge = MemoryEstimate::new(70_000_000_000, 32, 2048, 8192);
133        assert!(!huge.fits_in_vram);
134    }
135
136    #[test]
137    fn test_recommended_batch_size_positive() {
138        let estimate = MemoryEstimate::new(7_000_000_000, 32, 512, 4096);
139        assert!(estimate.recommended_batch_size >= 1);
140    }
141}