entrenar/prune/callback/
pruning_callback.rs1#![allow(clippy::field_reassign_with_default)]
7
8use crate::prune::calibrate::{CalibrationCollector, CalibrationConfig};
9use crate::prune::config::PruningConfig;
10use crate::prune::schedule::PruningSchedule;
11use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
12
13#[derive(Debug)]
41pub struct PruningCallback {
42 config: PruningConfig,
44 current_sparsity: f32,
46 parameters_pruned: usize,
48 pub(crate) calibration: Option<CalibrationCollector>,
50 enabled: bool,
52 pub(crate) last_prune_step: Option<usize>,
54}
55
56impl PruningCallback {
57 pub fn new(config: PruningConfig) -> Self {
63 let calibration = if config.requires_calibration() {
64 Some(CalibrationCollector::new(CalibrationConfig::default()))
65 } else {
66 None
67 };
68
69 Self {
70 config,
71 current_sparsity: 0.0,
72 parameters_pruned: 0,
73 calibration,
74 enabled: true,
75 last_prune_step: None,
76 }
77 }
78
79 pub fn with_calibration(config: PruningConfig, cal_config: CalibrationConfig) -> Self {
81 Self { calibration: Some(CalibrationCollector::new(cal_config)), ..Self::new(config) }
82 }
83
84 pub fn set_enabled(&mut self, enabled: bool) {
86 self.enabled = enabled;
87 }
88
89 pub fn is_enabled(&self) -> bool {
91 self.enabled
92 }
93
94 pub fn current_sparsity(&self) -> f32 {
96 self.current_sparsity
97 }
98
99 pub fn target_sparsity(&self) -> f32 {
101 self.config.target_sparsity()
102 }
103
104 pub fn parameters_pruned(&self) -> usize {
106 self.parameters_pruned
107 }
108
109 pub fn schedule(&self) -> &PruningSchedule {
111 self.config.schedule()
112 }
113
114 pub fn is_complete(&self) -> bool {
116 self.last_prune_step.is_some_and(|step| self.config.schedule().is_complete(step))
117 }
118
119 pub fn last_prune_step(&self) -> Option<usize> {
121 self.last_prune_step
122 }
123
124 pub fn set_current_sparsity(&mut self, sparsity: f32) {
126 self.current_sparsity = sparsity.clamp(0.0, 1.0);
127 }
128
129 pub fn config(&self) -> &PruningConfig {
131 &self.config
132 }
133
134 pub(crate) fn should_prune(&self, step: usize) -> bool {
136 if !self.enabled {
137 return false;
138 }
139 let target = self.config.schedule().sparsity_at_step(step);
140 target > self.current_sparsity && self.config.schedule().should_prune_at_step(step)
141 }
142
143 pub fn progress(&self) -> f32 {
145 let target = self.config.target_sparsity();
146 if target <= 0.0 {
147 return 1.0;
148 }
149 (self.current_sparsity / target).clamp(0.0, 1.0)
150 }
151}
152
153impl TrainerCallback for PruningCallback {
154 fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
155 if let Err(e) = self.config.schedule().validate() {
157 eprintln!("[PruningCallback] Invalid schedule configuration: {e}");
158 return CallbackAction::Stop;
159 }
160 CallbackAction::Continue
161 }
162
163 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
164 if !self.enabled {
165 return CallbackAction::Continue;
166 }
167
168 let step = ctx.global_step;
169 let target_sparsity = self.config.schedule().sparsity_at_step(step);
170
171 if self.should_prune(step) {
173 self.current_sparsity = target_sparsity;
181 self.last_prune_step = Some(step);
182
183 }
186
187 CallbackAction::Continue
188 }
189
190 fn on_train_end(&mut self, _ctx: &CallbackContext) {
191 if self.parameters_pruned > 0 || self.current_sparsity > 0.0 {
193 eprintln!(
194 "[PruningCallback] Training complete. Final sparsity: {:.2}%, Parameters pruned: {}",
195 self.current_sparsity * 100.0,
196 self.parameters_pruned
197 );
198 }
199 }
200
201 fn name(&self) -> &'static str {
202 "PruningCallback"
203 }
204}
205
206impl Clone for PruningCallback {
207 fn clone(&self) -> Self {
208 Self {
209 config: self.config.clone(),
210 current_sparsity: self.current_sparsity,
211 parameters_pruned: self.parameters_pruned,
212 calibration: self.calibration.clone(),
213 enabled: self.enabled,
214 last_prune_step: self.last_prune_step,
215 }
216 }
217}