use super::instruct_corpus::InstructSample;
use super::instruct_pipeline::{InstructConfig, InstructPipeline, InstructStepResult};
use super::instruct_trainer::InstructEpochMetrics;
use crate::lora::LoRALayer;
use serde::Deserialize;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AdapterSchedule {
Synchronized,
#[default]
RoundRobin,
PriorityValLoss,
}
#[derive(Debug, Clone)]
pub struct AdapterConfig {
pub data_path: PathBuf,
pub checkpoint_dir: PathBuf,
pub instruct_config: InstructConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AdaptersConfigFile {
#[serde(rename = "adapter")]
pub adapters: Vec<AdapterEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AdapterEntry {
pub data: PathBuf,
pub checkpoint: PathBuf,
#[serde(default)]
pub label: Option<String>,
#[serde(default)]
pub rank: Option<usize>,
#[serde(default)]
pub learning_rate: Option<f32>,
#[serde(default)]
pub epochs: Option<usize>,
#[serde(default)]
pub max_seq_len: Option<usize>,
}
impl AdaptersConfigFile {
pub fn from_file(path: &Path) -> Result<Self, String> {
let contents = std::fs::read_to_string(path)
.map_err(|e| format!("failed to read {}: {e}", path.display()))?;
Self::from_toml(&contents)
}
pub fn from_toml(toml_str: &str) -> Result<Self, String> {
let config: Self =
toml::from_str(toml_str).map_err(|e| format!("failed to parse adapters TOML: {e}"))?;
if config.adapters.is_empty() {
return Err("adapters config must have at least one [[adapter]] entry".to_string());
}
Ok(config)
}
pub fn to_adapter_configs(&self, base: &InstructConfig) -> Vec<AdapterConfig> {
self.adapters
.iter()
.map(|entry| {
let mut config = base.clone();
if let Some(rank) = entry.rank {
config.lora_rank = rank;
config.lora_alpha = rank as f32 * 2.0;
}
if let Some(lr) = entry.learning_rate {
config.learning_rate = lr;
}
if let Some(epochs) = entry.epochs {
config.epochs = epochs;
}
if let Some(seq_len) = entry.max_seq_len {
config.max_seq_len = seq_len;
}
AdapterConfig {
data_path: entry.data.clone(),
checkpoint_dir: entry.checkpoint.clone(),
instruct_config: config,
}
})
.collect()
}
}
pub struct AdapterSlot {
pub lora_layers: Vec<LoRALayer>,
pub train_samples: Vec<InstructSample>,
pub val_samples: Vec<InstructSample>,
pub checkpoint_dir: PathBuf,
pub metrics: Vec<InstructEpochMetrics>,
pub config: InstructConfig,
pub cursor: usize,
pub best_val_loss: f32,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
pub(crate) optimizer_states: Option<Vec<crate::transformer::GpuLoraOptimizerState>>,
#[cfg(feature = "cuda")]
pub lora_step: u32,
}
pub struct MultiAdapterPipeline {
pub base_pipeline: InstructPipeline,
pub adapters: Vec<AdapterSlot>,
pub schedule: AdapterSchedule,
pub global_step: usize,
}
impl MultiAdapterPipeline {
pub fn new(base_pipeline: InstructPipeline, schedule: AdapterSchedule) -> Self {
Self { base_pipeline, adapters: Vec::new(), schedule, global_step: 0 }
}
pub fn add_adapter(
&mut self,
config: AdapterConfig,
train_samples: Vec<InstructSample>,
val_samples: Vec<InstructSample>,
) {
let model_config = &self.base_pipeline.model.config;
let lora_layers = InstructPipeline::build_lora_layers(
&self.base_pipeline.model,
model_config,
&config.instruct_config,
);
let slot = AdapterSlot {
lora_layers,
train_samples,
val_samples,
checkpoint_dir: config.checkpoint_dir,
metrics: Vec::new(),
config: config.instruct_config,
cursor: 0,
best_val_loss: f32::INFINITY,
#[cfg(feature = "cuda")]
optimizer_states: None,
#[cfg(feature = "cuda")]
lora_step: 0,
};
self.adapters.push(slot);
}
pub fn num_adapters(&self) -> usize {
self.adapters.len()
}
pub fn select_next_adapter(&self) -> Option<usize> {
if self.adapters.is_empty() {
return None;
}
match self.schedule {
AdapterSchedule::Synchronized => {
Some(0)
}
AdapterSchedule::RoundRobin => Some(self.global_step % self.adapters.len()),
AdapterSchedule::PriorityValLoss => {
self.adapters
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.best_val_loss
.partial_cmp(&b.best_val_loss)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
}
}
pub fn train_step_adapter(&mut self, adapter_idx: usize) -> Option<InstructStepResult> {
let slot = &mut self.adapters[adapter_idx];
if slot.cursor >= slot.train_samples.len() {
return None;
}
let sample = &slot.train_samples[slot.cursor];
slot.cursor += 1;
if !self.base_pipeline.has_tokenizer() {
return None;
}
let prompt_ids = self.base_pipeline.tokenize(&sample.instruction);
let response_ids = self.base_pipeline.tokenize(&sample.response);
if prompt_ids.is_empty() || response_ids.is_empty() {
return None;
}
std::mem::swap(&mut slot.lora_layers, &mut self.base_pipeline.lora_layers);
let result = self.base_pipeline.train_step(&prompt_ids, &response_ids);
std::mem::swap(&mut slot.lora_layers, &mut self.base_pipeline.lora_layers);
self.global_step += 1;
Some(result)
}
pub fn reset_epoch(&mut self, seed: u64) {
for (i, slot) in self.adapters.iter_mut().enumerate() {
slot.cursor = 0;
shuffle_samples(&mut slot.train_samples, seed.wrapping_add(i as u64));
}
}
pub fn all_exhausted(&self) -> bool {
self.adapters.iter().all(|s| s.cursor >= s.train_samples.len())
}
pub fn batch_train_step(&mut self) -> Vec<Option<InstructStepResult>> {
let n = self.adapters.len();
let mut results = vec![None; n];
match self.schedule {
AdapterSchedule::Synchronized => {
for i in 0..n {
results[i] = self.train_step_adapter(i);
}
}
AdapterSchedule::RoundRobin | AdapterSchedule::PriorityValLoss => {
if let Some(idx) = self.select_next_adapter() {
results[idx] = self.train_step_adapter(idx);
}
}
}
results
}
pub fn save_adapter_checkpoint(
&self,
adapter_idx: usize,
epoch: usize,
avg_loss: f32,
) -> Result<PathBuf, Box<dyn std::error::Error>> {
let slot = &self.adapters[adapter_idx];
let ckpt_dir = slot.checkpoint_dir.join(format!("epoch-{epoch}"));
std::fs::create_dir_all(&ckpt_dir)?;
let metadata = serde_json::json!({
"mode": "multi_adapter",
"adapter_index": adapter_idx,
"epoch": epoch,
"avg_loss": avg_loss,
"best_val_loss": slot.best_val_loss,
"lora_rank": slot.config.lora_rank,
"lora_alpha": slot.config.lora_alpha,
"train_samples": slot.train_samples.len(),
"global_step": self.global_step,
});
std::fs::write(ckpt_dir.join("metadata.json"), serde_json::to_string_pretty(&metadata)?)?;
save_adapter_lora_weights(&slot.lora_layers, &ckpt_dir)?;
Ok(ckpt_dir)
}
pub fn save_best_checkpoint(
&self,
adapter_idx: usize,
epoch: usize,
avg_loss: f32,
) -> Result<PathBuf, Box<dyn std::error::Error>> {
let slot = &self.adapters[adapter_idx];
let best_dir = slot.checkpoint_dir.join("best");
std::fs::create_dir_all(&best_dir)?;
let metadata = serde_json::json!({
"mode": "multi_adapter",
"adapter_index": adapter_idx,
"epoch": epoch,
"avg_loss": avg_loss,
"lora_rank": slot.config.lora_rank,
"lora_alpha": slot.config.lora_alpha,
"global_step": self.global_step,
});
std::fs::write(best_dir.join("metadata.json"), serde_json::to_string_pretty(&metadata)?)?;
save_adapter_lora_weights(&slot.lora_layers, &best_dir)?;
Ok(best_dir)
}
}
fn save_adapter_lora_weights(
lora_layers: &[LoRALayer],
dir: &std::path::Path,
) -> Result<(), Box<dyn std::error::Error>> {
let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
for (idx, lora) in lora_layers.iter().enumerate() {
let layer = idx / 2;
let proj = if idx % 2 == 0 { "q" } else { "v" };
let a_data = lora.lora_a().data();
let a_bytes: Vec<u8> =
bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
let a_shape = vec![lora.rank(), lora.d_in()];
tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
let b_data = lora.lora_b().data();
let b_bytes: Vec<u8> =
bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
let b_shape = vec![lora.d_out(), lora.rank()];
tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
}
let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
.iter()
.map(|(name, bytes, shape)| {
let view = safetensors::tensor::TensorView::new(
safetensors::tensor::Dtype::F32,
shape.clone(),
bytes,
)
.expect("valid tensor view");
(name.as_str(), view)
})
.collect();
let safetensor_bytes = safetensors::serialize(views, None)
.map_err(|e| format!("SafeTensors serialization failed: {e}"))?;
std::fs::write(dir.join("model.safetensors"), safetensor_bytes)?;
Ok(())
}
fn shuffle_samples(samples: &mut [InstructSample], seed: u64) {
let mut rng = seed;
for i in (1..samples.len()).rev() {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let j = (rng as usize) % (i + 1);
samples.swap(i, j);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schedule_round_robin() {
let sched = AdapterSchedule::RoundRobin;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
schedule: sched,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
let pipeline = MultiAdapterPipeline { global_step: 1, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(1));
let pipeline = MultiAdapterPipeline { global_step: 5, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(2));
}
#[test]
fn test_schedule_priority_val_loss() {
let mut slot0 = dummy_slot();
slot0.best_val_loss = 1.0;
let mut slot1 = dummy_slot();
slot1.best_val_loss = 3.0; let mut slot2 = dummy_slot();
slot2.best_val_loss = 2.0;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1, slot2],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(1)); }
#[test]
fn test_empty_pipeline() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), None);
assert!(pipeline.all_exhausted());
}
#[test]
fn test_shuffle_deterministic() {
let mut samples1 = vec![
InstructSample {
instruction: "a".into(),
response: "1".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "b".into(),
response: "2".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "c".into(),
response: "3".into(),
system: None,
metadata: None,
},
];
let mut samples2 = samples1.clone();
shuffle_samples(&mut samples1, 42);
shuffle_samples(&mut samples2, 42);
for (s1, s2) in samples1.iter().zip(samples2.iter()) {
assert_eq!(s1.instruction, s2.instruction);
}
}
#[test]
fn test_batch_train_step_synchronized() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::Synchronized,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert_eq!(results.len(), 2);
}
#[test]
fn test_batch_train_step_round_robin() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert_eq!(results.len(), 3);
}
#[test]
fn test_adapters_config_parse() {
let toml = r#"
[[adapter]]
data = "data/corpus-a.jsonl"
checkpoint = "checkpoints/adapter-a"
label = "code-review"
rank = 16
learning_rate = 0.0002
[[adapter]]
data = "data/corpus-b.jsonl"
checkpoint = "checkpoints/adapter-b"
label = "bug-fixing"
rank = 8
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid TOML");
assert_eq!(config.adapters.len(), 2);
assert_eq!(config.adapters[0].data, PathBuf::from("data/corpus-a.jsonl"));
assert_eq!(config.adapters[0].rank, Some(16));
assert_eq!(config.adapters[0].learning_rate, Some(0.0002));
assert_eq!(config.adapters[1].rank, Some(8));
assert!(config.adapters[1].learning_rate.is_none());
}
#[test]
fn test_adapters_config_to_adapter_configs() {
let toml = r#"
[[adapter]]
data = "data/a.jsonl"
checkpoint = "ckpt/a"
rank = 32
learning_rate = 0.001
epochs = 5
max_seq_len = 256
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters.len(), 1);
assert_eq!(adapters[0].instruct_config.lora_rank, 32);
assert!((adapters[0].instruct_config.learning_rate - 0.001).abs() < f32::EPSILON);
assert_eq!(adapters[0].instruct_config.epochs, 5);
assert_eq!(adapters[0].instruct_config.max_seq_len, 256);
}
#[test]
fn test_adapters_config_empty_fails() {
let toml = "";
assert!(AdaptersConfigFile::from_toml(toml).is_err());
}
#[test]
fn test_adapters_config_defaults_from_base() {
let toml = r#"
[[adapter]]
data = "data/x.jsonl"
checkpoint = "ckpt/x"
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let base = InstructConfig {
lora_rank: 16,
learning_rate: 0.0002,
epochs: 3,
max_seq_len: 512,
..Default::default()
};
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].instruct_config.lora_rank, 16);
assert!((adapters[0].instruct_config.learning_rate - 0.0002).abs() < f32::EPSILON);
assert_eq!(adapters[0].instruct_config.epochs, 3);
assert_eq!(adapters[0].instruct_config.max_seq_len, 512);
}
fn create_dummy_pipeline() -> InstructPipeline {
use crate::transformer::TransformerConfig;
let config = TransformerConfig::tiny();
InstructPipeline::new(&config, InstructConfig::default())
}
fn dummy_slot() -> AdapterSlot {
AdapterSlot {
lora_layers: Vec::new(),
train_samples: Vec::new(),
val_samples: Vec::new(),
checkpoint_dir: PathBuf::from("/tmp/test"),
metrics: Vec::new(),
config: InstructConfig::default(),
cursor: 0,
best_val_loss: f32::INFINITY,
#[cfg(feature = "cuda")]
optimizer_states: None,
#[cfg(feature = "cuda")]
lora_step: 0,
}
}
fn dummy_slot_with_data(n_samples: usize) -> AdapterSlot {
let samples: Vec<InstructSample> = (0..n_samples)
.map(|i| InstructSample {
instruction: format!("inst_{i}"),
response: format!("resp_{i}"),
system: None,
metadata: None,
})
.collect();
AdapterSlot {
lora_layers: Vec::new(),
train_samples: samples,
val_samples: Vec::new(),
checkpoint_dir: PathBuf::from("/tmp/test"),
metrics: Vec::new(),
config: InstructConfig::default(),
cursor: 0,
best_val_loss: f32::INFINITY,
#[cfg(feature = "cuda")]
optimizer_states: None,
#[cfg(feature = "cuda")]
lora_step: 0,
}
}
#[test]
fn test_adapter_schedule_default() {
let sched: AdapterSchedule = Default::default();
assert_eq!(sched, AdapterSchedule::RoundRobin);
}
#[test]
fn test_adapter_schedule_debug() {
assert_eq!(format!("{:?}", AdapterSchedule::Synchronized), "Synchronized");
assert_eq!(format!("{:?}", AdapterSchedule::RoundRobin), "RoundRobin");
assert_eq!(format!("{:?}", AdapterSchedule::PriorityValLoss), "PriorityValLoss");
}
#[test]
fn test_adapter_schedule_clone() {
let sched = AdapterSchedule::PriorityValLoss;
let cloned = sched;
assert_eq!(sched, cloned);
}
#[test]
fn test_adapter_schedule_eq() {
assert_eq!(AdapterSchedule::Synchronized, AdapterSchedule::Synchronized);
assert_ne!(AdapterSchedule::Synchronized, AdapterSchedule::RoundRobin);
assert_ne!(AdapterSchedule::RoundRobin, AdapterSchedule::PriorityValLoss);
}
#[test]
fn test_select_next_adapter_synchronized() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::Synchronized,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
}
#[test]
fn test_select_next_adapter_synchronized_any_step() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::Synchronized,
global_step: 42,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
}
#[test]
fn test_select_next_adapter_round_robin_wraps() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::RoundRobin,
global_step: 3,
};
assert_eq!(pipeline.select_next_adapter(), Some(0)); }
#[test]
fn test_select_next_adapter_priority_all_infinity() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
let result = pipeline.select_next_adapter();
assert!(result.is_some());
}
#[test]
fn test_select_next_adapter_priority_with_nan() {
let mut slot0 = dummy_slot();
slot0.best_val_loss = f32::NAN;
let mut slot1 = dummy_slot();
slot1.best_val_loss = 1.0;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
let result = pipeline.select_next_adapter();
assert!(result.is_some());
}
#[test]
fn test_num_adapters() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert_eq!(pipeline.num_adapters(), 3);
}
#[test]
fn test_num_adapters_empty() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert_eq!(pipeline.num_adapters(), 0);
}
#[test]
fn test_all_exhausted_with_data() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(3), dummy_slot_with_data(2)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert!(!pipeline.all_exhausted());
}
#[test]
fn test_all_exhausted_partially() {
let mut slot0 = dummy_slot_with_data(3);
slot0.cursor = 3; let slot1 = dummy_slot_with_data(2);
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert!(!pipeline.all_exhausted());
}
#[test]
fn test_all_exhausted_all_done() {
let mut slot0 = dummy_slot_with_data(3);
slot0.cursor = 3;
let mut slot1 = dummy_slot_with_data(2);
slot1.cursor = 2;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert!(pipeline.all_exhausted());
}
#[test]
fn test_reset_epoch() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(5), dummy_slot_with_data(3)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
pipeline.adapters[0].cursor = 5;
pipeline.adapters[1].cursor = 3;
pipeline.reset_epoch(42);
assert_eq!(pipeline.adapters[0].cursor, 0);
assert_eq!(pipeline.adapters[1].cursor, 0);
}
#[test]
fn test_reset_epoch_shuffle_deterministic() {
let mut pipeline1 = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(10)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let mut pipeline2 = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(10)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
pipeline1.reset_epoch(123);
pipeline2.reset_epoch(123);
for (s1, s2) in pipeline1.adapters[0]
.train_samples
.iter()
.zip(pipeline2.adapters[0].train_samples.iter())
{
assert_eq!(s1.instruction, s2.instruction);
}
}
#[test]
fn test_shuffle_samples_empty() {
let mut samples: Vec<InstructSample> = vec![];
shuffle_samples(&mut samples, 42);
assert!(samples.is_empty());
}
#[test]
fn test_shuffle_samples_single() {
let mut samples = vec![InstructSample {
instruction: "only".into(),
response: "one".into(),
system: None,
metadata: None,
}];
shuffle_samples(&mut samples, 42);
assert_eq!(samples.len(), 1);
assert_eq!(samples[0].instruction, "only");
}
#[test]
fn test_shuffle_samples_different_seeds() {
let mut samples1 = vec![
InstructSample {
instruction: "a".into(),
response: "1".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "b".into(),
response: "2".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "c".into(),
response: "3".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "d".into(),
response: "4".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "e".into(),
response: "5".into(),
system: None,
metadata: None,
},
];
let mut samples2 = samples1.clone();
shuffle_samples(&mut samples1, 1);
shuffle_samples(&mut samples2, 999);
let same =
samples1.iter().zip(samples2.iter()).all(|(s1, s2)| s1.instruction == s2.instruction);
assert!(!same, "Different seeds should produce different shuffles");
}
#[test]
fn test_adapters_config_from_toml_invalid_toml() {
let toml = "this is not valid TOML {{{}}}";
let result = AdaptersConfigFile::from_toml(toml);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("failed to parse"), "Expected parse error, got: {err}");
}
#[test]
fn test_adapters_config_from_toml_empty_adapters_array() {
let toml = r#"
[settings]
foo = "bar"
"#;
let result = AdaptersConfigFile::from_toml(toml);
assert!(result.is_err());
}
#[test]
fn test_adapters_config_from_file_not_found() {
let result = AdaptersConfigFile::from_file(Path::new("/tmp/nonexistent_adapters_xyz.toml"));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("failed to read"), "Expected read error, got: {err}");
}
#[test]
fn test_adapters_config_from_file_valid() {
let dir = std::env::temp_dir().join("entrenar_adapter_cfg_test");
std::fs::create_dir_all(&dir).expect("create dir");
let path = dir.join("adapters.toml");
std::fs::write(
&path,
r#"
[[adapter]]
data = "data/a.jsonl"
checkpoint = "ckpt/a"
label = "test-adapter"
"#,
)
.expect("write file");
let config = AdaptersConfigFile::from_file(&path).expect("valid config");
assert_eq!(config.adapters.len(), 1);
assert_eq!(config.adapters[0].label, Some("test-adapter".to_string()));
std::fs::remove_file(&path).expect("cleanup");
}
#[test]
fn test_adapter_entry_defaults() {
let toml = r#"
[[adapter]]
data = "data/x.jsonl"
checkpoint = "ckpt/x"
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let entry = &config.adapters[0];
assert!(entry.label.is_none());
assert!(entry.rank.is_none());
assert!(entry.learning_rate.is_none());
assert!(entry.epochs.is_none());
assert!(entry.max_seq_len.is_none());
}
#[test]
fn test_adapter_entry_all_fields() {
let toml = r#"
[[adapter]]
data = "data/full.jsonl"
checkpoint = "ckpt/full"
label = "full-adapter"
rank = 64
learning_rate = 0.001
epochs = 10
max_seq_len = 1024
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let entry = &config.adapters[0];
assert_eq!(entry.data, PathBuf::from("data/full.jsonl"));
assert_eq!(entry.checkpoint, PathBuf::from("ckpt/full"));
assert_eq!(entry.label, Some("full-adapter".to_string()));
assert_eq!(entry.rank, Some(64));
assert_eq!(entry.learning_rate, Some(0.001));
assert_eq!(entry.epochs, Some(10));
assert_eq!(entry.max_seq_len, Some(1024));
}
#[test]
fn test_to_adapter_configs_rank_sets_alpha() {
let toml = r#"
[[adapter]]
data = "data/a.jsonl"
checkpoint = "ckpt/a"
rank = 32
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].instruct_config.lora_rank, 32);
assert!((adapters[0].instruct_config.lora_alpha - 64.0).abs() < f32::EPSILON);
}
#[test]
fn test_to_adapter_configs_multiple() {
let toml = r#"
[[adapter]]
data = "a.jsonl"
checkpoint = "ckpt/a"
rank = 8
learning_rate = 0.0001
[[adapter]]
data = "b.jsonl"
checkpoint = "ckpt/b"
epochs = 20
[[adapter]]
data = "c.jsonl"
checkpoint = "ckpt/c"
max_seq_len = 128
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let base = InstructConfig {
lora_rank: 16,
learning_rate: 0.0002,
epochs: 3,
max_seq_len: 512,
..Default::default()
};
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters.len(), 3);
assert_eq!(adapters[0].instruct_config.lora_rank, 8);
assert!((adapters[0].instruct_config.learning_rate - 0.0001).abs() < f32::EPSILON);
assert_eq!(adapters[0].instruct_config.epochs, 3);
assert_eq!(adapters[1].instruct_config.lora_rank, 16); assert_eq!(adapters[1].instruct_config.epochs, 20);
assert_eq!(adapters[2].instruct_config.max_seq_len, 128);
assert_eq!(adapters[2].instruct_config.lora_rank, 16); }
#[test]
fn test_batch_train_step_priority_val_loss() {
let mut slot0 = dummy_slot();
slot0.best_val_loss = 2.0;
let mut slot1 = dummy_slot();
slot1.best_val_loss = 5.0;
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert_eq!(results.len(), 2);
}
#[test]
fn test_adapter_config_debug() {
let config = AdapterConfig {
data_path: PathBuf::from("test.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt"),
instruct_config: InstructConfig::default(),
};
let debug = format!("{config:?}");
assert!(debug.contains("AdapterConfig"));
assert!(debug.contains("test.jsonl"));
}
#[test]
fn test_adapter_config_clone() {
let config = AdapterConfig {
data_path: PathBuf::from("test.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt"),
instruct_config: InstructConfig::default(),
};
let cloned = config.clone();
assert_eq!(cloned.data_path, PathBuf::from("test.jsonl"));
assert_eq!(cloned.checkpoint_dir, PathBuf::from("/tmp/ckpt"));
}
#[test]
fn test_adapters_config_file_debug() {
let toml = r#"
[[adapter]]
data = "a.jsonl"
checkpoint = "ckpt/a"
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let debug = format!("{config:?}");
assert!(debug.contains("AdaptersConfigFile"));
}
#[test]
fn test_adapter_entry_debug() {
let toml = r#"
[[adapter]]
data = "a.jsonl"
checkpoint = "ckpt/a"
label = "test"
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let debug = format!("{:?}", config.adapters[0]);
assert!(debug.contains("AdapterEntry"));
assert!(debug.contains("test"));
}
#[test]
fn test_adapter_slot_cursor_tracking() {
let mut slot = dummy_slot_with_data(5);
assert_eq!(slot.cursor, 0);
slot.cursor = 3;
assert_eq!(slot.cursor, 3);
assert!(slot.cursor < slot.train_samples.len());
slot.cursor = 5;
assert!(slot.cursor >= slot.train_samples.len());
}
#[test]
fn test_adapter_slot_best_val_loss() {
let mut slot = dummy_slot();
assert_eq!(slot.best_val_loss, f32::INFINITY);
slot.best_val_loss = 0.5;
assert!((slot.best_val_loss - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_multi_adapter_pipeline_global_step() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert_eq!(pipeline.global_step, 0);
}
#[test]
fn test_train_step_adapter_exhausted() {
let mut slot = dummy_slot_with_data(2);
slot.cursor = 2;
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let result = pipeline.train_step_adapter(0);
assert!(result.is_none(), "Exhausted adapter should return None");
}
#[test]
fn test_batch_train_step_empty() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![],
schedule: AdapterSchedule::Synchronized,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert!(results.is_empty());
}
#[test]
fn test_multi_adapter_pipeline_new() {
let pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::Synchronized);
assert_eq!(pipeline.num_adapters(), 0);
assert_eq!(pipeline.global_step, 0);
assert!(pipeline.all_exhausted());
}
#[test]
fn test_multi_adapter_pipeline_add_adapter() {
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt"),
instruct_config: InstructConfig::default(),
};
let samples = vec![InstructSample {
instruction: "test".into(),
response: "response".into(),
system: None,
metadata: None,
}];
pipeline.add_adapter(config, samples, vec![]);
assert_eq!(pipeline.num_adapters(), 1);
assert!(!pipeline.all_exhausted());
}
#[test]
fn test_train_step_adapter_no_tokenizer() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(5)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let result = pipeline.train_step_adapter(0);
assert!(result.is_none());
assert_eq!(pipeline.adapters[0].cursor, 1);
}
#[test]
fn test_train_step_increments_global_step() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(5)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let _ = pipeline.train_step_adapter(0);
}
#[test]
fn test_batch_train_step_synchronized_all_exhausted() {
let mut slot0 = dummy_slot_with_data(1);
slot0.cursor = 1;
let mut slot1 = dummy_slot_with_data(1);
slot1.cursor = 1;
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::Synchronized,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert_eq!(results.len(), 2);
assert!(results.iter().all(Option::is_none));
}
#[test]
fn test_reset_epoch_different_seeds_different_orders() {
let mut pipeline1 = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(20)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let mut pipeline2 = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(20)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
pipeline1.reset_epoch(1);
pipeline2.reset_epoch(999);
let same = pipeline1.adapters[0]
.train_samples
.iter()
.zip(pipeline2.adapters[0].train_samples.iter())
.all(|(s1, s2)| s1.instruction == s2.instruction);
assert!(!same, "Different seeds should produce different shuffles");
}
#[test]
fn test_shuffle_samples_preserves_elements() {
let mut samples: Vec<InstructSample> = (0..10)
.map(|i| InstructSample {
instruction: format!("inst_{i}"),
response: format!("resp_{i}"),
system: None,
metadata: None,
})
.collect();
let original_instructions: Vec<String> =
samples.iter().map(|s| s.instruction.clone()).collect();
shuffle_samples(&mut samples, 42);
let mut shuffled_instructions: Vec<String> =
samples.iter().map(|s| s.instruction.clone()).collect();
let mut sorted_original = original_instructions.clone();
sorted_original.sort();
shuffled_instructions.sort();
assert_eq!(sorted_original, shuffled_instructions);
}
#[test]
fn test_adapter_slot_metrics_empty() {
let slot = dummy_slot();
assert!(slot.metrics.is_empty());
}
#[test]
fn test_adapter_slot_val_samples() {
let slot = dummy_slot();
assert!(slot.val_samples.is_empty());
}
#[test]
fn test_adapter_slot_lora_layers_empty() {
let slot = dummy_slot();
assert!(slot.lora_layers.is_empty());
}
#[test]
fn test_adapters_config_label_propagation() {
let toml = r#"
[[adapter]]
data = "d1.jsonl"
checkpoint = "c1"
label = "adapter-one"
[[adapter]]
data = "d2.jsonl"
checkpoint = "c2"
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
assert_eq!(config.adapters[0].label, Some("adapter-one".to_string()));
assert!(config.adapters[1].label.is_none());
}
#[test]
fn test_adapters_config_to_adapter_configs_alpha_calculation() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
checkpoint = "ckpt"
rank = 64
"#;
let config = AdaptersConfigFile::from_toml(toml).expect("valid");
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert!((adapters[0].instruct_config.lora_alpha - 128.0).abs() < f32::EPSILON);
}
#[test]
fn test_select_next_adapter_round_robin_large_step() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot(), dummy_slot()],
schedule: AdapterSchedule::RoundRobin,
global_step: 1000,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
let pipeline = MultiAdapterPipeline { global_step: 1001, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(1)); }
#[test]
fn test_select_next_adapter_priority_selects_worst() {
let mut slot0 = dummy_slot();
slot0.best_val_loss = 0.1;
let mut slot1 = dummy_slot();
slot1.best_val_loss = 10.0;
let mut slot2 = dummy_slot();
slot2.best_val_loss = 5.0;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1, slot2],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(1)); }
#[test]
fn test_multi_adapter_multiple_add_adapter() {
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::Synchronized);
for i in 0..3 {
let config = AdapterConfig {
data_path: PathBuf::from(format!("data{i}.jsonl")),
checkpoint_dir: PathBuf::from(format!("/tmp/ckpt{i}")),
instruct_config: InstructConfig::default(),
};
pipeline.add_adapter(config, vec![], vec![]);
}
assert_eq!(pipeline.num_adapters(), 3);
assert!(pipeline.all_exhausted()); }
#[test]
fn test_cov3_save_adapter_checkpoint_creates_dir_and_files() {
let dir = std::env::temp_dir().join("entrenar_cov3_ckpt_test");
let _ = std::fs::remove_dir_all(&dir);
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: dir.clone(),
instruct_config: InstructConfig::default(),
};
let samples = vec![InstructSample {
instruction: "test".into(),
response: "resp".into(),
system: None,
metadata: None,
}];
pipeline.add_adapter(config, samples, vec![]);
let result = pipeline.save_adapter_checkpoint(0, 1, 0.5);
assert!(result.is_ok());
let ckpt_dir = result.unwrap();
assert!(ckpt_dir.join("metadata.json").exists());
assert!(ckpt_dir.join("model.safetensors").exists());
let metadata_str = std::fs::read_to_string(ckpt_dir.join("metadata.json")).unwrap();
assert!(metadata_str.contains("\"mode\": \"multi_adapter\""));
assert!(metadata_str.contains("\"adapter_index\": 0"));
assert!(metadata_str.contains("\"epoch\": 1"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_save_best_checkpoint_creates_dir_and_files() {
let dir = std::env::temp_dir().join("entrenar_cov3_best_ckpt_test");
let _ = std::fs::remove_dir_all(&dir);
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: dir.clone(),
instruct_config: InstructConfig::default(),
};
pipeline.add_adapter(config, vec![], vec![]);
let result = pipeline.save_best_checkpoint(0, 2, 0.3);
assert!(result.is_ok());
let best_dir = result.unwrap();
assert_eq!(best_dir, dir.join("best"));
assert!(best_dir.join("metadata.json").exists());
assert!(best_dir.join("model.safetensors").exists());
let metadata_str = std::fs::read_to_string(best_dir.join("metadata.json")).unwrap();
assert!(metadata_str.contains("\"mode\": \"multi_adapter\""));
assert!(metadata_str.contains("\"epoch\": 2"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_save_best_checkpoint_overwrites_previous() {
let dir = std::env::temp_dir().join("entrenar_cov3_best_overwrite");
let _ = std::fs::remove_dir_all(&dir);
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: dir.clone(),
instruct_config: InstructConfig::default(),
};
pipeline.add_adapter(config, vec![], vec![]);
pipeline.save_best_checkpoint(0, 1, 1.0).unwrap();
pipeline.save_best_checkpoint(0, 5, 0.2).unwrap();
let metadata_str = std::fs::read_to_string(dir.join("best").join("metadata.json")).unwrap();
assert!(metadata_str.contains("\"epoch\": 5"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_save_adapter_lora_weights_empty_layers() {
let dir = std::env::temp_dir().join("entrenar_cov3_empty_lora");
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
let result = save_adapter_lora_weights(&[], &dir);
assert!(result.is_ok());
assert!(dir.join("model.safetensors").exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_save_adapter_lora_weights_with_real_layers() {
let dir = std::env::temp_dir().join("entrenar_cov3_real_lora");
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
let model_config = crate::transformer::TransformerConfig::tiny();
let model = crate::transformer::Transformer::new(&model_config);
let instruct_config = InstructConfig { lora_rank: 4, ..InstructConfig::default() };
let layers = InstructPipeline::build_lora_layers(&model, &model_config, &instruct_config);
let result = save_adapter_lora_weights(&layers, &dir);
assert!(result.is_ok());
let st_bytes = std::fs::read(dir.join("model.safetensors")).unwrap();
let st = safetensors::SafeTensors::deserialize(&st_bytes).unwrap();
assert_eq!(st.len(), layers.len() * 2);
let names: Vec<String> = st.names().iter().map(std::string::ToString::to_string).collect();
assert!(names.iter().any(|n| n.contains("lora_a")));
assert!(names.iter().any(|n| n.contains("lora_b")));
assert!(names.iter().any(|n| n.contains("q_proj")));
assert!(names.iter().any(|n| n.contains("v_proj")));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_shuffle_samples_large_input() {
let mut samples: Vec<InstructSample> = (0..100)
.map(|i| InstructSample {
instruction: format!("inst_{i}"),
response: format!("resp_{i}"),
system: None,
metadata: None,
})
.collect();
let original: Vec<String> = samples.iter().map(|s| s.instruction.clone()).collect();
shuffle_samples(&mut samples, 12345);
let shuffled: Vec<String> = samples.iter().map(|s| s.instruction.clone()).collect();
assert_ne!(original, shuffled, "100 samples should shuffle to different order");
let mut sorted_original = original;
sorted_original.sort();
let mut sorted_shuffled = shuffled;
sorted_shuffled.sort();
assert_eq!(sorted_original, sorted_shuffled);
}
#[test]
fn test_cov3_shuffle_samples_two_elements() {
let mut samples = vec![
InstructSample {
instruction: "a".into(),
response: "1".into(),
system: None,
metadata: None,
},
InstructSample {
instruction: "b".into(),
response: "2".into(),
system: None,
metadata: None,
},
];
shuffle_samples(&mut samples, 42);
assert_eq!(samples.len(), 2);
}
#[test]
fn test_cov3_adapters_config_toml_all_overrides() {
let toml = r#"
[[adapter]]
data = "data/test.jsonl"
checkpoint = "ckpt/test"
label = "full-override"
rank = 64
learning_rate = 0.001
epochs = 20
max_seq_len = 2048
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].instruct_config.lora_rank, 64);
assert!((adapters[0].instruct_config.lora_alpha - 128.0).abs() < f32::EPSILON);
assert!((adapters[0].instruct_config.learning_rate - 0.001).abs() < f32::EPSILON);
assert_eq!(adapters[0].instruct_config.epochs, 20);
assert_eq!(adapters[0].instruct_config.max_seq_len, 2048);
}
#[test]
fn test_cov3_adapters_config_many_adapters() {
let mut toml_str = String::new();
for i in 0..10 {
toml_str.push_str(&format!(
r#"
[[adapter]]
data = "data/{i}.jsonl"
checkpoint = "ckpt/{i}"
rank = {rank}
"#,
i = i,
rank = 4 + i * 2,
));
}
let config = AdaptersConfigFile::from_toml(&toml_str).unwrap();
assert_eq!(config.adapters.len(), 10);
for (i, entry) in config.adapters.iter().enumerate() {
assert_eq!(entry.rank, Some(4 + i * 2));
}
}
#[test]
fn test_cov3_adapters_config_toml_missing_required_fields() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
"#;
let result = AdaptersConfigFile::from_toml(toml);
assert!(result.is_err());
}
#[test]
fn test_cov3_adapters_config_toml_missing_data_field() {
let toml = r#"
[[adapter]]
checkpoint = "ckpt"
"#;
let result = AdaptersConfigFile::from_toml(toml);
assert!(result.is_err());
}
#[test]
fn test_cov3_adapters_config_toml_extra_fields_ignored() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
checkpoint = "ckpt"
unknown_field = "ignored"
"#;
let result = AdaptersConfigFile::from_toml(toml);
let _ = result;
}
#[test]
fn test_cov3_adapters_config_rank_zero() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
checkpoint = "ckpt"
rank = 0
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].instruct_config.lora_rank, 0);
assert!((adapters[0].instruct_config.lora_alpha - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_cov3_add_adapter_creates_lora_layers() {
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt"),
instruct_config: InstructConfig { lora_rank: 4, ..InstructConfig::default() },
};
pipeline.add_adapter(config, vec![], vec![]);
assert_eq!(pipeline.adapters[0].lora_layers.len(), 4);
}
#[test]
fn test_cov3_add_adapter_with_val_samples() {
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt"),
instruct_config: InstructConfig::default(),
};
let val_samples = vec![InstructSample {
instruction: "val_q".into(),
response: "val_a".into(),
system: None,
metadata: None,
}];
pipeline.add_adapter(config, vec![], val_samples);
assert_eq!(pipeline.adapters[0].val_samples.len(), 1);
assert_eq!(pipeline.adapters[0].val_samples[0].instruction, "val_q");
}
#[test]
fn test_cov3_add_adapter_initial_state() {
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: PathBuf::from("/tmp/ckpt_initial"),
instruct_config: InstructConfig { lora_rank: 8, ..InstructConfig::default() },
};
pipeline.add_adapter(config, vec![], vec![]);
let slot = &pipeline.adapters[0];
assert_eq!(slot.cursor, 0);
assert_eq!(slot.best_val_loss, f32::INFINITY);
assert!(slot.metrics.is_empty());
assert_eq!(slot.config.lora_rank, 8);
assert_eq!(slot.checkpoint_dir, PathBuf::from("/tmp/ckpt_initial"));
}
#[test]
fn test_cov3_train_step_adapter_empty_tokens() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(5)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
let result = pipeline.train_step_adapter(0);
assert!(result.is_none());
assert_eq!(pipeline.adapters[0].cursor, 1);
}
#[test]
fn test_cov3_batch_train_step_synchronized_mixed_exhaustion() {
let slot0 = dummy_slot_with_data(3); let mut slot1 = dummy_slot_with_data(1);
slot1.cursor = 1;
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::Synchronized,
global_step: 0,
};
let results = pipeline.batch_train_step();
assert_eq!(results.len(), 2);
assert!(results[1].is_none());
}
#[test]
fn test_cov3_batch_train_step_round_robin_cycling() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![
dummy_slot_with_data(10),
dummy_slot_with_data(10),
dummy_slot_with_data(10),
],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
let pipeline = MultiAdapterPipeline { global_step: 1, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(1));
let pipeline = MultiAdapterPipeline { global_step: 2, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(2));
let pipeline = MultiAdapterPipeline { global_step: 3, ..pipeline };
assert_eq!(pipeline.select_next_adapter(), Some(0));
}
#[test]
fn test_cov3_reset_epoch_multiple_adapters_independent_seeds() {
let mut pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![dummy_slot_with_data(20), dummy_slot_with_data(20)],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
pipeline.reset_epoch(42);
let order0: Vec<String> =
pipeline.adapters[0].train_samples.iter().map(|s| s.instruction.clone()).collect();
let order1: Vec<String> =
pipeline.adapters[1].train_samples.iter().map(|s| s.instruction.clone()).collect();
assert_ne!(order0, order1, "Different adapters should have different shuffle orders");
}
#[test]
fn test_cov3_adapter_schedule_copy() {
let s1 = AdapterSchedule::PriorityValLoss;
let s2 = s1; assert_eq!(s1, s2);
}
#[test]
fn test_cov3_adapters_config_file_clone() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
checkpoint = "ckpt"
label = "test"
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let cloned = config.clone();
assert_eq!(cloned.adapters.len(), 1);
assert_eq!(cloned.adapters[0].label, Some("test".to_string()));
}
#[test]
fn test_cov3_adapter_entry_clone() {
let toml = r#"
[[adapter]]
data = "data.jsonl"
checkpoint = "ckpt"
rank = 32
learning_rate = 0.001
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let cloned = config.adapters[0].clone();
assert_eq!(cloned.rank, Some(32));
assert_eq!(cloned.learning_rate, Some(0.001));
}
#[test]
fn test_cov3_save_adapter_checkpoint_metadata_values() {
let dir = std::env::temp_dir().join("entrenar_cov3_ckpt_meta");
let _ = std::fs::remove_dir_all(&dir);
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
pipeline.global_step = 42;
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: dir.clone(),
instruct_config: InstructConfig {
lora_rank: 8,
lora_alpha: 16.0,
..InstructConfig::default()
},
};
let samples: Vec<InstructSample> = (0..5)
.map(|i| InstructSample {
instruction: format!("q{i}"),
response: format!("a{i}"),
system: None,
metadata: None,
})
.collect();
pipeline.add_adapter(config, samples, vec![]);
pipeline.adapters[0].best_val_loss = 0.75;
let ckpt_dir = pipeline.save_adapter_checkpoint(0, 3, 0.42).unwrap();
let metadata_str = std::fs::read_to_string(ckpt_dir.join("metadata.json")).unwrap();
let metadata: serde_json::Value = serde_json::from_str(&metadata_str).unwrap();
assert_eq!(metadata["adapter_index"], 0);
assert_eq!(metadata["epoch"], 3);
assert_eq!(metadata["lora_rank"], 8);
assert_eq!(metadata["train_samples"], 5);
assert_eq!(metadata["global_step"], 42);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_save_adapter_checkpoint_multiple_epochs() {
let dir = std::env::temp_dir().join("entrenar_cov3_multi_epoch");
let _ = std::fs::remove_dir_all(&dir);
let mut pipeline =
MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
let config = AdapterConfig {
data_path: PathBuf::from("data.jsonl"),
checkpoint_dir: dir.clone(),
instruct_config: InstructConfig::default(),
};
pipeline.add_adapter(config, vec![], vec![]);
for epoch in 0..3 {
let ckpt_dir =
pipeline.save_adapter_checkpoint(0, epoch, 1.0 - epoch as f32 * 0.2).unwrap();
assert!(ckpt_dir.join("metadata.json").exists());
assert!(ckpt_dir.join("model.safetensors").exists());
}
assert!(dir.join("epoch-0").exists());
assert!(dir.join("epoch-1").exists());
assert!(dir.join("epoch-2").exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_all_exhausted_single_adapter_one_sample() {
let slot = dummy_slot_with_data(1);
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert!(!pipeline.all_exhausted());
}
#[test]
fn test_cov3_all_exhausted_single_adapter_cursor_at_end() {
let mut slot = dummy_slot_with_data(1);
slot.cursor = 1;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot],
schedule: AdapterSchedule::RoundRobin,
global_step: 0,
};
assert!(pipeline.all_exhausted());
}
#[test]
fn test_cov3_select_priority_single_adapter() {
let mut slot = dummy_slot();
slot.best_val_loss = 3.0;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
assert_eq!(pipeline.select_next_adapter(), Some(0));
}
#[test]
fn test_cov3_select_priority_equal_losses() {
let mut slot0 = dummy_slot();
slot0.best_val_loss = 1.0;
let mut slot1 = dummy_slot();
slot1.best_val_loss = 1.0;
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![slot0, slot1],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
let result = pipeline.select_next_adapter();
assert!(result == Some(0) || result == Some(1));
}
#[test]
fn test_cov3_to_adapter_configs_no_overrides() {
let toml = r#"
[[adapter]]
data = "d.jsonl"
checkpoint = "c"
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let base = InstructConfig {
lora_rank: 32,
lora_alpha: 64.0,
learning_rate: 0.005,
epochs: 7,
max_seq_len: 1024,
gradient_clip_norm: Some(2.0),
quantize_nf4: true,
};
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].instruct_config.lora_rank, 32);
assert!((adapters[0].instruct_config.lora_alpha - 64.0).abs() < f32::EPSILON);
assert!((adapters[0].instruct_config.learning_rate - 0.005).abs() < f32::EPSILON);
assert_eq!(adapters[0].instruct_config.epochs, 7);
assert_eq!(adapters[0].instruct_config.max_seq_len, 1024);
assert_eq!(adapters[0].instruct_config.gradient_clip_norm, Some(2.0));
assert!(adapters[0].instruct_config.quantize_nf4);
}
#[test]
fn test_cov3_to_adapter_configs_preserves_data_and_checkpoint_paths() {
let toml = r#"
[[adapter]]
data = "/absolute/path/data.jsonl"
checkpoint = "../relative/ckpt"
"#;
let config = AdaptersConfigFile::from_toml(toml).unwrap();
let base = InstructConfig::default();
let adapters = config.to_adapter_configs(&base);
assert_eq!(adapters[0].data_path, PathBuf::from("/absolute/path/data.jsonl"));
assert_eq!(adapters[0].checkpoint_dir, PathBuf::from("../relative/ckpt"));
}
#[test]
fn test_cov3_adapters_config_from_file_invalid_toml() {
let dir = std::env::temp_dir().join("entrenar_cov3_invalid_toml");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("invalid.toml");
std::fs::write(&path, "this {{ is not valid TOML").unwrap();
let result = AdaptersConfigFile::from_file(&path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("failed to parse"), "Expected parse error, got: {err}");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_cov3_adapter_slot_checkpoint_dir() {
let slot = AdapterSlot {
lora_layers: Vec::new(),
train_samples: Vec::new(),
val_samples: Vec::new(),
checkpoint_dir: PathBuf::from("/my/custom/ckpt"),
metrics: Vec::new(),
config: InstructConfig::default(),
cursor: 0,
best_val_loss: f32::INFINITY,
#[cfg(feature = "cuda")]
optimizer_states: None,
#[cfg(feature = "cuda")]
lora_step: 0,
};
assert_eq!(slot.checkpoint_dir, PathBuf::from("/my/custom/ckpt"));
}
#[test]
fn test_cov3_multi_adapter_schedule_field() {
let pipeline = MultiAdapterPipeline {
base_pipeline: create_dummy_pipeline(),
adapters: vec![],
schedule: AdapterSchedule::PriorityValLoss,
global_step: 0,
};
assert_eq!(pipeline.schedule, AdapterSchedule::PriorityValLoss);
}
}