use super::*;
#[test]
fn test_fine_tune_method_lora() {
let lora = FineTuneMethod::lora(64);
match lora {
FineTuneMethod::LoRA { rank, alpha } => {
assert_eq!(rank, 64);
assert!((alpha - 64.0).abs() < f32::EPSILON);
}
_ => panic!("Expected LoRA"),
}
}
#[test]
fn test_fine_tune_method_qlora() {
let qlora = FineTuneMethod::qlora(32);
match qlora {
FineTuneMethod::QLoRA { rank, bits } => {
assert_eq!(rank, 32);
assert_eq!(bits, 4);
}
_ => panic!("Expected QLoRA"),
}
}
#[test]
fn test_fine_tune_method_memory_reduction() {
let full = FineTuneMethod::Full;
let lora = FineTuneMethod::lora(64);
let qlora = FineTuneMethod::qlora(32);
assert!((full.memory_reduction_factor() - 1.0).abs() < 0.01);
assert!(lora.memory_reduction_factor() > 1.0);
assert!(qlora.memory_reduction_factor() > lora.memory_reduction_factor());
}
#[test]
fn test_fine_tune_method_trainable_params() {
let full = FineTuneMethod::Full;
let lora = FineTuneMethod::lora(8);
let ia3 = FineTuneMethod::IA3;
assert!((full.trainable_params_percent() - 100.0).abs() < 0.01);
assert!(lora.trainable_params_percent() < 1.0);
assert!(ia3.trainable_params_percent() < lora.trainable_params_percent());
}
#[test]
fn test_fine_tune_method_display() {
assert_eq!(format!("{}", FineTuneMethod::lora(64)), "LoRA(r=64, α=64)");
assert_eq!(format!("{}", FineTuneMethod::qlora(32)), "QLoRA(r=32, 4-bit)");
assert_eq!(format!("{}", FineTuneMethod::IA3), "IA³");
}
#[test]
fn test_model_paradigm_lora() {
let paradigm = ModelParadigm::lora(64, 16.0);
match paradigm {
ModelParadigm::FineTuning(FineTuneMethod::LoRA { rank, alpha }) => {
assert_eq!(rank, 64);
assert!((alpha - 16.0).abs() < f32::EPSILON);
}
_ => panic!("Expected LoRA fine-tuning"),
}
}
#[test]
fn test_model_paradigm_memory_multiplier() {
let traditional = ModelParadigm::TraditionalMl;
let deep = ModelParadigm::DeepLearning;
let lora = ModelParadigm::lora(64, 64.0);
let distill = ModelParadigm::Distillation;
assert!(traditional.typical_memory_multiplier() < deep.typical_memory_multiplier());
assert!(lora.typical_memory_multiplier() < deep.typical_memory_multiplier());
assert!(distill.typical_memory_multiplier() > deep.typical_memory_multiplier());
}
#[test]
fn test_model_paradigm_training_speedup() {
let deep = ModelParadigm::DeepLearning;
let lora = ModelParadigm::lora(64, 64.0);
let traditional = ModelParadigm::TraditionalMl;
assert!((deep.typical_training_speedup() - 1.0).abs() < 0.01);
assert!(lora.typical_training_speedup() > deep.typical_training_speedup());
assert!(traditional.typical_training_speedup() > lora.typical_training_speedup());
}
#[test]
fn test_model_paradigm_quality_retention() {
let deep = ModelParadigm::DeepLearning;
let lora = ModelParadigm::lora(64, 64.0);
let distill = ModelParadigm::Distillation;
let ensemble = ModelParadigm::Ensemble;
assert!((deep.typical_quality_retention() - 1.0).abs() < 0.01);
assert!(lora.typical_quality_retention() > 0.9);
assert!(distill.typical_quality_retention() < 1.0);
assert!(ensemble.typical_quality_retention() > 1.0);
}
#[test]
fn test_model_paradigm_requires_pretrained() {
assert!(!ModelParadigm::TraditionalMl.requires_pretrained());
assert!(!ModelParadigm::DeepLearning.requires_pretrained());
assert!(ModelParadigm::lora(64, 64.0).requires_pretrained());
assert!(ModelParadigm::Distillation.requires_pretrained());
}
#[test]
fn test_model_paradigm_is_parameter_efficient() {
assert!(!ModelParadigm::DeepLearning.is_parameter_efficient());
assert!(!ModelParadigm::FineTuning(FineTuneMethod::Full).is_parameter_efficient());
assert!(ModelParadigm::lora(64, 64.0).is_parameter_efficient());
assert!(ModelParadigm::qlora(32, 4).is_parameter_efficient());
}
#[test]
fn test_model_paradigm_batch_size_multiplier() {
let deep = ModelParadigm::DeepLearning;
let qlora = ModelParadigm::qlora(32, 4);
let distill = ModelParadigm::Distillation;
assert!((deep.batch_size_multiplier() - 1.0).abs() < 0.01);
assert!(qlora.batch_size_multiplier() > deep.batch_size_multiplier());
assert!(distill.batch_size_multiplier() < deep.batch_size_multiplier());
}
#[test]
fn test_model_paradigm_default() {
let default = ModelParadigm::default();
assert!(matches!(default, ModelParadigm::DeepLearning));
}
#[test]
fn test_model_paradigm_display() {
assert_eq!(format!("{}", ModelParadigm::TraditionalMl), "Traditional ML");
assert_eq!(format!("{}", ModelParadigm::DeepLearning), "Deep Learning");
assert_eq!(format!("{}", ModelParadigm::MoE), "Mixture of Experts");
assert!(format!("{}", ModelParadigm::lora(64, 64.0)).contains("LoRA"));
}
#[test]
fn test_model_paradigm_serialization() {
let paradigm = ModelParadigm::lora(32, 16.0);
let json = serde_json::to_string(¶digm).expect("JSON serialization should succeed");
let parsed: ModelParadigm =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert!(parsed.is_parameter_efficient());
}
#[test]
fn test_fine_tune_method_display_all_variants() {
assert_eq!(format!("{}", FineTuneMethod::Adapter), "Adapter");
assert_eq!(format!("{}", FineTuneMethod::Prefix), "Prefix");
assert_eq!(format!("{}", FineTuneMethod::Full), "Full");
}
#[test]
fn test_fine_tune_method_memory_reduction_all_variants() {
assert!((FineTuneMethod::Adapter.memory_reduction_factor() - 10.0).abs() < 0.01);
assert!((FineTuneMethod::Prefix.memory_reduction_factor() - 20.0).abs() < 0.01);
assert!((FineTuneMethod::IA3.memory_reduction_factor() - 50.0).abs() < 0.01);
assert!((FineTuneMethod::Full.memory_reduction_factor() - 1.0).abs() < 0.01);
}
#[test]
fn test_fine_tune_method_trainable_params_all_variants() {
assert!((FineTuneMethod::Adapter.trainable_params_percent() - 5.0).abs() < 0.01);
assert!((FineTuneMethod::Prefix.trainable_params_percent() - 1.0).abs() < 0.01);
assert!((FineTuneMethod::IA3.trainable_params_percent() - 0.01).abs() < 0.001);
assert!((FineTuneMethod::Full.trainable_params_percent() - 100.0).abs() < 0.01);
}
#[test]
fn test_fine_tune_method_qlora_memory_with_different_bits() {
let qlora_4bit = FineTuneMethod::QLoRA { rank: 32, bits: 4 };
let qlora_8bit = FineTuneMethod::QLoRA { rank: 32, bits: 8 };
assert!(qlora_4bit.memory_reduction_factor() > qlora_8bit.memory_reduction_factor());
}
#[test]
fn test_fine_tune_method_lora_rank_zero() {
let lora = FineTuneMethod::LoRA { rank: 0, alpha: 0.0 };
let reduction = lora.memory_reduction_factor();
assert!((reduction - 100.0).abs() < 0.01); }
#[test]
fn test_fine_tune_method_high_rank_trainable_params() {
let lora_high = FineTuneMethod::LoRA { rank: 256, alpha: 256.0 };
let lora_low = FineTuneMethod::LoRA { rank: 8, alpha: 8.0 };
assert!((lora_high.trainable_params_percent() - 0.2).abs() < 0.01);
assert!((lora_low.trainable_params_percent() - 0.1).abs() < 0.01);
}
#[test]
fn test_model_paradigm_typical_memory_multiplier_all_finetuning() {
let full = ModelParadigm::FineTuning(FineTuneMethod::Full);
let adapter = ModelParadigm::FineTuning(FineTuneMethod::Adapter);
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
let ia3 = ModelParadigm::FineTuning(FineTuneMethod::IA3);
assert!((full.typical_memory_multiplier() - 4.0).abs() < 0.01);
assert!((adapter.typical_memory_multiplier() - 1.5).abs() < 0.01);
assert!((prefix.typical_memory_multiplier() - 1.3).abs() < 0.01);
assert!((ia3.typical_memory_multiplier() - 1.1).abs() < 0.01);
}
#[test]
fn test_model_paradigm_typical_training_speedup_all_finetuning() {
let full = ModelParadigm::FineTuning(FineTuneMethod::Full);
let qlora = ModelParadigm::FineTuning(FineTuneMethod::qlora(32));
let adapter = ModelParadigm::FineTuning(FineTuneMethod::Adapter);
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
let ia3 = ModelParadigm::FineTuning(FineTuneMethod::IA3);
assert!((full.typical_training_speedup() - 2.0).abs() < 0.01);
assert!((qlora.typical_training_speedup() - 6.0).abs() < 0.01);
assert!((adapter.typical_training_speedup() - 4.0).abs() < 0.01);
assert!((prefix.typical_training_speedup() - 5.0).abs() < 0.01);
assert!((ia3.typical_training_speedup() - 8.0).abs() < 0.01);
}
#[test]
fn test_model_paradigm_lora_training_speedup_varies_with_rank() {
let lora_low = ModelParadigm::lora(8, 8.0);
let lora_high = ModelParadigm::lora(128, 128.0);
assert!(lora_low.typical_training_speedup() > lora_high.typical_training_speedup());
}
#[test]
fn test_model_paradigm_typical_quality_retention_all_finetuning() {
let full = ModelParadigm::FineTuning(FineTuneMethod::Full);
let qlora = ModelParadigm::FineTuning(FineTuneMethod::qlora(32));
let adapter = ModelParadigm::FineTuning(FineTuneMethod::Adapter);
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
let ia3 = ModelParadigm::FineTuning(FineTuneMethod::IA3);
assert!((full.typical_quality_retention() - 1.0).abs() < 0.01);
assert!((qlora.typical_quality_retention() - 0.93).abs() < 0.01);
assert!((adapter.typical_quality_retention() - 0.92).abs() < 0.01);
assert!((prefix.typical_quality_retention() - 0.88).abs() < 0.01);
assert!((ia3.typical_quality_retention() - 0.90).abs() < 0.01);
}
#[test]
fn test_model_paradigm_lora_quality_varies_with_rank() {
let lora_low = ModelParadigm::lora(8, 8.0);
let lora_high = ModelParadigm::lora(64, 64.0);
assert!(lora_high.typical_quality_retention() > lora_low.typical_quality_retention());
}
#[test]
fn test_model_paradigm_batch_size_multiplier_all_finetuning() {
let full = ModelParadigm::FineTuning(FineTuneMethod::Full);
let adapter = ModelParadigm::FineTuning(FineTuneMethod::Adapter);
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
let ia3 = ModelParadigm::FineTuning(FineTuneMethod::IA3);
assert!((full.batch_size_multiplier() - 1.0).abs() < 0.01);
assert!((adapter.batch_size_multiplier() - 1.5).abs() < 0.01);
assert!((prefix.batch_size_multiplier() - 1.8).abs() < 0.01);
assert!((ia3.batch_size_multiplier() - 3.0).abs() < 0.01);
}
#[test]
fn test_model_paradigm_moe_characteristics() {
let moe = ModelParadigm::MoE;
assert!((moe.typical_memory_multiplier() - 2.0).abs() < 0.01);
assert!((moe.typical_training_speedup() - 0.8).abs() < 0.01);
assert!(moe.typical_quality_retention() > 1.0); assert!((moe.batch_size_multiplier() - 1.2).abs() < 0.01);
assert!(!moe.requires_pretrained());
assert!(!moe.is_parameter_efficient());
}
#[test]
fn test_model_paradigm_ensemble_characteristics() {
let ensemble = ModelParadigm::Ensemble;
assert!((ensemble.typical_memory_multiplier() - 3.0).abs() < 0.01);
assert!((ensemble.typical_training_speedup() - 0.5).abs() < 0.01);
assert!(ensemble.typical_quality_retention() > 1.0); assert!((ensemble.batch_size_multiplier() - 0.3).abs() < 0.01);
assert!(!ensemble.requires_pretrained());
assert!(!ensemble.is_parameter_efficient());
}
#[test]
fn test_model_paradigm_distillation_characteristics() {
let distill = ModelParadigm::Distillation;
assert!((distill.typical_memory_multiplier() - 5.0).abs() < 0.01);
assert!((distill.typical_training_speedup() - 1.5).abs() < 0.01);
assert!((distill.typical_quality_retention() - 0.85).abs() < 0.01);
assert!((distill.batch_size_multiplier() - 0.5).abs() < 0.01);
assert!(distill.requires_pretrained());
assert!(!distill.is_parameter_efficient());
}
#[test]
fn test_model_paradigm_traditional_ml_characteristics() {
let trad = ModelParadigm::TraditionalMl;
assert!((trad.typical_memory_multiplier() - 1.5).abs() < 0.01);
assert!((trad.typical_training_speedup() - 10.0).abs() < 0.01);
assert!((trad.typical_quality_retention() - 0.7).abs() < 0.01);
assert!((trad.batch_size_multiplier() - 10.0).abs() < 0.01);
assert!(!trad.requires_pretrained());
assert!(!trad.is_parameter_efficient());
}
#[test]
fn test_model_paradigm_display_all_variants() {
assert_eq!(format!("{}", ModelParadigm::Distillation), "Knowledge Distillation");
assert_eq!(format!("{}", ModelParadigm::Ensemble), "Ensemble");
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
assert!(format!("{prefix}").contains("Prefix"));
}
#[test]
fn test_model_paradigm_serialization_all_variants() {
let variants = vec![
ModelParadigm::TraditionalMl,
ModelParadigm::DeepLearning,
ModelParadigm::Distillation,
ModelParadigm::MoE,
ModelParadigm::Ensemble,
ModelParadigm::FineTuning(FineTuneMethod::Adapter),
ModelParadigm::FineTuning(FineTuneMethod::Prefix),
ModelParadigm::FineTuning(FineTuneMethod::IA3),
ModelParadigm::FineTuning(FineTuneMethod::Full),
];
for paradigm in variants {
let json = serde_json::to_string(¶digm).expect("JSON serialization should succeed");
let parsed: ModelParadigm =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(paradigm, parsed);
}
}
#[test]
fn test_fine_tune_method_serde_roundtrip() {
let methods = vec![
FineTuneMethod::lora(64),
FineTuneMethod::qlora(32),
FineTuneMethod::Adapter,
FineTuneMethod::Prefix,
FineTuneMethod::IA3,
FineTuneMethod::Full,
];
for method in methods {
let json = serde_json::to_string(&method).expect("JSON serialization should succeed");
let parsed: FineTuneMethod =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(method, parsed);
}
}
#[test]
fn test_model_paradigm_is_parameter_efficient_adapter_prefix_ia3() {
let adapter = ModelParadigm::FineTuning(FineTuneMethod::Adapter);
let prefix = ModelParadigm::FineTuning(FineTuneMethod::Prefix);
let ia3 = ModelParadigm::FineTuning(FineTuneMethod::IA3);
assert!(adapter.is_parameter_efficient());
assert!(prefix.is_parameter_efficient());
assert!(ia3.is_parameter_efficient());
}
#[test]
fn test_fine_tune_method_qlora_trainable_params() {
let qlora = FineTuneMethod::qlora(16);
let expected = 0.1 * (16.0 / 8.0_f64).min(2.0);
assert!((qlora.trainable_params_percent() - expected).abs() < 0.01);
}