#![ allow( clippy::missing_inline_in_public_items, clippy::unused_async ) ]
mod private
{
use std::
{
collections ::HashMap,
time ::SystemTime,
};
use core::time::Duration;
use serde::{ Deserialize, Serialize };
use tokio::sync::mpsc;
#[ derive( Debug, Clone, PartialEq, Serialize, Deserialize ) ]
pub enum TuningStatus
{
Validating,
Queued,
Running,
Succeeded,
Failed( String ),
Cancelled,
}
#[ derive( Debug, Clone, PartialEq, Serialize, Deserialize ) ]
pub enum TrainingObjective
{
LanguageModeling,
SupervisedFineTuning,
RLHF,
Custom
{
name : String,
parameters : HashMap< String, String >,
},
}
#[ derive( Debug, Clone, PartialEq, Serialize, Deserialize ) ]
pub enum FineTuningMethod
{
Full,
LoRA
{
rank : u32,
alpha : f64,
dropout : f64,
},
Adapter
{
hidden_dim : u32,
num_layers : u32,
},
PrefixTuning
{
prefix_length : u32,
embedding_dim : u32,
},
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct HyperParameters
{
pub learning_rate : f64,
pub batch_size : u32,
pub epochs : u32,
pub warmup_steps : u32,
pub weight_decay : f64,
pub gradient_clip_norm : f64,
pub lr_schedule : String,
pub custom_params : HashMap< String, String >,
}
impl Default for HyperParameters
{
fn default() -> Self
{
Self
{
learning_rate : 1e-4,
batch_size : 32,
epochs : 3,
warmup_steps : 100,
weight_decay : 0.01,
gradient_clip_norm : 1.0,
lr_schedule : "linear".to_string(),
custom_params : HashMap::new(),
}
}
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TrainingDataConfig
{
pub training_file : String,
pub validation_file : Option< String >,
pub data_format : String,
pub max_sequence_length : u32,
pub preprocessing : HashMap< String, String >,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct ModelCheckpoint
{
pub checkpoint_id : String,
pub step : u64,
pub loss : f64,
pub validation_metrics : HashMap< String, f64 >,
pub created_at : SystemTime,
pub file_path : String,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TrainingMetrics
{
pub step : u64,
pub epoch : u32,
pub training_loss : f64,
pub validation_loss : Option< f64 >,
pub learning_rate : f64,
pub throughput : f64,
pub eta_seconds : Option< u64 >,
pub custom_metrics : HashMap< String, f64 >,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TuningJobConfig
{
pub job_name : String,
pub base_model : String,
pub training_data : TrainingDataConfig,
pub hyperparameters : HyperParameters,
pub method : FineTuningMethod,
pub objective : TrainingObjective,
pub resource_requirements : TuningResourceRequirements,
pub checkpointing : CheckpointConfig,
pub env_vars : HashMap< String, String >,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TuningResourceRequirements
{
pub gpu_count : u32,
pub gpu_type : Option< String >,
pub memory_gb : u64,
pub cpu_cores : u32,
pub storage_gb : u64,
}
impl Default for TuningResourceRequirements
{
fn default() -> Self
{
Self
{
gpu_count : 1,
gpu_type : None,
memory_gb : 16,
cpu_cores : 4,
storage_gb : 100,
}
}
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct CheckpointConfig
{
pub enabled : bool,
pub save_interval : u64,
pub max_checkpoints : u32,
pub save_directory : String,
}
impl Default for CheckpointConfig
{
fn default() -> Self
{
Self
{
enabled : true,
save_interval : 1000,
max_checkpoints : 5,
save_directory : "./checkpoints".to_string(),
}
}
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TuningJob
{
pub config : TuningJobConfig,
pub status : TuningStatus,
pub created_at : SystemTime,
pub updated_at : SystemTime,
pub current_metrics : Option< TrainingMetrics >,
pub checkpoints : Vec< ModelCheckpoint >,
pub execution_log : Vec< TuningEvent >,
}
impl TuningJob
{
#[ must_use ]
pub fn new( config : TuningJobConfig ) -> Self
{
let now = SystemTime::now();
Self
{
config,
status : TuningStatus::Validating,
created_at : now,
updated_at : now,
current_metrics : None,
checkpoints : Vec::new(),
execution_log : Vec::new(),
}
}
pub fn update_status( &mut self, status : TuningStatus )
{
let event = TuningEvent
{
event_type : TuningEventType::StatusChanged
{
from : self.status.clone(),
to : status.clone(),
},
message : format!( "Status changed from {:?} to {:?}", self.status, status ),
timestamp : SystemTime::now(),
};
self.status = status;
self.updated_at = SystemTime::now();
self.execution_log.push( event );
}
pub fn update_metrics( &mut self, metrics : TrainingMetrics )
{
let event = TuningEvent
{
event_type : TuningEventType::MetricsUpdated
{
step : metrics.step,
loss : metrics.training_loss,
},
message : format!( "Metrics updated at step {}", metrics.step ),
timestamp : SystemTime::now(),
};
self.current_metrics = Some( metrics );
self.updated_at = SystemTime::now();
self.execution_log.push( event );
}
pub fn add_checkpoint( &mut self, checkpoint : ModelCheckpoint )
{
let event = TuningEvent
{
event_type : TuningEventType::CheckpointSaved
{
checkpoint_id : checkpoint.checkpoint_id.clone(),
step : checkpoint.step,
},
message : format!( "Checkpoint saved at step {}", checkpoint.step ),
timestamp : SystemTime::now(),
};
self.checkpoints.push( checkpoint );
self.updated_at = SystemTime::now();
self.execution_log.push( event );
}
#[ must_use ]
pub fn duration( &self ) -> Duration
{
self.updated_at.duration_since( self.created_at ).unwrap_or( Duration::from_secs( 0 ) )
}
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TuningEvent
{
pub event_type : TuningEventType,
pub message : String,
pub timestamp : SystemTime,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub enum TuningEventType
{
StatusChanged
{
from : TuningStatus,
to : TuningStatus,
},
MetricsUpdated
{
step : u64,
loss : f64,
},
CheckpointSaved
{
checkpoint_id : String,
step : u64,
},
}
#[ derive( Debug ) ]
pub struct TuningManager
{
jobs : HashMap< String, TuningJob >,
}
impl TuningManager
{
#[ must_use ]
pub fn new() -> Self
{
Self
{
jobs : HashMap::new(),
}
}
pub async fn create_job( &mut self, config : TuningJobConfig ) -> Result< String, String >
{
let job_name = config.job_name.clone();
if self.jobs.contains_key( &job_name )
{
return Err( format!( "Job with name '{job_name}' already exists" ) );
}
let job = TuningJob::new( config );
self.jobs.insert( job_name.clone(), job );
Ok( job_name )
}
pub async fn get_job( &self, job_name : &str ) -> Option< &TuningJob >
{
self.jobs.get( job_name )
}
pub async fn get_job_mut( &mut self, job_name : &str ) -> Option< &mut TuningJob >
{
self.jobs.get_mut( job_name )
}
pub async fn list_jobs( &self ) -> Vec< &TuningJob >
{
self.jobs.values().collect()
}
pub async fn update_job_status( &mut self, job_name : &str, status : TuningStatus ) -> Result< (), String >
{
match self.jobs.get_mut( job_name )
{
Some( job ) =>
{
job.update_status( status );
Ok( () )
}
None => Err( format!( "Job '{job_name}' not found" ) ),
}
}
pub async fn cancel_job( &mut self, job_name : &str ) -> Result< (), String >
{
self.update_job_status( job_name, TuningStatus::Cancelled ).await
}
pub async fn delete_job( &mut self, job_name : &str ) -> Result< (), String >
{
match self.jobs.remove( job_name )
{
Some( _ ) => Ok( () ),
None => Err( format!( "Job '{job_name}' not found" ) ),
}
}
#[ must_use ]
pub fn tuning_stats( &self ) -> TuningStats
{
let mut stats = TuningStats
{
total : self.jobs.len(),
validating : 0,
queued : 0,
running : 0,
succeeded : 0,
failed : 0,
cancelled : 0,
};
for job in self.jobs.values()
{
match job.status
{
TuningStatus::Validating => stats.validating += 1,
TuningStatus::Queued => stats.queued += 1,
TuningStatus::Running => stats.running += 1,
TuningStatus::Succeeded => stats.succeeded += 1,
TuningStatus::Failed( _ ) => stats.failed += 1,
TuningStatus::Cancelled => stats.cancelled += 1,
}
}
stats
}
}
impl Default for TuningManager
{
#[ inline ]
fn default() -> Self
{
Self::new()
}
}
#[ derive( Debug, Clone ) ]
pub struct TuningStats
{
pub total : usize,
pub validating : usize,
pub queued : usize,
pub running : usize,
pub succeeded : usize,
pub failed : usize,
pub cancelled : usize,
}
#[ derive( Debug, Clone ) ]
pub struct TuningNotification
{
pub job_name : String,
pub status : TuningStatus,
pub metrics : Option< TrainingMetrics >,
pub timestamp : SystemTime,
}
#[ derive( Debug ) ]
pub struct TuningEventSender
{
pub sender : mpsc::UnboundedSender< TuningNotification >,
}
#[ derive( Debug ) ]
pub struct TuningEventReceiver
{
pub receiver : mpsc::UnboundedReceiver< TuningNotification >,
}
#[ derive( Debug ) ]
pub struct ModelTuningUtils;
impl ModelTuningUtils
{
#[ must_use ]
pub fn create_event_notifier() -> ( TuningEventSender, TuningEventReceiver )
{
let ( tx, rx ) = mpsc::unbounded_channel();
( TuningEventSender { sender : tx }, TuningEventReceiver { receiver : rx } )
}
pub fn validate_config( config : &TuningJobConfig ) -> Result< (), Vec< String > >
{
let mut errors = Vec::new();
if config.job_name.is_empty()
{
errors.push( "Job name cannot be empty".to_string() );
}
if config.base_model.is_empty()
{
errors.push( "Base model cannot be empty".to_string() );
}
if config.training_data.training_file.is_empty()
{
errors.push( "Training file cannot be empty".to_string() );
}
if config.hyperparameters.learning_rate <= 0.0
{
errors.push( "Learning rate must be positive".to_string() );
}
if config.hyperparameters.batch_size == 0
{
errors.push( "Batch size must be positive".to_string() );
}
if config.hyperparameters.epochs == 0
{
errors.push( "Number of epochs must be positive".to_string() );
}
if errors.is_empty()
{
Ok( () )
}
else
{
Err( errors )
}
}
#[ must_use ]
pub fn estimate_training_time( config : &TuningJobConfig, dataset_size : u64 ) -> Duration
{
let tokens_per_epoch = dataset_size;
let total_tokens = tokens_per_epoch * u64::from( config.hyperparameters.epochs );
let tokens_per_second = 1000;
let total_seconds = total_tokens / tokens_per_second;
Duration::from_secs( total_seconds )
}
#[ must_use ]
pub fn estimate_training_cost( config : &TuningJobConfig, duration : Duration ) -> f64
{
let hours = duration.as_secs_f64() / 3600.0;
let gpu_cost_per_hour = match config.resource_requirements.gpu_type.as_deref()
{
Some( "A100" ) => 4.0,
Some( "V100" ) => 2.0,
Some( "T4" ) => 0.5,
_ => 1.0,
};
hours * gpu_cost_per_hour * f64::from( config.resource_requirements.gpu_count )
}
#[ must_use ]
pub fn suggest_hyperparameters( method : &FineTuningMethod, dataset_size : u64 ) -> HyperParameters
{
let mut params = HyperParameters::default();
match method
{
FineTuningMethod::Full =>
{
params.learning_rate = if dataset_size > 100_000 { 1e-5 } else { 5e-5 };
params.batch_size = if dataset_size > 50000 { 16 } else { 32 };
}
FineTuningMethod::LoRA { .. } =>
{
params.learning_rate = 1e-4;
params.batch_size = 64;
params.epochs = 5;
}
FineTuningMethod::Adapter { .. } =>
{
params.learning_rate = 5e-4;
params.batch_size = 32;
params.epochs = 10;
}
FineTuningMethod::PrefixTuning { .. } =>
{
params.learning_rate = 1e-3;
params.batch_size = 16;
params.epochs = 20;
}
}
params
}
}
}
crate ::mod_interface!
{
exposed use private::TuningStatus;
exposed use private::TrainingObjective;
exposed use private::FineTuningMethod;
exposed use private::HyperParameters;
exposed use private::TrainingDataConfig;
exposed use private::ModelCheckpoint;
exposed use private::TrainingMetrics;
exposed use private::TuningJobConfig;
exposed use private::TuningResourceRequirements;
exposed use private::CheckpointConfig;
exposed use private::TuningJob;
exposed use private::TuningEvent;
exposed use private::TuningEventType;
exposed use private::TuningManager;
exposed use private::TuningStats;
exposed use private::TuningNotification;
exposed use private::TuningEventSender;
exposed use private::TuningEventReceiver;
exposed use private::ModelTuningUtils;
}