Skip to main content

entrenar/prune/callback/
pruning_callback.rs

1//! Pruning callback implementation
2//!
3//! This module contains the `PruningCallback` struct that integrates with
4//! Entrenar's training callback system to apply pruning during training.
5
6#![allow(clippy::field_reassign_with_default)]
7
8use crate::prune::calibrate::{CalibrationCollector, CalibrationConfig};
9use crate::prune::config::PruningConfig;
10use crate::prune::schedule::PruningSchedule;
11use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
12
13/// Callback for applying pruning during training.
14///
15/// Integrates with the training loop to apply pruning at scheduled steps,
16/// collect calibration data for activation-weighted methods, and log
17/// pruning metrics.
18///
19/// # Toyota Way: Kaizen (Continuous Improvement)
20/// Gradual pruning allows the model to adapt incrementally to sparsity.
21///
22/// # Example
23///
24/// ```ignore
25/// use entrenar::prune::{PruningCallback, PruningConfig, PruningSchedule};
26///
27/// let config = PruningConfig::new()
28///     .with_schedule(PruningSchedule::Gradual {
29///         start_step: 1000,
30///         end_step: 5000,
31///         initial_sparsity: 0.0,
32///         final_sparsity: 0.5,
33///         frequency: 100,
34///     })
35///     .with_target_sparsity(0.5);
36///
37/// let callback = PruningCallback::new(config);
38/// trainer.add_callback(callback);
39/// ```
40#[derive(Debug)]
41pub struct PruningCallback {
42    /// Configuration for pruning
43    config: PruningConfig,
44    /// Current achieved sparsity
45    current_sparsity: f32,
46    /// Total parameters pruned so far
47    parameters_pruned: usize,
48    /// Calibration data collector
49    pub(crate) calibration: Option<CalibrationCollector>,
50    /// Whether pruning is enabled
51    enabled: bool,
52    /// Step when last pruning occurred
53    pub(crate) last_prune_step: Option<usize>,
54}
55
56impl PruningCallback {
57    /// Create a new pruning callback with the given configuration.
58    ///
59    /// # Arguments
60    ///
61    /// * `config` - Pruning configuration
62    pub fn new(config: PruningConfig) -> Self {
63        let calibration = if config.requires_calibration() {
64            Some(CalibrationCollector::new(CalibrationConfig::default()))
65        } else {
66            None
67        };
68
69        Self {
70            config,
71            current_sparsity: 0.0,
72            parameters_pruned: 0,
73            calibration,
74            enabled: true,
75            last_prune_step: None,
76        }
77    }
78
79    /// Create a pruning callback with custom calibration configuration.
80    pub fn with_calibration(config: PruningConfig, cal_config: CalibrationConfig) -> Self {
81        Self { calibration: Some(CalibrationCollector::new(cal_config)), ..Self::new(config) }
82    }
83
84    /// Enable or disable the callback.
85    pub fn set_enabled(&mut self, enabled: bool) {
86        self.enabled = enabled;
87    }
88
89    /// Check if the callback is enabled.
90    pub fn is_enabled(&self) -> bool {
91        self.enabled
92    }
93
94    /// Get the current achieved sparsity.
95    pub fn current_sparsity(&self) -> f32 {
96        self.current_sparsity
97    }
98
99    /// Get the target sparsity from the configuration.
100    pub fn target_sparsity(&self) -> f32 {
101        self.config.target_sparsity()
102    }
103
104    /// Get the total number of parameters pruned.
105    pub fn parameters_pruned(&self) -> usize {
106        self.parameters_pruned
107    }
108
109    /// Get the pruning schedule.
110    pub fn schedule(&self) -> &PruningSchedule {
111        self.config.schedule()
112    }
113
114    /// Check if pruning is complete.
115    pub fn is_complete(&self) -> bool {
116        self.last_prune_step.is_some_and(|step| self.config.schedule().is_complete(step))
117    }
118
119    /// Get the step at which pruning last occurred.
120    pub fn last_prune_step(&self) -> Option<usize> {
121        self.last_prune_step
122    }
123
124    /// Update current sparsity (for testing or manual updates).
125    pub fn set_current_sparsity(&mut self, sparsity: f32) {
126        self.current_sparsity = sparsity.clamp(0.0, 1.0);
127    }
128
129    /// Get the configuration.
130    pub fn config(&self) -> &PruningConfig {
131        &self.config
132    }
133
134    /// Check if pruning should be applied at the given step.
135    pub(crate) fn should_prune(&self, step: usize) -> bool {
136        if !self.enabled {
137            return false;
138        }
139        let target = self.config.schedule().sparsity_at_step(step);
140        target > self.current_sparsity && self.config.schedule().should_prune_at_step(step)
141    }
142
143    /// Compute progress through the pruning schedule (0.0 to 1.0).
144    pub fn progress(&self) -> f32 {
145        let target = self.config.target_sparsity();
146        if target <= 0.0 {
147            return 1.0;
148        }
149        (self.current_sparsity / target).clamp(0.0, 1.0)
150    }
151}
152
153impl TrainerCallback for PruningCallback {
154    fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
155        // Validate configuration at training start
156        if let Err(e) = self.config.schedule().validate() {
157            eprintln!("[PruningCallback] Invalid schedule configuration: {e}");
158            return CallbackAction::Stop;
159        }
160        CallbackAction::Continue
161    }
162
163    fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
164        if !self.enabled {
165            return CallbackAction::Continue;
166        }
167
168        let step = ctx.global_step;
169        let target_sparsity = self.config.schedule().sparsity_at_step(step);
170
171        // Check if we should prune at this step
172        if self.should_prune(step) {
173            // In a real implementation, we would:
174            // 1. Collect calibration data if needed
175            // 2. Compute importance scores
176            // 3. Apply pruning masks
177            // 4. Update metrics
178
179            // For now, simulate pruning progress
180            self.current_sparsity = target_sparsity;
181            self.last_prune_step = Some(step);
182
183            // Log pruning event (placeholder for actual logging)
184            // In production, this would integrate with the monitoring system
185        }
186
187        CallbackAction::Continue
188    }
189
190    fn on_train_end(&mut self, _ctx: &CallbackContext) {
191        // Log final pruning summary
192        if self.parameters_pruned > 0 || self.current_sparsity > 0.0 {
193            eprintln!(
194                "[PruningCallback] Training complete. Final sparsity: {:.2}%, Parameters pruned: {}",
195                self.current_sparsity * 100.0,
196                self.parameters_pruned
197            );
198        }
199    }
200
201    fn name(&self) -> &'static str {
202        "PruningCallback"
203    }
204}
205
206impl Clone for PruningCallback {
207    fn clone(&self) -> Self {
208        Self {
209            config: self.config.clone(),
210            current_sparsity: self.current_sparsity,
211            parameters_pruned: self.parameters_pruned,
212            calibration: self.calibration.clone(),
213            enabled: self.enabled,
214            last_prune_step: self.last_prune_step,
215        }
216    }
217}