Skip to main content

entrenar/efficiency/paradigm/
model_paradigm.rs

1//! Model training/inference paradigm definitions.
2
3use serde::{Deserialize, Serialize};
4
5use super::FineTuneMethod;
6
7/// Model training/inference paradigm
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
9pub enum ModelParadigm {
10    /// Traditional ML (sklearn-style): logistic regression, random forests, etc.
11    TraditionalMl,
12    /// Deep learning: neural networks trained from scratch
13    #[default]
14    DeepLearning,
15    /// Fine-tuning a pretrained model
16    FineTuning(FineTuneMethod),
17    /// Knowledge distillation (teacher → student)
18    Distillation,
19    /// Mixture of Experts
20    MoE,
21    /// Ensemble of multiple models
22    Ensemble,
23}
24
25impl ModelParadigm {
26    /// Create a LoRA fine-tuning paradigm
27    pub fn lora(rank: u32, alpha: f32) -> Self {
28        Self::FineTuning(FineTuneMethod::LoRA { rank, alpha })
29    }
30
31    /// Create a QLoRA fine-tuning paradigm
32    pub fn qlora(rank: u32, bits: u8) -> Self {
33        Self::FineTuning(FineTuneMethod::QLoRA { rank, bits })
34    }
35
36    /// Get typical memory multiplier relative to model size
37    ///
38    /// Returns the factor by which memory usage increases during training
39    /// compared to inference-only memory.
40    pub fn typical_memory_multiplier(&self) -> f64 {
41        match self {
42            Self::TraditionalMl => 1.5,
43            Self::DeepLearning => 4.0,
44            Self::FineTuning(method) => method.memory_multiplier(),
45            Self::Distillation => 5.0,
46            Self::MoE => 2.0,
47            Self::Ensemble => 3.0,
48        }
49    }
50
51    /// Get typical training speedup relative to full training
52    ///
53    /// Returns the factor by which training is faster compared to
54    /// training from scratch.
55    pub fn typical_training_speedup(&self) -> f64 {
56        match self {
57            Self::TraditionalMl => 10.0,
58            Self::DeepLearning => 1.0,
59            Self::FineTuning(method) => method.training_speedup(),
60            Self::Distillation => 1.5,
61            Self::MoE => 0.8,
62            Self::Ensemble => 0.5,
63        }
64    }
65
66    /// Get typical quality retention compared to full training
67    ///
68    /// Returns expected quality as a fraction of full fine-tuning quality.
69    pub fn typical_quality_retention(&self) -> f64 {
70        match self {
71            Self::TraditionalMl => 0.7,
72            Self::DeepLearning => 1.0,
73            Self::FineTuning(method) => method.quality_retention(),
74            Self::Distillation => 0.85,
75            Self::MoE => 1.05,
76            Self::Ensemble => 1.02,
77        }
78    }
79
80    /// Check if this paradigm requires a pretrained model
81    pub fn requires_pretrained(&self) -> bool {
82        matches!(self, Self::FineTuning(_) | Self::Distillation)
83    }
84
85    /// Check if this paradigm is parameter-efficient
86    pub fn is_parameter_efficient(&self) -> bool {
87        matches!(
88            self,
89            Self::FineTuning(
90                FineTuneMethod::LoRA { .. }
91                    | FineTuneMethod::QLoRA { .. }
92                    | FineTuneMethod::Adapter
93                    | FineTuneMethod::Prefix
94                    | FineTuneMethod::IA3
95            )
96        )
97    }
98
99    /// Get recommended batch size multiplier
100    ///
101    /// Parameter-efficient methods allow larger batch sizes due to lower memory.
102    pub fn batch_size_multiplier(&self) -> f64 {
103        match self {
104            Self::TraditionalMl => 10.0,
105            Self::DeepLearning => 1.0,
106            Self::FineTuning(method) => method.batch_size_multiplier(),
107            Self::Distillation => 0.5,
108            Self::MoE => 1.2,
109            Self::Ensemble => 0.3,
110        }
111    }
112}
113
114impl std::fmt::Display for ModelParadigm {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            Self::TraditionalMl => write!(f, "Traditional ML"),
118            Self::DeepLearning => write!(f, "Deep Learning"),
119            Self::FineTuning(method) => write!(f, "Fine-tuning ({method})"),
120            Self::Distillation => write!(f, "Knowledge Distillation"),
121            Self::MoE => write!(f, "Mixture of Experts"),
122            Self::Ensemble => write!(f, "Ensemble"),
123        }
124    }
125}