#![allow(clippy::field_reassign_with_default)]
use crate::prune::calibrate::{CalibrationCollector, CalibrationConfig};
use crate::prune::config::PruningConfig;
use crate::prune::schedule::PruningSchedule;
use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
#[derive(Debug)]
pub struct PruningCallback {
config: PruningConfig,
current_sparsity: f32,
parameters_pruned: usize,
pub(crate) calibration: Option<CalibrationCollector>,
enabled: bool,
pub(crate) last_prune_step: Option<usize>,
}
impl PruningCallback {
pub fn new(config: PruningConfig) -> Self {
let calibration = if config.requires_calibration() {
Some(CalibrationCollector::new(CalibrationConfig::default()))
} else {
None
};
Self {
config,
current_sparsity: 0.0,
parameters_pruned: 0,
calibration,
enabled: true,
last_prune_step: None,
}
}
pub fn with_calibration(config: PruningConfig, cal_config: CalibrationConfig) -> Self {
Self { calibration: Some(CalibrationCollector::new(cal_config)), ..Self::new(config) }
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn current_sparsity(&self) -> f32 {
self.current_sparsity
}
pub fn target_sparsity(&self) -> f32 {
self.config.target_sparsity()
}
pub fn parameters_pruned(&self) -> usize {
self.parameters_pruned
}
pub fn schedule(&self) -> &PruningSchedule {
self.config.schedule()
}
pub fn is_complete(&self) -> bool {
self.last_prune_step.is_some_and(|step| self.config.schedule().is_complete(step))
}
pub fn last_prune_step(&self) -> Option<usize> {
self.last_prune_step
}
pub fn set_current_sparsity(&mut self, sparsity: f32) {
self.current_sparsity = sparsity.clamp(0.0, 1.0);
}
pub fn config(&self) -> &PruningConfig {
&self.config
}
pub(crate) fn should_prune(&self, step: usize) -> bool {
if !self.enabled {
return false;
}
let target = self.config.schedule().sparsity_at_step(step);
target > self.current_sparsity && self.config.schedule().should_prune_at_step(step)
}
pub fn progress(&self) -> f32 {
let target = self.config.target_sparsity();
if target <= 0.0 {
return 1.0;
}
(self.current_sparsity / target).clamp(0.0, 1.0)
}
}
impl TrainerCallback for PruningCallback {
fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
if let Err(e) = self.config.schedule().validate() {
eprintln!("[PruningCallback] Invalid schedule configuration: {e}");
return CallbackAction::Stop;
}
CallbackAction::Continue
}
fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
if !self.enabled {
return CallbackAction::Continue;
}
let step = ctx.global_step;
let target_sparsity = self.config.schedule().sparsity_at_step(step);
if self.should_prune(step) {
self.current_sparsity = target_sparsity;
self.last_prune_step = Some(step);
}
CallbackAction::Continue
}
fn on_train_end(&mut self, _ctx: &CallbackContext) {
if self.parameters_pruned > 0 || self.current_sparsity > 0.0 {
eprintln!(
"[PruningCallback] Training complete. Final sparsity: {:.2}%, Parameters pruned: {}",
self.current_sparsity * 100.0,
self.parameters_pruned
);
}
}
fn name(&self) -> &'static str {
"PruningCallback"
}
}
impl Clone for PruningCallback {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
current_sparsity: self.current_sparsity,
parameters_pruned: self.parameters_pruned,
calibration: self.calibration.clone(),
enabled: self.enabled,
last_prune_step: self.last_prune_step,
}
}
}