mod private
{
use serde::{ Deserialize, Serialize };
use std::time::{ Duration, SystemTime };
use std::sync::{ Arc, Mutex };
use std::collections::HashMap;
use tokio::sync::broadcast;
#[ derive( Debug, Clone, PartialEq, Eq, Serialize, Deserialize ) ]
pub enum TrainingJobState
{
Pending,
Running,
Completed,
Cancelled,
Failed,
Paused,
}
#[ derive( Debug, Clone ) ]
pub struct HyperparameterConfig
{
pub learning_rate : f64,
pub batch_size : usize,
pub epochs : usize,
pub warmup_steps : usize,
pub weight_decay : f64,
pub gradient_clip_norm : f64,
pub scheduler : String,
pub optimizer : String,
}
impl Default for HyperparameterConfig
{
fn default() -> Self
{
Self {
learning_rate : 0.0001,
batch_size : 16,
epochs : 3,
warmup_steps : 50,
weight_decay : 0.01,
gradient_clip_norm : 1.0,
scheduler : "cosine".to_string(),
optimizer : "adamw".to_string(),
}
}
}
#[ derive( Debug, Clone ) ]
pub struct HyperparameterConfigBuilder
{
config : HyperparameterConfig,
}
impl HyperparameterConfigBuilder
{
pub fn new() -> Self
{
Self {
config : HyperparameterConfig::default(),
}
}
pub fn learning_rate( mut self, rate : f64 ) -> Self
{
self.config.learning_rate = rate;
self
}
pub fn batch_size( mut self, size : usize ) -> Self
{
self.config.batch_size = size;
self
}
pub fn epochs( mut self, epochs : usize ) -> Self
{
self.config.epochs = epochs;
self
}
pub fn warmup_steps( mut self, steps : usize ) -> Self
{
self.config.warmup_steps = steps;
self
}
pub fn weight_decay( mut self, decay : f64 ) -> Self
{
self.config.weight_decay = decay;
self
}
pub fn gradient_clip_norm( mut self, norm : f64 ) -> Self
{
self.config.gradient_clip_norm = norm;
self
}
pub fn scheduler( mut self, scheduler : &str ) -> Self
{
self.config.scheduler = scheduler.to_string();
self
}
pub fn optimizer( mut self, optimizer : &str ) -> Self
{
self.config.optimizer = optimizer.to_string();
self
}
pub fn build( self ) -> Result< HyperparameterConfig, crate::error::Error >
{
if self.config.learning_rate <= 0.0
{
return Err( crate::error::Error::ConfigurationError(
"Learning rate must be positive".to_string()
) );
}
if self.config.batch_size == 0
{
return Err( crate::error::Error::ConfigurationError(
"Batch size must be greater than 0".to_string()
) );
}
if self.config.epochs == 0
{
return Err( crate::error::Error::ConfigurationError(
"Epochs must be greater than 0".to_string()
) );
}
if self.config.weight_decay < 0.0
{
return Err( crate::error::Error::ConfigurationError(
"Weight decay must be non-negative".to_string()
) );
}
if self.config.gradient_clip_norm <= 0.0
{
return Err( crate::error::Error::ConfigurationError(
"Gradient clip norm must be positive".to_string()
) );
}
Ok( self.config )
}
}
impl HyperparameterConfig
{
pub fn builder() -> HyperparameterConfigBuilder
{
HyperparameterConfigBuilder::new()
}
}
#[ derive( Debug, Clone ) ]
pub struct LoRAConfig
{
pub rank : usize,
pub alpha : f64,
pub dropout : f64,
pub target_modules : Vec< String >,
pub merge_weights : bool,
}
impl Default for LoRAConfig
{
fn default() -> Self
{
Self {
rank : 8,
alpha : 16.0,
dropout : 0.1,
target_modules : vec![ "query".to_string(), "value".to_string() ],
merge_weights : false,
}
}
}
#[ derive( Debug, Clone ) ]
pub struct LoRAConfigBuilder
{
config : LoRAConfig,
}
impl LoRAConfigBuilder
{
pub fn new() -> Self
{
Self {
config : LoRAConfig::default(),
}
}
pub fn rank( mut self, rank : usize ) -> Self
{
self.config.rank = rank;
self
}
pub fn alpha( mut self, alpha : f64 ) -> Self
{
self.config.alpha = alpha;
self
}
pub fn dropout( mut self, dropout : f64 ) -> Self
{
self.config.dropout = dropout;
self
}
pub fn target_modules( mut self, modules : Vec< String > ) -> Self
{
self.config.target_modules = modules;
self
}
pub fn merge_weights( mut self, merge : bool ) -> Self
{
self.config.merge_weights = merge;
self
}
pub fn build( self ) -> Result< LoRAConfig, crate::error::Error >
{
if self.config.rank == 0
{
return Err( crate::error::Error::ConfigurationError(
"LoRA rank must be greater than 0".to_string()
) );
}
if self.config.alpha <= 0.0
{
return Err( crate::error::Error::ConfigurationError(
"LoRA alpha must be positive".to_string()
) );
}
if self.config.dropout < 0.0 || self.config.dropout >= 1.0
{
return Err( crate::error::Error::ConfigurationError(
"LoRA dropout must be between 0.0 and 1.0".to_string()
) );
}
if self.config.target_modules.is_empty()
{
return Err( crate::error::Error::ConfigurationError(
"LoRA target modules cannot be empty".to_string()
) );
}
Ok( self.config )
}
}
impl LoRAConfig
{
pub fn builder() -> LoRAConfigBuilder
{
LoRAConfigBuilder::new()
}
}
#[ derive( Debug, Clone ) ]
pub enum TrainingObjective
{
Completion {
max_sequence_length : usize,
temperature : f64,
},
Classification {
num_classes : usize,
label_smoothing : f64,
},
Seq2Seq {
max_input_length : usize,
max_output_length : usize,
},
}
#[ derive( Debug, Clone ) ]
pub struct TrainingMetrics
{
pub epoch : usize,
pub step : usize,
pub loss : f64,
pub learning_rate : f64,
pub gradient_norm : f64,
pub throughput_tokens_per_second : f64,
pub memory_usage_mb : f64,
pub elapsed_time_seconds : f64,
}
impl Default for TrainingMetrics
{
fn default() -> Self
{
Self {
epoch : 0,
step : 0,
loss : 0.0,
learning_rate : 0.0,
gradient_norm : 0.0,
throughput_tokens_per_second : 0.0,
memory_usage_mb : 0.0,
elapsed_time_seconds : 0.0,
}
}
}
#[ derive( Debug, Clone ) ]
pub struct ModelCheckpoint
{
pub checkpoint_id : String,
pub epoch : usize,
pub step : usize,
pub loss : f64,
pub metrics : HashMap< String, f64 >,
pub model_path : String,
pub created_at : SystemTime,
}
#[ derive( Debug, Clone ) ]
pub struct TrainingProgress
{
pub percentage : f64,
pub metrics : TrainingMetrics,
pub estimated_time_remaining : Option< Duration >,
}
pub struct TrainingJob
{
pub job_id : String,
state : Arc< Mutex< TrainingJobState > >,
config : HyperparameterConfig,
metrics : Arc< Mutex< TrainingMetrics > >,
progress_tx : broadcast::Sender< TrainingProgress >,
checkpoints : Arc< Mutex< Vec< ModelCheckpoint > > >,
}
impl TrainingJob
{
pub fn new( job_id : String, config : HyperparameterConfig ) -> Self
{
let ( progress_tx, _progress_rx ) = broadcast::channel( 16 );
Self {
job_id,
state : Arc::new( Mutex::new( TrainingJobState::Pending ) ),
config,
metrics : Arc::new( Mutex::new( TrainingMetrics::default() ) ),
progress_tx,
checkpoints : Arc::new( Mutex::new( Vec::new() ) ),
}
}
pub fn state( &self ) -> TrainingJobState
{
self.state.lock().unwrap().clone()
}
pub fn get_metrics( &self ) -> TrainingMetrics
{
self.metrics.lock().unwrap().clone()
}
pub fn subscribe_progress( &self ) -> broadcast::Receiver< TrainingProgress >
{
self.progress_tx.subscribe()
}
pub fn get_checkpoints( &self ) -> Vec< ModelCheckpoint >
{
self.checkpoints.lock().unwrap().clone()
}
pub async fn start( &self ) -> Result< (), crate::error::Error >
{
*self.state.lock().unwrap() = TrainingJobState::Running;
Ok( () )
}
pub async fn pause( &self ) -> Result< (), crate::error::Error >
{
*self.state.lock().unwrap() = TrainingJobState::Paused;
Ok( () )
}
pub async fn cancel( &self ) -> Result< (), crate::error::Error >
{
*self.state.lock().unwrap() = TrainingJobState::Cancelled;
Ok( () )
}
pub async fn resume( &self ) -> Result< (), crate::error::Error >
{
let current_state = self.state();
if current_state != TrainingJobState::Paused
{
return Err( crate::error::Error::ApiError(
format!( "Cannot resume job in state : {:?}", current_state )
) );
}
*self.state.lock().unwrap() = TrainingJobState::Running;
Ok( () )
}
}
impl std::fmt::Debug for TrainingJob
{
fn fmt( &self, f : &mut std::fmt::Formatter< '_ > ) -> std::fmt::Result
{
f.debug_struct( "TrainingJob" )
.field( "job_id", &self.job_id )
.field( "state", &self.state() )
.field( "config", &self.config )
.field( "metrics", &self.get_metrics() )
.finish_non_exhaustive()
}
}
pub struct FineTuningBuilder< 'a >
{
#[ allow( dead_code ) ]
model : &'a crate::models::api::ModelApi< 'a >,
training_data : Option< String >,
validation_data : Option< String >,
hyperparams : HyperparameterConfig,
lora_config : Option< LoRAConfig >,
objective : Option< TrainingObjective >,
monitoring_interval : Option< Duration >,
validate_data : bool,
checkpoint_frequency : usize,
progress_callback : Option< Box< dyn Fn( TrainingProgress ) + Send + Sync > >,
}
impl< 'a > std::fmt::Debug for FineTuningBuilder< 'a >
{
fn fmt( &self, f : &mut std::fmt::Formatter< '_ > ) -> std::fmt::Result
{
f.debug_struct( "FineTuningBuilder" )
.field( "training_data", &self.training_data )
.field( "validation_data", &self.validation_data )
.field( "hyperparams", &self.hyperparams )
.field( "lora_config", &self.lora_config )
.field( "objective", &self.objective )
.field( "monitoring_interval", &self.monitoring_interval )
.field( "validate_data", &self.validate_data )
.field( "checkpoint_frequency", &self.checkpoint_frequency )
.field( "progress_callback", &self.progress_callback.is_some() )
.finish_non_exhaustive()
}
}
impl< 'a > FineTuningBuilder< 'a >
{
pub fn new( model : &'a crate::models::api::ModelApi< 'a > ) -> Self
{
Self {
model,
training_data : None,
validation_data : None,
hyperparams : HyperparameterConfig::default(),
lora_config : None,
objective : None,
monitoring_interval : None,
validate_data : false,
checkpoint_frequency : 1000,
progress_callback : None,
}
}
pub fn with_training_data( mut self, path : &str ) -> Self
{
self.training_data = Some( path.to_string() );
self
}
pub fn with_validation_data( mut self, path : &str ) -> Self
{
self.validation_data = Some( path.to_string() );
self
}
pub fn with_epochs( mut self, epochs : usize ) -> Self
{
self.hyperparams.epochs = epochs;
self
}
pub fn with_learning_rate( mut self, rate : f64 ) -> Self
{
self.hyperparams.learning_rate = rate;
self
}
pub fn with_batch_size( mut self, size : usize ) -> Self
{
self.hyperparams.batch_size = size;
self
}
pub fn with_hyperparams( mut self, config : HyperparameterConfig ) -> Self
{
self.hyperparams = config;
self
}
pub fn with_lora_config( mut self, config : LoRAConfig ) -> Self
{
self.lora_config = Some( config );
self
}
pub fn with_objective( mut self, objective : TrainingObjective ) -> Self
{
self.objective = Some( objective );
self
}
pub fn with_monitoring_interval( mut self, interval : Duration ) -> Self
{
self.monitoring_interval = Some( interval );
self
}
pub fn validate_data( mut self, validate : bool ) -> Self
{
self.validate_data = validate;
self
}
pub fn with_checkpoint_frequency( mut self, frequency : usize ) -> Self
{
self.checkpoint_frequency = frequency;
self
}
pub fn with_progress_callback< F >( mut self, callback : F ) -> Self
where
F: Fn( TrainingProgress ) + Send + Sync + 'static,
{
self.progress_callback = Some( Box::new( callback ) );
self
}
pub async fn start_training( self ) -> Result< TrainingJob, crate::error::Error >
{
if self.training_data.is_none()
{
return Err( crate::error::Error::ApiError(
"Training data is required for fine-tuning".to_string()
) );
}
let job_id = format!( "tuning-{}", "generated-id" ); let job = TrainingJob::new( job_id, self.hyperparams );
job.start().await?;
Ok( job )
}
}
}
::mod_interface::mod_interface!
{
exposed use private::TrainingJobState;
exposed use private::HyperparameterConfig;
exposed use private::HyperparameterConfigBuilder;
exposed use private::LoRAConfig;
exposed use private::LoRAConfigBuilder;
exposed use private::TrainingObjective;
exposed use private::TrainingMetrics;
exposed use private::ModelCheckpoint;
exposed use private::TrainingProgress;
exposed use private::TrainingJob;
exposed use private::FineTuningBuilder;
}