use crate::prune::schedule::PruningSchedule;
use serde::{Deserialize, Serialize};
use super::{PruneMethod, SparsityPatternConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningConfig {
method: PruneMethod,
target_sparsity: f32,
pattern: SparsityPatternConfig,
schedule: PruningSchedule,
fine_tune_after_pruning: bool,
fine_tune_steps: usize,
fine_tune_lr: f32,
skip_embed_layers: bool,
}
impl Default for PruningConfig {
fn default() -> Self {
Self {
method: PruneMethod::default(),
target_sparsity: 0.5,
pattern: SparsityPatternConfig::default(),
schedule: PruningSchedule::default(),
fine_tune_after_pruning: true,
fine_tune_steps: 1000,
fine_tune_lr: 1e-5,
skip_embed_layers: true,
}
}
}
impl PruningConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_method(mut self, method: PruneMethod) -> Self {
self.method = method;
self
}
pub fn with_target_sparsity(mut self, sparsity: f32) -> Self {
self.target_sparsity = sparsity.clamp(0.0, 1.0);
self
}
pub fn with_pattern(mut self, pattern: SparsityPatternConfig) -> Self {
self.pattern = pattern;
self
}
pub fn with_schedule(mut self, schedule: PruningSchedule) -> Self {
self.schedule = schedule;
self
}
pub fn with_fine_tune(mut self, enabled: bool) -> Self {
self.fine_tune_after_pruning = enabled;
self
}
pub fn with_fine_tune_steps(mut self, steps: usize) -> Self {
self.fine_tune_steps = steps;
self
}
pub fn with_fine_tune_lr(mut self, lr: f32) -> Self {
self.fine_tune_lr = lr;
self
}
pub fn with_skip_embed_layers(mut self, skip: bool) -> Self {
self.skip_embed_layers = skip;
self
}
pub fn method(&self) -> PruneMethod {
self.method
}
pub fn target_sparsity(&self) -> f32 {
self.target_sparsity
}
pub fn pattern(&self) -> &SparsityPatternConfig {
&self.pattern
}
pub fn schedule(&self) -> &PruningSchedule {
&self.schedule
}
pub fn fine_tune_after_pruning(&self) -> bool {
self.fine_tune_after_pruning
}
pub fn fine_tune_steps(&self) -> usize {
self.fine_tune_steps
}
pub fn fine_tune_lr(&self) -> f32 {
self.fine_tune_lr
}
pub fn skip_embed_layers(&self) -> bool {
self.skip_embed_layers
}
pub fn requires_calibration(&self) -> bool {
self.method.requires_calibration()
}
pub fn validate(&self) -> Result<(), String> {
self.schedule.validate()?;
if self.target_sparsity < 0.0 || self.target_sparsity > 1.0 {
return Err(format!(
"target_sparsity ({}) must be between 0.0 and 1.0",
self.target_sparsity
));
}
if let SparsityPatternConfig::NM { n, m } = &self.pattern {
if *n >= *m {
return Err(format!("N ({n}) must be less than M ({m})"));
}
if *m == 0 {
return Err("M cannot be 0".to_string());
}
}
if let SparsityPatternConfig::Block { height, width } = &self.pattern {
if *height == 0 || *width == 0 {
return Err("Block dimensions must be non-zero".to_string());
}
}
Ok(())
}
}