use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::TrustformersError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressiveConfig {
pub growth_dimension: GrowthDimension,
pub growth_strategy: GrowthStrategy,
pub initial_size: usize,
pub final_size: usize,
pub growth_epochs: Vec<usize>,
pub warmup_steps: usize,
pub zero_init_new_params: bool,
pub lr_scaling_factor: f64,
pub gradual_initialization: bool,
pub transition_smoothing: f64,
pub freeze_old_params_during_warmup: bool,
}
impl Default for ProgressiveConfig {
fn default() -> Self {
Self {
growth_dimension: GrowthDimension::Layers,
growth_strategy: GrowthStrategy::Linear,
initial_size: 6,
final_size: 12,
growth_epochs: vec![10, 20, 30, 40],
warmup_steps: 1000,
zero_init_new_params: true,
lr_scaling_factor: 0.5,
gradual_initialization: true,
transition_smoothing: 0.1,
freeze_old_params_during_warmup: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GrowthDimension {
Layers,
HiddenDim,
AttentionHeads,
IntermediateDim,
VocabSize,
MultiDimensional,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GrowthStrategy {
Linear,
Exponential,
Logarithmic,
Adaptive,
Custom,
Staged,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrowthSchedule {
pub growth_points: HashMap<usize, usize>,
pub adaptive: bool,
pub min_growth_interval: usize,
pub max_growth_per_step: usize,
}
pub struct ProgressiveTrainer {
config: ProgressiveConfig,
current_size: usize,
current_epoch: usize,
current_step: usize,
growth_schedule: GrowthSchedule,
growth_history: Vec<GrowthEvent>,
warmup_remaining: usize,
frozen_parameters: HashSet<String>,
learning_progress: LearningProgress,
}
use std::collections::HashSet;
impl ProgressiveTrainer {
pub fn new(config: ProgressiveConfig) -> Result<Self, TrustformersError> {
let growth_schedule = Self::create_growth_schedule(&config)?;
Ok(Self {
current_size: config.initial_size,
current_epoch: 0,
current_step: 0,
growth_schedule,
growth_history: Vec::new(),
warmup_remaining: 0,
frozen_parameters: HashSet::new(),
learning_progress: LearningProgress::new(),
config,
})
}
fn create_growth_schedule(
config: &ProgressiveConfig,
) -> Result<GrowthSchedule, TrustformersError> {
let mut growth_points = HashMap::new();
match config.growth_strategy {
GrowthStrategy::Linear => {
let total_growth = config.final_size - config.initial_size;
let num_steps = config.growth_epochs.len();
let growth_per_step = total_growth / num_steps.max(1);
for (i, &epoch) in config.growth_epochs.iter().enumerate() {
let new_size = config.initial_size + (i + 1) * growth_per_step;
growth_points.insert(epoch, new_size.min(config.final_size));
}
},
GrowthStrategy::Exponential => {
for (i, &epoch) in config.growth_epochs.iter().enumerate() {
let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
let exp_progress = progress.powf(2.0);
let new_size = config.initial_size
+ ((config.final_size - config.initial_size) as f64 * exp_progress)
as usize;
growth_points.insert(epoch, new_size.min(config.final_size));
}
},
GrowthStrategy::Logarithmic => {
for (i, &epoch) in config.growth_epochs.iter().enumerate() {
let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
let log_progress = (1.0 + progress).ln() / (2.0_f64).ln();
let new_size = config.initial_size
+ ((config.final_size - config.initial_size) as f64 * log_progress)
as usize;
growth_points.insert(epoch, new_size.min(config.final_size));
}
},
GrowthStrategy::Adaptive => {
for (i, &epoch) in config.growth_epochs.iter().enumerate() {
let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
let new_size = config.initial_size
+ ((config.final_size - config.initial_size) as f64 * progress) as usize;
growth_points.insert(epoch, new_size.min(config.final_size));
}
},
GrowthStrategy::Staged => {
let stage_size =
(config.final_size - config.initial_size) / config.growth_epochs.len().max(1);
for (i, &epoch) in config.growth_epochs.iter().enumerate() {
let new_size = config.initial_size + (i + 1) * stage_size;
growth_points.insert(epoch, new_size.min(config.final_size));
}
},
GrowthStrategy::Custom => {
},
}
Ok(GrowthSchedule {
growth_points,
adaptive: matches!(config.growth_strategy, GrowthStrategy::Adaptive),
min_growth_interval: 5,
max_growth_per_step: (config.final_size - config.initial_size) / 2,
})
}
pub fn should_grow(&self, epoch: usize) -> bool {
if self.warmup_remaining > 0 {
return false;
}
if let Some(&target_size) = self.growth_schedule.growth_points.get(&epoch) {
return target_size > self.current_size;
}
if self.growth_schedule.adaptive {
return self.learning_progress.should_trigger_growth(epoch);
}
false
}
pub fn grow_model(
&mut self,
model: &mut dyn ProgressiveModel,
epoch: usize,
) -> Result<GrowthResult, TrustformersError> {
let target_size = self
.growth_schedule
.growth_points
.get(&epoch)
.copied()
.unwrap_or_else(|| self.determine_adaptive_growth_size(epoch));
if target_size <= self.current_size {
return Ok(GrowthResult::NoGrowthNeeded);
}
let growth_amount = target_size - self.current_size;
let start_time = std::time::Instant::now();
let growth_info = match self.config.growth_dimension {
GrowthDimension::Layers => self.grow_layers(model, growth_amount)?,
GrowthDimension::HiddenDim => self.grow_hidden_dimension(model, target_size)?,
GrowthDimension::AttentionHeads => self.grow_attention_heads(model, target_size)?,
GrowthDimension::IntermediateDim => {
self.grow_intermediate_dimension(model, target_size)?
},
GrowthDimension::VocabSize => self.grow_vocabulary(model, target_size)?,
GrowthDimension::MultiDimensional => self.grow_multi_dimensional(model, target_size)?,
};
let growth_event = GrowthEvent {
epoch,
old_size: self.current_size,
new_size: target_size,
growth_dimension: self.config.growth_dimension,
growth_time: start_time.elapsed(),
growth_info: growth_info.clone(),
};
self.growth_history.push(growth_event);
self.current_size = target_size;
self.warmup_remaining = self.config.warmup_steps;
if self.config.freeze_old_params_during_warmup {
self.freeze_old_parameters(model)?;
}
Ok(GrowthResult::Grown {
old_size: self.current_size,
new_size: target_size,
growth_info,
})
}
fn grow_layers(
&mut self,
model: &mut dyn ProgressiveModel,
num_layers: usize,
) -> Result<GrowthInfo, TrustformersError> {
let mut added_parameters = 0;
let mut initialization_method = String::new();
for i in 0..num_layers {
let layer_params = model.add_layer(self.current_size + i)?;
added_parameters += layer_params;
if self.config.gradual_initialization {
let scale = self.config.transition_smoothing * (i + 1) as f64 / num_layers as f64;
model.scale_layer_parameters(self.current_size + i, scale)?;
initialization_method = format!("Gradual scaling (factor: {})", scale);
} else if self.config.zero_init_new_params {
model.zero_initialize_layer(self.current_size + i)?;
initialization_method = "Zero initialization".to_string();
}
}
Ok(GrowthInfo {
added_parameters,
initialization_method,
growth_type: "Layer addition".to_string(),
})
}
fn grow_hidden_dimension(
&mut self,
model: &mut dyn ProgressiveModel,
target_dim: usize,
) -> Result<GrowthInfo, TrustformersError> {
let old_dim = model.get_hidden_dimension()?;
let _growth = target_dim - old_dim;
let added_parameters = model.expand_hidden_dimension(target_dim)?;
if self.config.gradual_initialization {
model.initialize_expanded_dimensions(
old_dim,
target_dim,
self.config.transition_smoothing,
)?;
}
Ok(GrowthInfo {
added_parameters,
initialization_method: "Hidden dimension expansion".to_string(),
growth_type: format!("Hidden dim: {} -> {}", old_dim, target_dim),
})
}
fn grow_attention_heads(
&mut self,
model: &mut dyn ProgressiveModel,
target_heads: usize,
) -> Result<GrowthInfo, TrustformersError> {
let old_heads = model.get_num_attention_heads()?;
let added_parameters = model.expand_attention_heads(target_heads)?;
Ok(GrowthInfo {
added_parameters,
initialization_method: "Attention head expansion".to_string(),
growth_type: format!("Attention heads: {} -> {}", old_heads, target_heads),
})
}
fn grow_intermediate_dimension(
&mut self,
model: &mut dyn ProgressiveModel,
target_dim: usize,
) -> Result<GrowthInfo, TrustformersError> {
let old_dim = model.get_intermediate_dimension()?;
let added_parameters = model.expand_intermediate_dimension(target_dim)?;
Ok(GrowthInfo {
added_parameters,
initialization_method: "Intermediate dimension expansion".to_string(),
growth_type: format!("Intermediate dim: {} -> {}", old_dim, target_dim),
})
}
fn grow_vocabulary(
&mut self,
model: &mut dyn ProgressiveModel,
target_vocab: usize,
) -> Result<GrowthInfo, TrustformersError> {
let old_vocab = model.get_vocab_size()?;
let added_parameters = model.expand_vocabulary(target_vocab)?;
Ok(GrowthInfo {
added_parameters,
initialization_method: "Vocabulary expansion".to_string(),
growth_type: format!("Vocab size: {} -> {}", old_vocab, target_vocab),
})
}
fn grow_multi_dimensional(
&mut self,
model: &mut dyn ProgressiveModel,
_target_size: usize,
) -> Result<GrowthInfo, TrustformersError> {
let mut total_added_parameters = 0;
if self.current_size < self.config.final_size / 2 {
let layer_growth = self.grow_layers(model, 1)?;
total_added_parameters += layer_growth.added_parameters;
}
let current_hidden = model.get_hidden_dimension()?;
if current_hidden < 1024 {
let width_growth = self.grow_hidden_dimension(model, current_hidden + 64)?;
total_added_parameters += width_growth.added_parameters;
}
Ok(GrowthInfo {
added_parameters: total_added_parameters,
initialization_method: "Multi-dimensional growth".to_string(),
growth_type: "Combined layer and width growth".to_string(),
})
}
fn determine_adaptive_growth_size(&self, _epoch: usize) -> usize {
if self.learning_progress.is_plateau() {
(self.current_size as f64 * 1.2) as usize } else {
self.current_size + 1 }
}
fn freeze_old_parameters(
&mut self,
model: &mut dyn ProgressiveModel,
) -> Result<(), TrustformersError> {
let old_param_names = model.get_parameter_names()?;
for name in old_param_names {
self.frozen_parameters.insert(name);
}
model.freeze_parameters(&self.frozen_parameters)?;
Ok(())
}
fn unfreeze_parameters(
&mut self,
model: &mut dyn ProgressiveModel,
) -> Result<(), TrustformersError> {
model.unfreeze_parameters(&self.frozen_parameters)?;
self.frozen_parameters.clear();
Ok(())
}
pub fn step(
&mut self,
model: &mut dyn ProgressiveModel,
loss: f64,
) -> Result<(), TrustformersError> {
self.current_step += 1;
self.learning_progress.update(loss);
if self.warmup_remaining > 0 {
self.warmup_remaining -= 1;
if self.warmup_remaining == 0 && !self.frozen_parameters.is_empty() {
self.unfreeze_parameters(model)?;
}
}
Ok(())
}
pub fn set_epoch(&mut self, epoch: usize) {
self.current_epoch = epoch;
self.learning_progress.new_epoch();
}
pub fn current_size(&self) -> usize {
self.current_size
}
pub fn growth_history(&self) -> &[GrowthEvent] {
&self.growth_history
}
pub fn is_in_warmup(&self) -> bool {
self.warmup_remaining > 0
}
pub fn learning_progress(&self) -> &LearningProgress {
&self.learning_progress
}
pub fn update_growth_schedule(&mut self, new_points: HashMap<usize, usize>) {
self.growth_schedule.growth_points.extend(new_points);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrowthInfo {
pub added_parameters: usize,
pub initialization_method: String,
pub growth_type: String,
}
#[derive(Debug)]
pub enum GrowthResult {
Grown {
old_size: usize,
new_size: usize,
growth_info: GrowthInfo,
},
NoGrowthNeeded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrowthEvent {
pub epoch: usize,
pub old_size: usize,
pub new_size: usize,
pub growth_dimension: GrowthDimension,
pub growth_time: std::time::Duration,
pub growth_info: GrowthInfo,
}
#[derive(Debug)]
pub struct LearningProgress {
loss_history: Vec<f64>,
recent_losses: std::collections::VecDeque<f64>,
plateau_threshold: f64,
plateau_patience: usize,
#[allow(dead_code)]
improvement_threshold: f64,
current_epoch: usize,
}
impl Default for LearningProgress {
fn default() -> Self {
Self::new()
}
}
impl LearningProgress {
pub fn new() -> Self {
Self {
loss_history: Vec::new(),
recent_losses: std::collections::VecDeque::with_capacity(10),
plateau_threshold: 0.001,
plateau_patience: 5,
improvement_threshold: 0.01,
current_epoch: 0,
}
}
pub fn update(&mut self, loss: f64) {
self.loss_history.push(loss);
self.recent_losses.push_back(loss);
if self.recent_losses.len() > 10 {
self.recent_losses.pop_front();
}
}
pub fn is_plateau(&self) -> bool {
if self.recent_losses.len() < self.plateau_patience {
return false;
}
let recent_avg = self.recent_losses.iter().sum::<f64>() / self.recent_losses.len() as f64;
let older_losses = &self.loss_history[self.loss_history.len().saturating_sub(20)
..self.loss_history.len().saturating_sub(10)];
if older_losses.is_empty() {
return false;
}
let older_avg = older_losses.iter().sum::<f64>() / older_losses.len() as f64;
let improvement = older_avg - recent_avg;
improvement < self.plateau_threshold
}
pub fn should_trigger_growth(&self, _epoch: usize) -> bool {
self.is_plateau() && self.loss_history.len() > 100
}
pub fn new_epoch(&mut self) {
self.current_epoch += 1;
}
}
pub trait ProgressiveModel {
fn add_layer(&mut self, layer_index: usize) -> Result<usize, TrustformersError>;
fn expand_hidden_dimension(&mut self, target_dim: usize) -> Result<usize, TrustformersError>;
fn expand_attention_heads(&mut self, target_heads: usize) -> Result<usize, TrustformersError>;
fn expand_intermediate_dimension(
&mut self,
target_dim: usize,
) -> Result<usize, TrustformersError>;
fn expand_vocabulary(&mut self, target_vocab: usize) -> Result<usize, TrustformersError>;
fn get_hidden_dimension(&self) -> Result<usize, TrustformersError>;
fn get_num_attention_heads(&self) -> Result<usize, TrustformersError>;
fn get_intermediate_dimension(&self) -> Result<usize, TrustformersError>;
fn get_vocab_size(&self) -> Result<usize, TrustformersError>;
fn zero_initialize_layer(&mut self, layer_index: usize) -> Result<(), TrustformersError>;
fn scale_layer_parameters(
&mut self,
layer_index: usize,
scale: f64,
) -> Result<(), TrustformersError>;
fn initialize_expanded_dimensions(
&mut self,
old_dim: usize,
new_dim: usize,
smoothing: f64,
) -> Result<(), TrustformersError>;
fn get_parameter_names(&self) -> Result<Vec<String>, TrustformersError>;
fn freeze_parameters(&mut self, param_names: &HashSet<String>)
-> Result<(), TrustformersError>;
fn unfreeze_parameters(
&mut self,
param_names: &HashSet<String>,
) -> Result<(), TrustformersError>;
}
pub mod utils {
pub fn create_linear_schedule(
initial_size: usize,
final_size: usize,
num_steps: usize,
start_epoch: usize,
epoch_interval: usize,
) -> Vec<usize> {
let _growth_per_step = (final_size - initial_size) / num_steps.max(1);
(0..num_steps).map(|i| start_epoch + i * epoch_interval).collect()
}
pub fn create_exponential_schedule(
_initial_size: usize,
_final_size: usize,
num_steps: usize,
start_epoch: usize,
epoch_interval: usize,
) -> Vec<usize> {
(0..num_steps)
.map(|i| start_epoch + (epoch_interval as f64 * (1.5_f64.powi(i as i32))) as usize)
.collect()
}
pub fn estimate_parameter_count(
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
_num_heads: usize,
intermediate_dim: usize,
) -> usize {
let embedding_params = vocab_size * hidden_dim;
let attention_params = 4 * hidden_dim * hidden_dim; let ffn_params = 2 * hidden_dim * intermediate_dim; let norm_params = 2 * hidden_dim; let layer_params = attention_params + ffn_params + norm_params;
embedding_params + num_layers * layer_params + hidden_dim }
pub fn calculate_optimal_schedule(
initial_size: usize,
final_size: usize,
total_epochs: usize,
_computational_budget: f64,
) -> Vec<usize> {
let mut schedule = Vec::new();
let num_growth_steps = ((final_size - initial_size) as f64).sqrt() as usize;
for i in 0..num_growth_steps {
let progress = i as f64 / num_growth_steps as f64;
let epoch = (total_epochs as f64 * progress.sqrt()) as usize;
schedule.push(epoch);
}
schedule
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progressive_config_default() {
let config = ProgressiveConfig::default();
assert_eq!(config.initial_size, 6);
assert_eq!(config.final_size, 12);
assert!(config.zero_init_new_params);
}
#[test]
fn test_growth_schedule_creation() {
let config = ProgressiveConfig {
growth_strategy: GrowthStrategy::Linear,
initial_size: 4,
final_size: 12,
growth_epochs: vec![10, 20, 30, 40],
..Default::default()
};
let schedule =
ProgressiveTrainer::create_growth_schedule(&config).expect("operation failed");
assert!(!schedule.growth_points.is_empty());
assert_eq!(schedule.growth_points.len(), 4);
}
#[test]
fn test_progressive_trainer_creation() {
let config = ProgressiveConfig::default();
let trainer = ProgressiveTrainer::new(config);
assert!(trainer.is_ok());
let trainer = trainer.expect("operation failed");
assert_eq!(trainer.current_size(), 6);
assert!(!trainer.is_in_warmup());
}
#[test]
fn test_learning_progress() {
let mut progress = LearningProgress::new();
for i in 0..20 {
progress.update(1.0 - i as f64 * 0.01); }
assert!(!progress.is_plateau());
for _ in 0..25 {
progress.update(0.8); }
assert!(progress.is_plateau());
}
#[test]
fn test_growth_dimensions() {
assert_eq!(GrowthDimension::Layers as u8, 0);
assert_ne!(GrowthDimension::Layers, GrowthDimension::HiddenDim);
}
#[test]
fn test_growth_strategies() {
assert_eq!(GrowthStrategy::Linear as u8, 0);
assert_ne!(GrowthStrategy::Linear, GrowthStrategy::Exponential);
}
#[test]
fn test_utils_parameter_estimation() {
let params = utils::estimate_parameter_count(30000, 768, 12, 12, 3072);
assert!(params > 100_000_000); }
#[test]
fn test_utils_linear_schedule() {
let schedule = utils::create_linear_schedule(6, 12, 3, 10, 5);
assert_eq!(schedule.len(), 3);
assert_eq!(schedule[0], 10);
assert_eq!(schedule[1], 15);
assert_eq!(schedule[2], 20);
}
}