use crate::prune::calibrate::CalibrationCollector;
use crate::prune::data_loader::CalibrationDataLoader;
use crate::prune::pipeline::{PruneFinetunePipeline, PruningMetrics, PruningStage};
use super::config::PruneTrainerConfig;
#[derive(Debug)]
pub struct PruneTrainer {
config: PruneTrainerConfig,
pipeline: PruneFinetunePipeline,
data_loader: CalibrationDataLoader,
pub(crate) calibration: Option<CalibrationCollector>,
current_epoch: usize,
}
impl PruneTrainer {
pub fn new(config: PruneTrainerConfig) -> Self {
let pipeline = PruneFinetunePipeline::new(config.pruning.clone());
let data_loader = CalibrationDataLoader::new(config.calibration.clone());
Self { config, pipeline, data_loader, calibration: None, current_epoch: 0 }
}
pub fn config(&self) -> &PruneTrainerConfig {
&self.config
}
pub fn pipeline(&self) -> &PruneFinetunePipeline {
&self.pipeline
}
pub fn pipeline_mut(&mut self) -> &mut PruneFinetunePipeline {
&mut self.pipeline
}
pub fn stage(&self) -> PruningStage {
self.pipeline.stage()
}
pub fn metrics(&self) -> &PruningMetrics {
self.pipeline.metrics()
}
pub fn current_epoch(&self) -> usize {
self.current_epoch
}
pub fn is_complete(&self) -> bool {
self.pipeline.is_complete()
}
pub fn succeeded(&self) -> bool {
self.pipeline.succeeded()
}
pub fn error(&self) -> Option<&str> {
self.pipeline.error()
}
pub fn initialize(&mut self) -> Result<(), String> {
self.config.validate()?;
self.data_loader.load()?;
if self.config.pruning.requires_calibration() {
self.calibration = Some(CalibrationCollector::new(
crate::prune::calibrate::CalibrationConfig::new()
.with_num_samples(self.config.calibration.num_samples()),
));
}
Ok(())
}
pub fn calibrate(&mut self) -> Result<(), String> {
if self.pipeline.stage() != PruningStage::Idle
&& self.pipeline.stage() != PruningStage::Calibrating
{
return Err("Cannot calibrate in current stage".to_string());
}
if self.calibration.is_none() && self.config.pruning.requires_calibration() {
self.calibration = Some(CalibrationCollector::new(
crate::prune::calibrate::CalibrationConfig::new()
.with_num_samples(self.config.calibration.num_samples()),
));
}
if self.pipeline.stage() == PruningStage::Idle {
if let Some(cal) = self.calibration.take() {
self.pipeline.start_calibration(cal);
} else {
self.pipeline.advance();
}
}
for _batch in &self.data_loader {
}
if self.pipeline.stage() == PruningStage::Calibrating {
self.pipeline.advance();
}
Ok(())
}
pub fn prune(&mut self) -> Result<(), String> {
while self.pipeline.stage() == PruningStage::ComputingImportance {
self.pipeline.advance();
}
if self.pipeline.stage() != PruningStage::Pruning {
return Err(format!("Cannot prune in stage {:?}", self.pipeline.stage()));
}
let target_sparsity = self.config.pruning.target_sparsity();
self.pipeline.metrics_mut().target_sparsity = target_sparsity;
self.pipeline.metrics_mut().achieved_sparsity = target_sparsity;
self.pipeline.advance();
Ok(())
}
pub fn finetune(&mut self) -> Result<(), String> {
if self.pipeline.stage() != PruningStage::FineTuning {
return Err(format!("Cannot finetune in stage {:?}", self.pipeline.stage()));
}
for epoch in 0..self.config.finetune_epochs {
self.current_epoch = epoch;
let loss = 1.0 / (epoch + 1) as f32;
self.pipeline.metrics_mut().record_finetune_loss(loss);
}
self.pipeline.advance();
Ok(())
}
pub fn evaluate(&mut self) -> Result<(), String> {
if self.pipeline.stage() != PruningStage::Evaluating {
return Err(format!("Cannot evaluate in stage {:?}", self.pipeline.stage()));
}
self.pipeline.advance();
Ok(())
}
pub fn export(&mut self) -> Result<(), String> {
if self.pipeline.stage() != PruningStage::Exporting {
return Err(format!("Cannot export in stage {:?}", self.pipeline.stage()));
}
self.pipeline.advance();
Ok(())
}
pub fn run(&mut self) -> Result<PruningMetrics, String> {
self.initialize()?;
self.calibrate()?;
self.prune()?;
if self.config.pruning.fine_tune_after_pruning() {
self.finetune()?;
}
if self.config.evaluate_pre_post {
self.evaluate()?;
}
self.export()?;
Ok(self.metrics().clone())
}
pub fn reset(&mut self) {
self.pipeline.reset();
self.calibration = None;
self.current_epoch = 0;
self.data_loader.reset();
}
}
impl Clone for PruneTrainer {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
pipeline: self.pipeline.clone(),
data_loader: self.data_loader.clone(),
calibration: self.calibration.clone(),
current_epoch: self.current_epoch,
}
}
}