entrenar/prune/config/
pruning_config.rs1use crate::prune::schedule::PruningSchedule;
4use serde::{Deserialize, Serialize};
5
6use super::{PruneMethod, SparsityPatternConfig};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PruningConfig {
28 method: PruneMethod,
30
31 target_sparsity: f32,
33
34 pattern: SparsityPatternConfig,
36
37 schedule: PruningSchedule,
39
40 fine_tune_after_pruning: bool,
42
43 fine_tune_steps: usize,
45
46 fine_tune_lr: f32,
48
49 skip_embed_layers: bool,
51}
52
53impl Default for PruningConfig {
54 fn default() -> Self {
55 Self {
56 method: PruneMethod::default(),
57 target_sparsity: 0.5,
58 pattern: SparsityPatternConfig::default(),
59 schedule: PruningSchedule::default(),
60 fine_tune_after_pruning: true,
61 fine_tune_steps: 1000,
62 fine_tune_lr: 1e-5,
63 skip_embed_layers: true,
64 }
65 }
66}
67
68impl PruningConfig {
69 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn with_method(mut self, method: PruneMethod) -> Self {
76 self.method = method;
77 self
78 }
79
80 pub fn with_target_sparsity(mut self, sparsity: f32) -> Self {
82 self.target_sparsity = sparsity.clamp(0.0, 1.0);
83 self
84 }
85
86 pub fn with_pattern(mut self, pattern: SparsityPatternConfig) -> Self {
88 self.pattern = pattern;
89 self
90 }
91
92 pub fn with_schedule(mut self, schedule: PruningSchedule) -> Self {
94 self.schedule = schedule;
95 self
96 }
97
98 pub fn with_fine_tune(mut self, enabled: bool) -> Self {
100 self.fine_tune_after_pruning = enabled;
101 self
102 }
103
104 pub fn with_fine_tune_steps(mut self, steps: usize) -> Self {
106 self.fine_tune_steps = steps;
107 self
108 }
109
110 pub fn with_fine_tune_lr(mut self, lr: f32) -> Self {
112 self.fine_tune_lr = lr;
113 self
114 }
115
116 pub fn with_skip_embed_layers(mut self, skip: bool) -> Self {
118 self.skip_embed_layers = skip;
119 self
120 }
121
122 pub fn method(&self) -> PruneMethod {
124 self.method
125 }
126
127 pub fn target_sparsity(&self) -> f32 {
129 self.target_sparsity
130 }
131
132 pub fn pattern(&self) -> &SparsityPatternConfig {
134 &self.pattern
135 }
136
137 pub fn schedule(&self) -> &PruningSchedule {
139 &self.schedule
140 }
141
142 pub fn fine_tune_after_pruning(&self) -> bool {
144 self.fine_tune_after_pruning
145 }
146
147 pub fn fine_tune_steps(&self) -> usize {
149 self.fine_tune_steps
150 }
151
152 pub fn fine_tune_lr(&self) -> f32 {
154 self.fine_tune_lr
155 }
156
157 pub fn skip_embed_layers(&self) -> bool {
159 self.skip_embed_layers
160 }
161
162 pub fn requires_calibration(&self) -> bool {
164 self.method.requires_calibration()
165 }
166
167 pub fn validate(&self) -> Result<(), String> {
169 self.schedule.validate()?;
171
172 if self.target_sparsity < 0.0 || self.target_sparsity > 1.0 {
174 return Err(format!(
175 "target_sparsity ({}) must be between 0.0 and 1.0",
176 self.target_sparsity
177 ));
178 }
179
180 if let SparsityPatternConfig::NM { n, m } = &self.pattern {
182 if *n >= *m {
183 return Err(format!("N ({n}) must be less than M ({m})"));
184 }
185 if *m == 0 {
186 return Err("M cannot be 0".to_string());
187 }
188 }
189
190 if let SparsityPatternConfig::Block { height, width } = &self.pattern {
192 if *height == 0 || *width == 0 {
193 return Err("Block dimensions must be non-zero".to_string());
194 }
195 }
196
197 Ok(())
198 }
199}