Skip to main content

entrenar/prune/config/
pruning_config.rs

1//! Main pruning configuration struct.
2
3use crate::prune::schedule::PruningSchedule;
4use serde::{Deserialize, Serialize};
5
6use super::{PruneMethod, SparsityPatternConfig};
7
8/// Configuration for pruning operations.
9///
10/// # Example
11///
12/// ```
13/// use entrenar::prune::{PruningConfig, PruningSchedule, PruneMethod};
14///
15/// let config = PruningConfig::default()
16///     .with_method(PruneMethod::Wanda)
17///     .with_schedule(PruningSchedule::Gradual {
18///         start_step: 1000,
19///         end_step: 5000,
20///         initial_sparsity: 0.0,
21///         final_sparsity: 0.5,
22///         frequency: 100,
23///     })
24///     .with_target_sparsity(0.5);
25/// ```
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PruningConfig {
28    /// Pruning method to use.
29    method: PruneMethod,
30
31    /// Target sparsity (0.0 to 1.0).
32    target_sparsity: f32,
33
34    /// Sparsity pattern.
35    pattern: SparsityPatternConfig,
36
37    /// Pruning schedule.
38    schedule: PruningSchedule,
39
40    /// Whether to fine-tune after pruning.
41    fine_tune_after_pruning: bool,
42
43    /// Number of fine-tuning steps.
44    fine_tune_steps: usize,
45
46    /// Learning rate for fine-tuning.
47    fine_tune_lr: f32,
48
49    /// Whether to skip first and last layers (recommended for LLMs).
50    skip_embed_layers: bool,
51}
52
53impl Default for PruningConfig {
54    fn default() -> Self {
55        Self {
56            method: PruneMethod::default(),
57            target_sparsity: 0.5,
58            pattern: SparsityPatternConfig::default(),
59            schedule: PruningSchedule::default(),
60            fine_tune_after_pruning: true,
61            fine_tune_steps: 1000,
62            fine_tune_lr: 1e-5,
63            skip_embed_layers: true,
64        }
65    }
66}
67
68impl PruningConfig {
69    /// Create a new configuration with default values.
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    /// Set the pruning method.
75    pub fn with_method(mut self, method: PruneMethod) -> Self {
76        self.method = method;
77        self
78    }
79
80    /// Set the target sparsity.
81    pub fn with_target_sparsity(mut self, sparsity: f32) -> Self {
82        self.target_sparsity = sparsity.clamp(0.0, 1.0);
83        self
84    }
85
86    /// Set the sparsity pattern.
87    pub fn with_pattern(mut self, pattern: SparsityPatternConfig) -> Self {
88        self.pattern = pattern;
89        self
90    }
91
92    /// Set the pruning schedule.
93    pub fn with_schedule(mut self, schedule: PruningSchedule) -> Self {
94        self.schedule = schedule;
95        self
96    }
97
98    /// Enable or disable fine-tuning after pruning.
99    pub fn with_fine_tune(mut self, enabled: bool) -> Self {
100        self.fine_tune_after_pruning = enabled;
101        self
102    }
103
104    /// Set the number of fine-tuning steps.
105    pub fn with_fine_tune_steps(mut self, steps: usize) -> Self {
106        self.fine_tune_steps = steps;
107        self
108    }
109
110    /// Set the fine-tuning learning rate.
111    pub fn with_fine_tune_lr(mut self, lr: f32) -> Self {
112        self.fine_tune_lr = lr;
113        self
114    }
115
116    /// Enable or disable skipping embedding layers.
117    pub fn with_skip_embed_layers(mut self, skip: bool) -> Self {
118        self.skip_embed_layers = skip;
119        self
120    }
121
122    /// Get the pruning method.
123    pub fn method(&self) -> PruneMethod {
124        self.method
125    }
126
127    /// Get the target sparsity.
128    pub fn target_sparsity(&self) -> f32 {
129        self.target_sparsity
130    }
131
132    /// Get the sparsity pattern.
133    pub fn pattern(&self) -> &SparsityPatternConfig {
134        &self.pattern
135    }
136
137    /// Get the pruning schedule.
138    pub fn schedule(&self) -> &PruningSchedule {
139        &self.schedule
140    }
141
142    /// Check if fine-tuning is enabled.
143    pub fn fine_tune_after_pruning(&self) -> bool {
144        self.fine_tune_after_pruning
145    }
146
147    /// Get fine-tuning steps.
148    pub fn fine_tune_steps(&self) -> usize {
149        self.fine_tune_steps
150    }
151
152    /// Get fine-tuning learning rate.
153    pub fn fine_tune_lr(&self) -> f32 {
154        self.fine_tune_lr
155    }
156
157    /// Check if embedding layers should be skipped.
158    pub fn skip_embed_layers(&self) -> bool {
159        self.skip_embed_layers
160    }
161
162    /// Check if this configuration requires calibration data.
163    pub fn requires_calibration(&self) -> bool {
164        self.method.requires_calibration()
165    }
166
167    /// Validate the configuration.
168    pub fn validate(&self) -> Result<(), String> {
169        // Validate schedule
170        self.schedule.validate()?;
171
172        // Validate target sparsity
173        if self.target_sparsity < 0.0 || self.target_sparsity > 1.0 {
174            return Err(format!(
175                "target_sparsity ({}) must be between 0.0 and 1.0",
176                self.target_sparsity
177            ));
178        }
179
180        // Validate N:M pattern
181        if let SparsityPatternConfig::NM { n, m } = &self.pattern {
182            if *n >= *m {
183                return Err(format!("N ({n}) must be less than M ({m})"));
184            }
185            if *m == 0 {
186                return Err("M cannot be 0".to_string());
187            }
188        }
189
190        // Validate block pattern
191        if let SparsityPatternConfig::Block { height, width } = &self.pattern {
192            if *height == 0 || *width == 0 {
193                return Err("Block dimensions must be non-zero".to_string());
194            }
195        }
196
197        Ok(())
198    }
199}