use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct TransferLearningConfig {
pub pretrained_model: Option<PathBuf>,
pub frozen_layers: Vec<String>,
pub fine_tune_lr: f64,
pub feature_extraction_only: bool,
pub fine_tune_epochs: usize,
pub batch_size: usize,
pub early_stopping_patience: usize,
pub validation_split: f64,
}
impl Default for TransferLearningConfig {
fn default() -> Self {
Self {
pretrained_model: None,
frozen_layers: Vec::new(),
fine_tune_lr: 0.001,
feature_extraction_only: false,
fine_tune_epochs: 10,
batch_size: 32,
early_stopping_patience: 3,
validation_split: 0.2,
}
}
}
impl TransferLearningConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_pretrained(mut self, path: impl Into<PathBuf>) -> Self {
self.pretrained_model = Some(path.into());
self
}
#[must_use]
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.fine_tune_lr = lr;
self
}
#[must_use]
pub fn with_feature_extraction_only(mut self, enabled: bool) -> Self {
self.feature_extraction_only = enabled;
self
}
#[must_use]
pub fn freeze_layer(mut self, layer: impl Into<String>) -> Self {
self.frozen_layers.push(layer.into());
self
}
#[must_use]
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.fine_tune_epochs = epochs;
self
}
#[must_use]
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
#[must_use]
pub fn with_early_stopping(mut self, patience: usize) -> Self {
self.early_stopping_patience = patience;
self
}
#[must_use]
pub fn with_validation_split(mut self, split: f64) -> Self {
self.validation_split = split.clamp(0.0, 0.5);
self
}
#[must_use]
pub fn is_configured(&self) -> bool {
self.pretrained_model.is_some()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TransferStatus {
Disabled,
PretrainedLoaded {
model_path: PathBuf,
frozen_layers: usize,
},
FineTuning {
epoch: usize,
total_epochs: usize,
loss: f64,
},
Complete {
accuracy: f64,
epochs_trained: usize,
},
Error {
message: String,
},
}
impl TransferStatus {
#[must_use]
pub fn is_active(&self) -> bool {
matches!(
self,
TransferStatus::PretrainedLoaded { .. } | TransferStatus::FineTuning { .. }
)
}
#[must_use]
pub fn is_complete(&self) -> bool {
matches!(self, TransferStatus::Complete { .. })
}
#[must_use]
pub fn is_error(&self) -> bool {
matches!(self, TransferStatus::Error { .. })
}
}
#[derive(Debug)]
pub struct TransferLearner {
config: TransferLearningConfig,
status: TransferStatus,
history: Vec<(f64, f64)>,
best_val_accuracy: f64,
epochs_without_improvement: usize,
}
impl TransferLearner {
#[must_use]
pub fn new(config: TransferLearningConfig) -> Self {
let status = if config.is_configured() {
TransferStatus::Disabled } else {
TransferStatus::Disabled
};
Self {
config,
status,
history: Vec::new(),
best_val_accuracy: 0.0,
epochs_without_improvement: 0,
}
}
#[must_use]
pub fn status(&self) -> &TransferStatus {
&self.status
}
#[must_use]
pub fn config(&self) -> &TransferLearningConfig {
&self.config
}
#[must_use]
pub fn history(&self) -> &[(f64, f64)] {
&self.history
}
pub fn load_pretrained(&mut self) -> Result<(), String> {
let Some(ref path) = self.config.pretrained_model else {
return Err("No pretrained model path configured".to_string());
};
if !path.exists() {
return Err(format!("Pretrained model not found: {}", path.display()));
}
self.status = TransferStatus::PretrainedLoaded {
model_path: path.clone(),
frozen_layers: self.config.frozen_layers.len(),
};
Ok(())
}
pub fn record_epoch(
&mut self,
epoch: usize,
train_loss: f64,
val_loss: f64,
val_accuracy: f64,
) {
self.history.push((train_loss, val_loss));
self.status = TransferStatus::FineTuning {
epoch,
total_epochs: self.config.fine_tune_epochs,
loss: train_loss,
};
if val_accuracy > self.best_val_accuracy {
self.best_val_accuracy = val_accuracy;
self.epochs_without_improvement = 0;
} else {
self.epochs_without_improvement += 1;
}
}
#[must_use]
pub fn should_early_stop(&self) -> bool {
self.epochs_without_improvement >= self.config.early_stopping_patience
}
pub fn mark_complete(&mut self, final_accuracy: f64) {
self.status = TransferStatus::Complete {
accuracy: final_accuracy,
epochs_trained: self.history.len(),
};
}
pub fn mark_error(&mut self, message: impl Into<String>) {
self.status = TransferStatus::Error {
message: message.into(),
};
}
pub fn reset(&mut self) {
self.status = TransferStatus::Disabled;
self.history.clear();
self.best_val_accuracy = 0.0;
self.epochs_without_improvement = 0;
}
}
impl Default for TransferLearner {
fn default() -> Self {
Self::new(TransferLearningConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_config_default() {
let config = TransferLearningConfig::default();
assert!(config.pretrained_model.is_none());
assert!(config.frozen_layers.is_empty());
assert!((config.fine_tune_lr - 0.001).abs() < f64::EPSILON);
assert!(!config.feature_extraction_only);
assert_eq!(config.fine_tune_epochs, 10);
}
#[test]
fn test_transfer_config_with_pretrained() {
let config = TransferLearningConfig::new().with_pretrained("/path/to/model.apr");
assert_eq!(
config.pretrained_model,
Some(PathBuf::from("/path/to/model.apr"))
);
}
#[test]
fn test_transfer_config_with_learning_rate() {
let config = TransferLearningConfig::new().with_learning_rate(0.0001);
assert!((config.fine_tune_lr - 0.0001).abs() < f64::EPSILON);
}
#[test]
fn test_transfer_config_with_feature_extraction_only() {
let config = TransferLearningConfig::new().with_feature_extraction_only(true);
assert!(config.feature_extraction_only);
}
#[test]
fn test_transfer_config_freeze_layer() {
let config = TransferLearningConfig::new()
.freeze_layer("encoder")
.freeze_layer("embedding");
assert_eq!(config.frozen_layers.len(), 2);
assert!(config.frozen_layers.contains(&"encoder".to_string()));
}
#[test]
fn test_transfer_config_with_epochs() {
let config = TransferLearningConfig::new().with_epochs(20);
assert_eq!(config.fine_tune_epochs, 20);
}
#[test]
fn test_transfer_config_with_batch_size() {
let config = TransferLearningConfig::new().with_batch_size(64);
assert_eq!(config.batch_size, 64);
}
#[test]
fn test_transfer_config_with_early_stopping() {
let config = TransferLearningConfig::new().with_early_stopping(5);
assert_eq!(config.early_stopping_patience, 5);
}
#[test]
fn test_transfer_config_with_validation_split() {
let config = TransferLearningConfig::new().with_validation_split(0.3);
assert!((config.validation_split - 0.3).abs() < f64::EPSILON);
}
#[test]
fn test_transfer_config_validation_split_clamped() {
let config = TransferLearningConfig::new().with_validation_split(0.8);
assert!((config.validation_split - 0.5).abs() < f64::EPSILON);
let config2 = TransferLearningConfig::new().with_validation_split(-0.1);
assert!((config2.validation_split - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_transfer_config_is_configured() {
let config1 = TransferLearningConfig::new();
assert!(!config1.is_configured());
let config2 = TransferLearningConfig::new().with_pretrained("/path/to/model.apr");
assert!(config2.is_configured());
}
#[test]
fn test_transfer_status_disabled() {
let status = TransferStatus::Disabled;
assert!(!status.is_active());
assert!(!status.is_complete());
assert!(!status.is_error());
}
#[test]
fn test_transfer_status_pretrained_loaded() {
let status = TransferStatus::PretrainedLoaded {
model_path: PathBuf::from("/path/to/model.apr"),
frozen_layers: 2,
};
assert!(status.is_active());
assert!(!status.is_complete());
}
#[test]
fn test_transfer_status_fine_tuning() {
let status = TransferStatus::FineTuning {
epoch: 5,
total_epochs: 10,
loss: 0.5,
};
assert!(status.is_active());
assert!(!status.is_complete());
}
#[test]
fn test_transfer_status_complete() {
let status = TransferStatus::Complete {
accuracy: 0.95,
epochs_trained: 10,
};
assert!(!status.is_active());
assert!(status.is_complete());
}
#[test]
fn test_transfer_status_error() {
let status = TransferStatus::Error {
message: "Model not found".to_string(),
};
assert!(!status.is_active());
assert!(!status.is_complete());
assert!(status.is_error());
}
#[test]
fn test_transfer_learner_new() {
let config = TransferLearningConfig::default();
let learner = TransferLearner::new(config);
assert_eq!(*learner.status(), TransferStatus::Disabled);
assert!(learner.history().is_empty());
}
#[test]
fn test_transfer_learner_load_pretrained_no_path() {
let config = TransferLearningConfig::default();
let mut learner = TransferLearner::new(config);
let result = learner.load_pretrained();
assert!(result.is_err());
}
#[test]
fn test_transfer_learner_load_pretrained_not_found() {
let config = TransferLearningConfig::new().with_pretrained("/nonexistent/model.apr");
let mut learner = TransferLearner::new(config);
let result = learner.load_pretrained();
assert!(result.is_err());
}
#[test]
fn test_transfer_learner_record_epoch() {
let config = TransferLearningConfig::default();
let mut learner = TransferLearner::new(config);
learner.record_epoch(1, 0.5, 0.4, 0.8);
assert_eq!(learner.history().len(), 1);
assert!(matches!(
learner.status(),
TransferStatus::FineTuning { epoch: 1, .. }
));
}
#[test]
fn test_transfer_learner_should_early_stop() {
let config = TransferLearningConfig::new().with_early_stopping(2);
let mut learner = TransferLearner::new(config);
learner.record_epoch(1, 0.5, 0.4, 0.8);
assert!(!learner.should_early_stop());
learner.record_epoch(2, 0.4, 0.35, 0.75);
assert!(!learner.should_early_stop());
learner.record_epoch(3, 0.35, 0.3, 0.7);
assert!(learner.should_early_stop());
}
#[test]
fn test_transfer_learner_mark_complete() {
let config = TransferLearningConfig::default();
let mut learner = TransferLearner::new(config);
learner.record_epoch(1, 0.5, 0.4, 0.8);
learner.mark_complete(0.95);
assert!(learner.status().is_complete());
if let TransferStatus::Complete {
accuracy,
epochs_trained,
} = learner.status()
{
assert!((*accuracy - 0.95).abs() < f64::EPSILON);
assert_eq!(*epochs_trained, 1);
}
}
#[test]
fn test_transfer_learner_mark_error() {
let config = TransferLearningConfig::default();
let mut learner = TransferLearner::new(config);
learner.mark_error("Training failed");
assert!(learner.status().is_error());
}
#[test]
fn test_transfer_learner_reset() {
let config = TransferLearningConfig::default();
let mut learner = TransferLearner::new(config);
learner.record_epoch(1, 0.5, 0.4, 0.8);
learner.reset();
assert_eq!(*learner.status(), TransferStatus::Disabled);
assert!(learner.history().is_empty());
}
#[test]
fn test_transfer_learner_default() {
let learner = TransferLearner::default();
assert_eq!(*learner.status(), TransferStatus::Disabled);
}
}