Skip to main content

entrenar/finetune/
multi_adapter_pipeline.rs

1//! Multi-Adapter Pipeline (GPU-SHARE Phase 2.1, GH-203)
2//!
3//! Trains N LoRA adapter sets concurrently on a single frozen NF4 base model.
4//! The base model is loaded once to GPU, and each adapter maintains independent:
5//! - LoRA A/B matrices (Q and V projections)
6//! - AdamW optimizer state
7//! - Training data iterator
8//! - Checkpoint directory
9//!
10//! # VRAM Savings
11//!
12//! Compared to N separate processes (MPS), this saves (N-1) × base_model_vram:
13//! - MPS (3 adapters on 7B): 3 × 7.3 GB = 21.9 GB
14//! - Multi-adapter (3 adapters on 7B): 7.3 GB + 3 × 0.02 GB = 7.36 GB
15//!
16//! # Architecture
17//!
18//! ```text
19//! ┌──────────────────────────────────────┐
20//! │         Frozen NF4 Base Model        │ ← loaded once
21//! │    (CudaNf4TransformerBlock × L)     │
22//! └──────────┬───────────┬───────────┬───┘
23//!            │           │           │
24//!     ┌──────┴──┐ ┌──────┴──┐ ┌──────┴──┐
25//!     │Adapter 0│ │Adapter 1│ │Adapter 2│
26//!     │LoRA A/B │ │LoRA A/B │ │LoRA A/B │
27//!     │Optimizer│ │Optimizer│ │Optimizer│
28//!     │  Data   │ │  Data   │ │  Data   │
29//!     └─────────┘ └─────────┘ └─────────┘
30//! ```
31
32use super::instruct_corpus::InstructSample;
33use super::instruct_pipeline::{InstructConfig, InstructPipeline, InstructStepResult};
34use super::instruct_trainer::InstructEpochMetrics;
35use crate::lora::LoRALayer;
36use serde::Deserialize;
37use std::path::{Path, PathBuf};
38
39/// Scheduling strategy for multi-adapter training.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub enum AdapterSchedule {
42    /// All adapters process one sample each per step (synchronized).
43    Synchronized,
44    /// Round-robin: each step trains one adapter.
45    #[default]
46    RoundRobin,
47    /// Priority: adapter with highest validation loss gets the next step.
48    PriorityValLoss,
49}
50
51/// Configuration for a single adapter slot.
52#[derive(Debug, Clone)]
53pub struct AdapterConfig {
54    /// Path to training data (JSONL instruct corpus).
55    pub data_path: PathBuf,
56    /// Directory for adapter checkpoints.
57    pub checkpoint_dir: PathBuf,
58    /// Per-adapter hyperparameters (lora_rank, lr, epochs, etc.)
59    pub instruct_config: InstructConfig,
60}
61
62/// TOML file schema for `--adapters-config adapters.toml` (GPU-SHARE §2.4).
63///
64/// # Example
65///
66/// ```toml
67/// [[adapter]]
68/// data = "data/corpus-a.jsonl"
69/// checkpoint = "checkpoints/adapter-a"
70/// label = "code-review"
71/// rank = 16
72/// learning_rate = 0.0002
73///
74/// [[adapter]]
75/// data = "data/corpus-b.jsonl"
76/// checkpoint = "checkpoints/adapter-b"
77/// label = "bug-fixing"
78/// ```
79#[derive(Debug, Clone, Deserialize)]
80pub struct AdaptersConfigFile {
81    /// List of adapter configurations.
82    #[serde(rename = "adapter")]
83    pub adapters: Vec<AdapterEntry>,
84}
85
86/// A single adapter entry in the TOML config file.
87#[derive(Debug, Clone, Deserialize)]
88pub struct AdapterEntry {
89    /// Path to training data (JSONL instruct corpus).
90    pub data: PathBuf,
91    /// Directory for adapter checkpoints.
92    pub checkpoint: PathBuf,
93    /// Human-readable label for this adapter.
94    #[serde(default)]
95    pub label: Option<String>,
96    /// LoRA rank override (default: 16).
97    #[serde(default)]
98    pub rank: Option<usize>,
99    /// Learning rate override.
100    #[serde(default)]
101    pub learning_rate: Option<f32>,
102    /// Epochs override.
103    #[serde(default)]
104    pub epochs: Option<usize>,
105    /// Maximum sequence length override.
106    #[serde(default)]
107    pub max_seq_len: Option<usize>,
108}
109
110impl AdaptersConfigFile {
111    /// Parse an adapters config from a TOML file.
112    pub fn from_file(path: &Path) -> Result<Self, String> {
113        let contents = std::fs::read_to_string(path)
114            .map_err(|e| format!("failed to read {}: {e}", path.display()))?;
115        Self::from_toml(&contents)
116    }
117
118    /// Parse an adapters config from a TOML string.
119    pub fn from_toml(toml_str: &str) -> Result<Self, String> {
120        let config: Self =
121            toml::from_str(toml_str).map_err(|e| format!("failed to parse adapters TOML: {e}"))?;
122        if config.adapters.is_empty() {
123            return Err("adapters config must have at least one [[adapter]] entry".to_string());
124        }
125        Ok(config)
126    }
127
128    /// Convert to `Vec<AdapterConfig>` using a base `InstructConfig` for defaults.
129    pub fn to_adapter_configs(&self, base: &InstructConfig) -> Vec<AdapterConfig> {
130        self.adapters
131            .iter()
132            .map(|entry| {
133                let mut config = base.clone();
134                if let Some(rank) = entry.rank {
135                    config.lora_rank = rank;
136                    config.lora_alpha = rank as f32 * 2.0;
137                }
138                if let Some(lr) = entry.learning_rate {
139                    config.learning_rate = lr;
140                }
141                if let Some(epochs) = entry.epochs {
142                    config.epochs = epochs;
143                }
144                if let Some(seq_len) = entry.max_seq_len {
145                    config.max_seq_len = seq_len;
146                }
147                AdapterConfig {
148                    data_path: entry.data.clone(),
149                    checkpoint_dir: entry.checkpoint.clone(),
150                    instruct_config: config,
151                }
152            })
153            .collect()
154    }
155}
156
157/// Runtime state for one adapter during training.
158pub struct AdapterSlot {
159    /// Per-adapter LoRA layers (Q and V projections, per transformer layer).
160    pub lora_layers: Vec<LoRALayer>,
161    /// Training data for this adapter.
162    pub train_samples: Vec<InstructSample>,
163    /// Validation data for this adapter.
164    pub val_samples: Vec<InstructSample>,
165    /// Checkpoint directory for this adapter.
166    pub checkpoint_dir: PathBuf,
167    /// Per-adapter metrics history.
168    pub metrics: Vec<InstructEpochMetrics>,
169    /// Per-adapter config.
170    pub config: InstructConfig,
171    /// Current sample index within the training data.
172    pub cursor: usize,
173    /// Best validation loss (for early stopping / priority scheduling).
174    pub best_val_loss: f32,
175
176    /// Per-adapter GPU LoRA optimizer states.
177    #[cfg(feature = "cuda")]
178    #[allow(dead_code)]
179    pub(crate) optimizer_states: Option<Vec<crate::transformer::GpuLoraOptimizerState>>,
180    /// NF4 LoRA optimizer step counter.
181    #[cfg(feature = "cuda")]
182    pub lora_step: u32,
183}
184
185/// Multi-adapter training pipeline.
186///
187/// Trains N LoRA adapter sets on a shared frozen NF4 base model.
188/// GPU memory is dominated by the base model (~7 GB for 7B NF4);
189/// each adapter adds only ~20 MB (LoRA A/B matrices + optimizer state).
190pub struct MultiAdapterPipeline {
191    /// The base InstructPipeline (owns the frozen transformer + CUDA blocks).
192    pub base_pipeline: InstructPipeline,
193    /// Independent adapter slots.
194    pub adapters: Vec<AdapterSlot>,
195    /// Scheduling strategy.
196    pub schedule: AdapterSchedule,
197    /// Current step counter (global across all adapters).
198    pub global_step: usize,
199}
200
201impl MultiAdapterPipeline {
202    /// Create a new multi-adapter pipeline.
203    ///
204    /// The `base_pipeline` should be a fully initialized InstructPipeline
205    /// (with CUDA blocks uploaded if GPU training is desired). Adapter slots
206    /// are initially empty — call `add_adapter()` to register each one.
207    pub fn new(base_pipeline: InstructPipeline, schedule: AdapterSchedule) -> Self {
208        Self { base_pipeline, adapters: Vec::new(), schedule, global_step: 0 }
209    }
210
211    /// Add an adapter slot with its own training data and checkpoint directory.
212    pub fn add_adapter(
213        &mut self,
214        config: AdapterConfig,
215        train_samples: Vec<InstructSample>,
216        val_samples: Vec<InstructSample>,
217    ) {
218        let model_config = &self.base_pipeline.model.config;
219        let lora_layers = InstructPipeline::build_lora_layers(
220            &self.base_pipeline.model,
221            model_config,
222            &config.instruct_config,
223        );
224
225        let slot = AdapterSlot {
226            lora_layers,
227            train_samples,
228            val_samples,
229            checkpoint_dir: config.checkpoint_dir,
230            metrics: Vec::new(),
231            config: config.instruct_config,
232            cursor: 0,
233            best_val_loss: f32::INFINITY,
234            #[cfg(feature = "cuda")]
235            optimizer_states: None,
236            #[cfg(feature = "cuda")]
237            lora_step: 0,
238        };
239
240        self.adapters.push(slot);
241    }
242
243    /// Number of registered adapters.
244    pub fn num_adapters(&self) -> usize {
245        self.adapters.len()
246    }
247
248    /// Select which adapter index to train next based on the schedule.
249    pub fn select_next_adapter(&self) -> Option<usize> {
250        if self.adapters.is_empty() {
251            return None;
252        }
253        match self.schedule {
254            AdapterSchedule::Synchronized => {
255                // All adapters train — caller should iterate all
256                Some(0)
257            }
258            AdapterSchedule::RoundRobin => Some(self.global_step % self.adapters.len()),
259            AdapterSchedule::PriorityValLoss => {
260                // Pick adapter with highest (worst) validation loss
261                self.adapters
262                    .iter()
263                    .enumerate()
264                    .max_by(|(_, a), (_, b)| {
265                        a.best_val_loss
266                            .partial_cmp(&b.best_val_loss)
267                            .unwrap_or(std::cmp::Ordering::Equal)
268                    })
269                    .map(|(i, _)| i)
270            }
271        }
272    }
273
274    /// Train one step on the specified adapter.
275    ///
276    /// Swaps the adapter's LoRA layers into the base pipeline, runs one
277    /// training step, then swaps them back out.
278    ///
279    /// # Returns
280    ///
281    /// Training step result (loss, perplexity) or `None` if the adapter's
282    /// data is exhausted.
283    pub fn train_step_adapter(&mut self, adapter_idx: usize) -> Option<InstructStepResult> {
284        let slot = &mut self.adapters[adapter_idx];
285
286        // Check if data is exhausted
287        if slot.cursor >= slot.train_samples.len() {
288            return None;
289        }
290
291        let sample = &slot.train_samples[slot.cursor];
292        slot.cursor += 1;
293
294        // Tokenize prompt and response separately
295        if !self.base_pipeline.has_tokenizer() {
296            return None;
297        }
298        let prompt_ids = self.base_pipeline.tokenize(&sample.instruction);
299        let response_ids = self.base_pipeline.tokenize(&sample.response);
300
301        if prompt_ids.is_empty() || response_ids.is_empty() {
302            return None;
303        }
304
305        // Swap adapter's LoRA layers into the base pipeline
306        std::mem::swap(&mut slot.lora_layers, &mut self.base_pipeline.lora_layers);
307
308        // Run training step through base pipeline (uses shared CUDA blocks)
309        let result = self.base_pipeline.train_step(&prompt_ids, &response_ids);
310
311        // Swap LoRA layers back
312        std::mem::swap(&mut slot.lora_layers, &mut self.base_pipeline.lora_layers);
313
314        self.global_step += 1;
315
316        Some(result)
317    }
318
319    /// Reset all adapter cursors for a new epoch.
320    pub fn reset_epoch(&mut self, seed: u64) {
321        for (i, slot) in self.adapters.iter_mut().enumerate() {
322            slot.cursor = 0;
323            // Shuffle training data with per-adapter seed
324            shuffle_samples(&mut slot.train_samples, seed.wrapping_add(i as u64));
325        }
326    }
327
328    /// Check if all adapters have exhausted their training data.
329    pub fn all_exhausted(&self) -> bool {
330        self.adapters.iter().all(|s| s.cursor >= s.train_samples.len())
331    }
332
333    /// Batch training step across all non-exhausted adapters (GH-204).
334    ///
335    /// Trains each adapter that still has data, using the scheduling mode.
336    /// In `Synchronized` mode, all adapters train one sample each.
337    /// In `RoundRobin`, only the next scheduled adapter trains.
338    /// In `PriorityValLoss`, the adapter with highest val loss trains.
339    ///
340    /// Returns per-adapter step results (indexed by adapter, None if skipped/exhausted).
341    ///
342    /// NOTE: Current implementation runs sequential forward+backward per adapter
343    /// (swapping LoRA layers). Future optimization: fused BatchLoRA forward
344    /// through shared NF4 blocks with per-adapter LoRA deltas (arXiv:2510.00206).
345    pub fn batch_train_step(&mut self) -> Vec<Option<InstructStepResult>> {
346        let n = self.adapters.len();
347        let mut results = vec![None; n];
348
349        match self.schedule {
350            AdapterSchedule::Synchronized => {
351                // All adapters train one sample each
352                for i in 0..n {
353                    results[i] = self.train_step_adapter(i);
354                }
355            }
356            AdapterSchedule::RoundRobin | AdapterSchedule::PriorityValLoss => {
357                // Single adapter per step
358                if let Some(idx) = self.select_next_adapter() {
359                    results[idx] = self.train_step_adapter(idx);
360                }
361            }
362        }
363
364        results
365    }
366
367    /// Save a checkpoint for the specified adapter.
368    ///
369    /// Creates `{checkpoint_dir}/epoch-{epoch}/` with:
370    /// - `metadata.json`: adapter index, epoch, metrics
371    /// - `model.safetensors`: LoRA A/B weights for this adapter
372    pub fn save_adapter_checkpoint(
373        &self,
374        adapter_idx: usize,
375        epoch: usize,
376        avg_loss: f32,
377    ) -> Result<PathBuf, Box<dyn std::error::Error>> {
378        let slot = &self.adapters[adapter_idx];
379        let ckpt_dir = slot.checkpoint_dir.join(format!("epoch-{epoch}"));
380        std::fs::create_dir_all(&ckpt_dir)?;
381
382        // Metadata
383        let metadata = serde_json::json!({
384            "mode": "multi_adapter",
385            "adapter_index": adapter_idx,
386            "epoch": epoch,
387            "avg_loss": avg_loss,
388            "best_val_loss": slot.best_val_loss,
389            "lora_rank": slot.config.lora_rank,
390            "lora_alpha": slot.config.lora_alpha,
391            "train_samples": slot.train_samples.len(),
392            "global_step": self.global_step,
393        });
394        std::fs::write(ckpt_dir.join("metadata.json"), serde_json::to_string_pretty(&metadata)?)?;
395
396        // Save LoRA weights as SafeTensors
397        save_adapter_lora_weights(&slot.lora_layers, &ckpt_dir)?;
398
399        Ok(ckpt_dir)
400    }
401
402    /// Save best checkpoint for an adapter (overwrites previous best).
403    pub fn save_best_checkpoint(
404        &self,
405        adapter_idx: usize,
406        epoch: usize,
407        avg_loss: f32,
408    ) -> Result<PathBuf, Box<dyn std::error::Error>> {
409        let slot = &self.adapters[adapter_idx];
410        let best_dir = slot.checkpoint_dir.join("best");
411        std::fs::create_dir_all(&best_dir)?;
412
413        let metadata = serde_json::json!({
414            "mode": "multi_adapter",
415            "adapter_index": adapter_idx,
416            "epoch": epoch,
417            "avg_loss": avg_loss,
418            "lora_rank": slot.config.lora_rank,
419            "lora_alpha": slot.config.lora_alpha,
420            "global_step": self.global_step,
421        });
422        std::fs::write(best_dir.join("metadata.json"), serde_json::to_string_pretty(&metadata)?)?;
423
424        save_adapter_lora_weights(&slot.lora_layers, &best_dir)?;
425        Ok(best_dir)
426    }
427}
428
429/// Save LoRA A/B weights to a SafeTensors file in the given directory.
430fn save_adapter_lora_weights(
431    lora_layers: &[LoRALayer],
432    dir: &std::path::Path,
433) -> Result<(), Box<dyn std::error::Error>> {
434    let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
435
436    for (idx, lora) in lora_layers.iter().enumerate() {
437        let layer = idx / 2;
438        let proj = if idx % 2 == 0 { "q" } else { "v" };
439
440        // LoRA A: [rank, d_in]
441        let a_data = lora.lora_a().data();
442        let a_bytes: Vec<u8> =
443            bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
444        let a_shape = vec![lora.rank(), lora.d_in()];
445        tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
446
447        // LoRA B: [d_out, rank]
448        let b_data = lora.lora_b().data();
449        let b_bytes: Vec<u8> =
450            bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
451        let b_shape = vec![lora.d_out(), lora.rank()];
452        tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
453    }
454
455    let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
456        .iter()
457        .map(|(name, bytes, shape)| {
458            let view = safetensors::tensor::TensorView::new(
459                safetensors::tensor::Dtype::F32,
460                shape.clone(),
461                bytes,
462            )
463            .expect("valid tensor view");
464            (name.as_str(), view)
465        })
466        .collect();
467
468    let safetensor_bytes = safetensors::serialize(views, None)
469        .map_err(|e| format!("SafeTensors serialization failed: {e}"))?;
470    std::fs::write(dir.join("model.safetensors"), safetensor_bytes)?;
471    Ok(())
472}
473
474/// Simple Fisher-Yates shuffle with a deterministic seed.
475fn shuffle_samples(samples: &mut [InstructSample], seed: u64) {
476    let mut rng = seed;
477    for i in (1..samples.len()).rev() {
478        // xorshift64
479        rng ^= rng << 13;
480        rng ^= rng >> 7;
481        rng ^= rng << 17;
482        let j = (rng as usize) % (i + 1);
483        samples.swap(i, j);
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_schedule_round_robin() {
493        let sched = AdapterSchedule::RoundRobin;
494        let pipeline = MultiAdapterPipeline {
495            base_pipeline: create_dummy_pipeline(),
496            adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
497            schedule: sched,
498            global_step: 0,
499        };
500
501        assert_eq!(pipeline.select_next_adapter(), Some(0));
502
503        let pipeline = MultiAdapterPipeline { global_step: 1, ..pipeline };
504        assert_eq!(pipeline.select_next_adapter(), Some(1));
505
506        let pipeline = MultiAdapterPipeline { global_step: 5, ..pipeline };
507        assert_eq!(pipeline.select_next_adapter(), Some(2));
508    }
509
510    #[test]
511    fn test_schedule_priority_val_loss() {
512        let mut slot0 = dummy_slot();
513        slot0.best_val_loss = 1.0;
514        let mut slot1 = dummy_slot();
515        slot1.best_val_loss = 3.0; // worst
516        let mut slot2 = dummy_slot();
517        slot2.best_val_loss = 2.0;
518
519        let pipeline = MultiAdapterPipeline {
520            base_pipeline: create_dummy_pipeline(),
521            adapters: vec![slot0, slot1, slot2],
522            schedule: AdapterSchedule::PriorityValLoss,
523            global_step: 0,
524        };
525
526        assert_eq!(pipeline.select_next_adapter(), Some(1)); // highest loss
527    }
528
529    #[test]
530    fn test_empty_pipeline() {
531        let pipeline = MultiAdapterPipeline {
532            base_pipeline: create_dummy_pipeline(),
533            adapters: vec![],
534            schedule: AdapterSchedule::RoundRobin,
535            global_step: 0,
536        };
537        assert_eq!(pipeline.select_next_adapter(), None);
538        assert!(pipeline.all_exhausted());
539    }
540
541    #[test]
542    fn test_shuffle_deterministic() {
543        let mut samples1 = vec![
544            InstructSample {
545                instruction: "a".into(),
546                response: "1".into(),
547                system: None,
548                metadata: None,
549            },
550            InstructSample {
551                instruction: "b".into(),
552                response: "2".into(),
553                system: None,
554                metadata: None,
555            },
556            InstructSample {
557                instruction: "c".into(),
558                response: "3".into(),
559                system: None,
560                metadata: None,
561            },
562        ];
563        let mut samples2 = samples1.clone();
564
565        shuffle_samples(&mut samples1, 42);
566        shuffle_samples(&mut samples2, 42);
567
568        // Same seed → same order
569        for (s1, s2) in samples1.iter().zip(samples2.iter()) {
570            assert_eq!(s1.instruction, s2.instruction);
571        }
572    }
573
574    #[test]
575    fn test_batch_train_step_synchronized() {
576        let mut pipeline = MultiAdapterPipeline {
577            base_pipeline: create_dummy_pipeline(),
578            adapters: vec![dummy_slot(), dummy_slot()],
579            schedule: AdapterSchedule::Synchronized,
580            global_step: 0,
581        };
582
583        // No tokenizer → all results are None, but batch_train_step returns correct length
584        let results = pipeline.batch_train_step();
585        assert_eq!(results.len(), 2);
586    }
587
588    #[test]
589    fn test_batch_train_step_round_robin() {
590        let mut pipeline = MultiAdapterPipeline {
591            base_pipeline: create_dummy_pipeline(),
592            adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
593            schedule: AdapterSchedule::RoundRobin,
594            global_step: 0,
595        };
596
597        let results = pipeline.batch_train_step();
598        assert_eq!(results.len(), 3);
599        // RoundRobin at step 0 → only adapter 0 would be trained
600        // (but no tokenizer, so all None)
601    }
602
603    #[test]
604    fn test_adapters_config_parse() {
605        let toml = r#"
606[[adapter]]
607data = "data/corpus-a.jsonl"
608checkpoint = "checkpoints/adapter-a"
609label = "code-review"
610rank = 16
611learning_rate = 0.0002
612
613[[adapter]]
614data = "data/corpus-b.jsonl"
615checkpoint = "checkpoints/adapter-b"
616label = "bug-fixing"
617rank = 8
618"#;
619        let config = AdaptersConfigFile::from_toml(toml).expect("valid TOML");
620        assert_eq!(config.adapters.len(), 2);
621        assert_eq!(config.adapters[0].data, PathBuf::from("data/corpus-a.jsonl"));
622        assert_eq!(config.adapters[0].rank, Some(16));
623        assert_eq!(config.adapters[0].learning_rate, Some(0.0002));
624        assert_eq!(config.adapters[1].rank, Some(8));
625        assert!(config.adapters[1].learning_rate.is_none());
626    }
627
628    #[test]
629    fn test_adapters_config_to_adapter_configs() {
630        let toml = r#"
631[[adapter]]
632data = "data/a.jsonl"
633checkpoint = "ckpt/a"
634rank = 32
635learning_rate = 0.001
636epochs = 5
637max_seq_len = 256
638"#;
639        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
640        let base = InstructConfig::default();
641        let adapters = config.to_adapter_configs(&base);
642        assert_eq!(adapters.len(), 1);
643        assert_eq!(adapters[0].instruct_config.lora_rank, 32);
644        assert!((adapters[0].instruct_config.learning_rate - 0.001).abs() < f32::EPSILON);
645        assert_eq!(adapters[0].instruct_config.epochs, 5);
646        assert_eq!(adapters[0].instruct_config.max_seq_len, 256);
647    }
648
649    #[test]
650    fn test_adapters_config_empty_fails() {
651        let toml = "";
652        assert!(AdaptersConfigFile::from_toml(toml).is_err());
653    }
654
655    #[test]
656    fn test_adapters_config_defaults_from_base() {
657        let toml = r#"
658[[adapter]]
659data = "data/x.jsonl"
660checkpoint = "ckpt/x"
661"#;
662        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
663        let base = InstructConfig {
664            lora_rank: 16,
665            learning_rate: 0.0002,
666            epochs: 3,
667            max_seq_len: 512,
668            ..Default::default()
669        };
670        let adapters = config.to_adapter_configs(&base);
671        // Should inherit base defaults when not overridden
672        assert_eq!(adapters[0].instruct_config.lora_rank, 16);
673        assert!((adapters[0].instruct_config.learning_rate - 0.0002).abs() < f32::EPSILON);
674        assert_eq!(adapters[0].instruct_config.epochs, 3);
675        assert_eq!(adapters[0].instruct_config.max_seq_len, 512);
676    }
677
678    fn create_dummy_pipeline() -> InstructPipeline {
679        use crate::transformer::TransformerConfig;
680        let config = TransformerConfig::tiny();
681        InstructPipeline::new(&config, InstructConfig::default())
682    }
683
684    fn dummy_slot() -> AdapterSlot {
685        AdapterSlot {
686            lora_layers: Vec::new(),
687            train_samples: Vec::new(),
688            val_samples: Vec::new(),
689            checkpoint_dir: PathBuf::from("/tmp/test"),
690            metrics: Vec::new(),
691            config: InstructConfig::default(),
692            cursor: 0,
693            best_val_loss: f32::INFINITY,
694            #[cfg(feature = "cuda")]
695            optimizer_states: None,
696            #[cfg(feature = "cuda")]
697            lora_step: 0,
698        }
699    }
700
701    fn dummy_slot_with_data(n_samples: usize) -> AdapterSlot {
702        let samples: Vec<InstructSample> = (0..n_samples)
703            .map(|i| InstructSample {
704                instruction: format!("inst_{i}"),
705                response: format!("resp_{i}"),
706                system: None,
707                metadata: None,
708            })
709            .collect();
710        AdapterSlot {
711            lora_layers: Vec::new(),
712            train_samples: samples,
713            val_samples: Vec::new(),
714            checkpoint_dir: PathBuf::from("/tmp/test"),
715            metrics: Vec::new(),
716            config: InstructConfig::default(),
717            cursor: 0,
718            best_val_loss: f32::INFINITY,
719            #[cfg(feature = "cuda")]
720            optimizer_states: None,
721            #[cfg(feature = "cuda")]
722            lora_step: 0,
723        }
724    }
725
726    // ── Coverage improvement tests ───────────────────────────────
727
728    #[test]
729    fn test_adapter_schedule_default() {
730        let sched: AdapterSchedule = Default::default();
731        assert_eq!(sched, AdapterSchedule::RoundRobin);
732    }
733
734    #[test]
735    fn test_adapter_schedule_debug() {
736        assert_eq!(format!("{:?}", AdapterSchedule::Synchronized), "Synchronized");
737        assert_eq!(format!("{:?}", AdapterSchedule::RoundRobin), "RoundRobin");
738        assert_eq!(format!("{:?}", AdapterSchedule::PriorityValLoss), "PriorityValLoss");
739    }
740
741    #[test]
742    fn test_adapter_schedule_clone() {
743        let sched = AdapterSchedule::PriorityValLoss;
744        let cloned = sched;
745        assert_eq!(sched, cloned);
746    }
747
748    #[test]
749    fn test_adapter_schedule_eq() {
750        assert_eq!(AdapterSchedule::Synchronized, AdapterSchedule::Synchronized);
751        assert_ne!(AdapterSchedule::Synchronized, AdapterSchedule::RoundRobin);
752        assert_ne!(AdapterSchedule::RoundRobin, AdapterSchedule::PriorityValLoss);
753    }
754
755    #[test]
756    fn test_select_next_adapter_synchronized() {
757        let pipeline = MultiAdapterPipeline {
758            base_pipeline: create_dummy_pipeline(),
759            adapters: vec![dummy_slot(), dummy_slot()],
760            schedule: AdapterSchedule::Synchronized,
761            global_step: 0,
762        };
763        // Synchronized always returns Some(0)
764        assert_eq!(pipeline.select_next_adapter(), Some(0));
765    }
766
767    #[test]
768    fn test_select_next_adapter_synchronized_any_step() {
769        let pipeline = MultiAdapterPipeline {
770            base_pipeline: create_dummy_pipeline(),
771            adapters: vec![dummy_slot(), dummy_slot()],
772            schedule: AdapterSchedule::Synchronized,
773            global_step: 42,
774        };
775        assert_eq!(pipeline.select_next_adapter(), Some(0));
776    }
777
778    #[test]
779    fn test_select_next_adapter_round_robin_wraps() {
780        let pipeline = MultiAdapterPipeline {
781            base_pipeline: create_dummy_pipeline(),
782            adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
783            schedule: AdapterSchedule::RoundRobin,
784            global_step: 3,
785        };
786        assert_eq!(pipeline.select_next_adapter(), Some(0)); // 3 % 3 = 0
787    }
788
789    #[test]
790    fn test_select_next_adapter_priority_all_infinity() {
791        // All slots have INFINITY best_val_loss → first one wins (or any, but deterministic)
792        let pipeline = MultiAdapterPipeline {
793            base_pipeline: create_dummy_pipeline(),
794            adapters: vec![dummy_slot(), dummy_slot()],
795            schedule: AdapterSchedule::PriorityValLoss,
796            global_step: 0,
797        };
798        let result = pipeline.select_next_adapter();
799        assert!(result.is_some());
800    }
801
802    #[test]
803    fn test_select_next_adapter_priority_with_nan() {
804        let mut slot0 = dummy_slot();
805        slot0.best_val_loss = f32::NAN;
806        let mut slot1 = dummy_slot();
807        slot1.best_val_loss = 1.0;
808
809        let pipeline = MultiAdapterPipeline {
810            base_pipeline: create_dummy_pipeline(),
811            adapters: vec![slot0, slot1],
812            schedule: AdapterSchedule::PriorityValLoss,
813            global_step: 0,
814        };
815        // NaN comparison uses Ordering::Equal fallback, so result is deterministic
816        let result = pipeline.select_next_adapter();
817        assert!(result.is_some());
818    }
819
820    #[test]
821    fn test_num_adapters() {
822        let pipeline = MultiAdapterPipeline {
823            base_pipeline: create_dummy_pipeline(),
824            adapters: vec![dummy_slot(), dummy_slot(), dummy_slot()],
825            schedule: AdapterSchedule::RoundRobin,
826            global_step: 0,
827        };
828        assert_eq!(pipeline.num_adapters(), 3);
829    }
830
831    #[test]
832    fn test_num_adapters_empty() {
833        let pipeline = MultiAdapterPipeline {
834            base_pipeline: create_dummy_pipeline(),
835            adapters: vec![],
836            schedule: AdapterSchedule::RoundRobin,
837            global_step: 0,
838        };
839        assert_eq!(pipeline.num_adapters(), 0);
840    }
841
842    #[test]
843    fn test_all_exhausted_with_data() {
844        let pipeline = MultiAdapterPipeline {
845            base_pipeline: create_dummy_pipeline(),
846            adapters: vec![dummy_slot_with_data(3), dummy_slot_with_data(2)],
847            schedule: AdapterSchedule::RoundRobin,
848            global_step: 0,
849        };
850        assert!(!pipeline.all_exhausted());
851    }
852
853    #[test]
854    fn test_all_exhausted_partially() {
855        let mut slot0 = dummy_slot_with_data(3);
856        slot0.cursor = 3; // exhausted
857        let slot1 = dummy_slot_with_data(2); // not exhausted
858
859        let pipeline = MultiAdapterPipeline {
860            base_pipeline: create_dummy_pipeline(),
861            adapters: vec![slot0, slot1],
862            schedule: AdapterSchedule::RoundRobin,
863            global_step: 0,
864        };
865        assert!(!pipeline.all_exhausted());
866    }
867
868    #[test]
869    fn test_all_exhausted_all_done() {
870        let mut slot0 = dummy_slot_with_data(3);
871        slot0.cursor = 3;
872        let mut slot1 = dummy_slot_with_data(2);
873        slot1.cursor = 2;
874
875        let pipeline = MultiAdapterPipeline {
876            base_pipeline: create_dummy_pipeline(),
877            adapters: vec![slot0, slot1],
878            schedule: AdapterSchedule::RoundRobin,
879            global_step: 0,
880        };
881        assert!(pipeline.all_exhausted());
882    }
883
884    #[test]
885    fn test_reset_epoch() {
886        let mut pipeline = MultiAdapterPipeline {
887            base_pipeline: create_dummy_pipeline(),
888            adapters: vec![dummy_slot_with_data(5), dummy_slot_with_data(3)],
889            schedule: AdapterSchedule::RoundRobin,
890            global_step: 0,
891        };
892        pipeline.adapters[0].cursor = 5;
893        pipeline.adapters[1].cursor = 3;
894
895        pipeline.reset_epoch(42);
896
897        assert_eq!(pipeline.adapters[0].cursor, 0);
898        assert_eq!(pipeline.adapters[1].cursor, 0);
899    }
900
901    #[test]
902    fn test_reset_epoch_shuffle_deterministic() {
903        let mut pipeline1 = MultiAdapterPipeline {
904            base_pipeline: create_dummy_pipeline(),
905            adapters: vec![dummy_slot_with_data(10)],
906            schedule: AdapterSchedule::RoundRobin,
907            global_step: 0,
908        };
909        let mut pipeline2 = MultiAdapterPipeline {
910            base_pipeline: create_dummy_pipeline(),
911            adapters: vec![dummy_slot_with_data(10)],
912            schedule: AdapterSchedule::RoundRobin,
913            global_step: 0,
914        };
915
916        pipeline1.reset_epoch(123);
917        pipeline2.reset_epoch(123);
918
919        // Same seed should produce same shuffle
920        for (s1, s2) in pipeline1.adapters[0]
921            .train_samples
922            .iter()
923            .zip(pipeline2.adapters[0].train_samples.iter())
924        {
925            assert_eq!(s1.instruction, s2.instruction);
926        }
927    }
928
929    #[test]
930    fn test_shuffle_samples_empty() {
931        let mut samples: Vec<InstructSample> = vec![];
932        shuffle_samples(&mut samples, 42);
933        assert!(samples.is_empty());
934    }
935
936    #[test]
937    fn test_shuffle_samples_single() {
938        let mut samples = vec![InstructSample {
939            instruction: "only".into(),
940            response: "one".into(),
941            system: None,
942            metadata: None,
943        }];
944        shuffle_samples(&mut samples, 42);
945        assert_eq!(samples.len(), 1);
946        assert_eq!(samples[0].instruction, "only");
947    }
948
949    #[test]
950    fn test_shuffle_samples_different_seeds() {
951        let mut samples1 = vec![
952            InstructSample {
953                instruction: "a".into(),
954                response: "1".into(),
955                system: None,
956                metadata: None,
957            },
958            InstructSample {
959                instruction: "b".into(),
960                response: "2".into(),
961                system: None,
962                metadata: None,
963            },
964            InstructSample {
965                instruction: "c".into(),
966                response: "3".into(),
967                system: None,
968                metadata: None,
969            },
970            InstructSample {
971                instruction: "d".into(),
972                response: "4".into(),
973                system: None,
974                metadata: None,
975            },
976            InstructSample {
977                instruction: "e".into(),
978                response: "5".into(),
979                system: None,
980                metadata: None,
981            },
982        ];
983        let mut samples2 = samples1.clone();
984
985        shuffle_samples(&mut samples1, 1);
986        shuffle_samples(&mut samples2, 999);
987
988        // Different seeds should (very likely) produce different orderings
989        let same =
990            samples1.iter().zip(samples2.iter()).all(|(s1, s2)| s1.instruction == s2.instruction);
991        // With 5! = 120 permutations, probability of same is ~0.83%, so this is safe
992        assert!(!same, "Different seeds should produce different shuffles");
993    }
994
995    #[test]
996    fn test_adapters_config_from_toml_invalid_toml() {
997        let toml = "this is not valid TOML {{{}}}";
998        let result = AdaptersConfigFile::from_toml(toml);
999        assert!(result.is_err());
1000        let err = result.unwrap_err();
1001        assert!(err.contains("failed to parse"), "Expected parse error, got: {err}");
1002    }
1003
1004    #[test]
1005    fn test_adapters_config_from_toml_empty_adapters_array() {
1006        // Valid TOML but no [[adapter]] entries → should fail
1007        let toml = r#"
1008[settings]
1009foo = "bar"
1010"#;
1011        let result = AdaptersConfigFile::from_toml(toml);
1012        assert!(result.is_err());
1013    }
1014
1015    #[test]
1016    fn test_adapters_config_from_file_not_found() {
1017        let result = AdaptersConfigFile::from_file(Path::new("/tmp/nonexistent_adapters_xyz.toml"));
1018        assert!(result.is_err());
1019        let err = result.unwrap_err();
1020        assert!(err.contains("failed to read"), "Expected read error, got: {err}");
1021    }
1022
1023    #[test]
1024    fn test_adapters_config_from_file_valid() {
1025        let dir = std::env::temp_dir().join("entrenar_adapter_cfg_test");
1026        std::fs::create_dir_all(&dir).expect("create dir");
1027        let path = dir.join("adapters.toml");
1028        std::fs::write(
1029            &path,
1030            r#"
1031[[adapter]]
1032data = "data/a.jsonl"
1033checkpoint = "ckpt/a"
1034label = "test-adapter"
1035"#,
1036        )
1037        .expect("write file");
1038        let config = AdaptersConfigFile::from_file(&path).expect("valid config");
1039        assert_eq!(config.adapters.len(), 1);
1040        assert_eq!(config.adapters[0].label, Some("test-adapter".to_string()));
1041        std::fs::remove_file(&path).expect("cleanup");
1042    }
1043
1044    #[test]
1045    fn test_adapter_entry_defaults() {
1046        let toml = r#"
1047[[adapter]]
1048data = "data/x.jsonl"
1049checkpoint = "ckpt/x"
1050"#;
1051        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1052        let entry = &config.adapters[0];
1053        assert!(entry.label.is_none());
1054        assert!(entry.rank.is_none());
1055        assert!(entry.learning_rate.is_none());
1056        assert!(entry.epochs.is_none());
1057        assert!(entry.max_seq_len.is_none());
1058    }
1059
1060    #[test]
1061    fn test_adapter_entry_all_fields() {
1062        let toml = r#"
1063[[adapter]]
1064data = "data/full.jsonl"
1065checkpoint = "ckpt/full"
1066label = "full-adapter"
1067rank = 64
1068learning_rate = 0.001
1069epochs = 10
1070max_seq_len = 1024
1071"#;
1072        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1073        let entry = &config.adapters[0];
1074        assert_eq!(entry.data, PathBuf::from("data/full.jsonl"));
1075        assert_eq!(entry.checkpoint, PathBuf::from("ckpt/full"));
1076        assert_eq!(entry.label, Some("full-adapter".to_string()));
1077        assert_eq!(entry.rank, Some(64));
1078        assert_eq!(entry.learning_rate, Some(0.001));
1079        assert_eq!(entry.epochs, Some(10));
1080        assert_eq!(entry.max_seq_len, Some(1024));
1081    }
1082
1083    #[test]
1084    fn test_to_adapter_configs_rank_sets_alpha() {
1085        let toml = r#"
1086[[adapter]]
1087data = "data/a.jsonl"
1088checkpoint = "ckpt/a"
1089rank = 32
1090"#;
1091        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1092        let base = InstructConfig::default();
1093        let adapters = config.to_adapter_configs(&base);
1094        // rank=32 → alpha = 32*2.0 = 64.0
1095        assert_eq!(adapters[0].instruct_config.lora_rank, 32);
1096        assert!((adapters[0].instruct_config.lora_alpha - 64.0).abs() < f32::EPSILON);
1097    }
1098
1099    #[test]
1100    fn test_to_adapter_configs_multiple() {
1101        let toml = r#"
1102[[adapter]]
1103data = "a.jsonl"
1104checkpoint = "ckpt/a"
1105rank = 8
1106learning_rate = 0.0001
1107
1108[[adapter]]
1109data = "b.jsonl"
1110checkpoint = "ckpt/b"
1111epochs = 20
1112
1113[[adapter]]
1114data = "c.jsonl"
1115checkpoint = "ckpt/c"
1116max_seq_len = 128
1117"#;
1118        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1119        let base = InstructConfig {
1120            lora_rank: 16,
1121            learning_rate: 0.0002,
1122            epochs: 3,
1123            max_seq_len: 512,
1124            ..Default::default()
1125        };
1126        let adapters = config.to_adapter_configs(&base);
1127        assert_eq!(adapters.len(), 3);
1128
1129        // First adapter: rank=8, lr=0.0001
1130        assert_eq!(adapters[0].instruct_config.lora_rank, 8);
1131        assert!((adapters[0].instruct_config.learning_rate - 0.0001).abs() < f32::EPSILON);
1132        assert_eq!(adapters[0].instruct_config.epochs, 3); // inherited
1133
1134        // Second adapter: epochs=20
1135        assert_eq!(adapters[1].instruct_config.lora_rank, 16); // inherited
1136        assert_eq!(adapters[1].instruct_config.epochs, 20);
1137
1138        // Third adapter: max_seq_len=128
1139        assert_eq!(adapters[2].instruct_config.max_seq_len, 128);
1140        assert_eq!(adapters[2].instruct_config.lora_rank, 16); // inherited
1141    }
1142
1143    #[test]
1144    fn test_batch_train_step_priority_val_loss() {
1145        let mut slot0 = dummy_slot();
1146        slot0.best_val_loss = 2.0;
1147        let mut slot1 = dummy_slot();
1148        slot1.best_val_loss = 5.0; // worst → should be selected
1149
1150        let mut pipeline = MultiAdapterPipeline {
1151            base_pipeline: create_dummy_pipeline(),
1152            adapters: vec![slot0, slot1],
1153            schedule: AdapterSchedule::PriorityValLoss,
1154            global_step: 0,
1155        };
1156
1157        let results = pipeline.batch_train_step();
1158        assert_eq!(results.len(), 2);
1159        // No tokenizer → both None, but the function should not panic
1160    }
1161
1162    #[test]
1163    fn test_adapter_config_debug() {
1164        let config = AdapterConfig {
1165            data_path: PathBuf::from("test.jsonl"),
1166            checkpoint_dir: PathBuf::from("/tmp/ckpt"),
1167            instruct_config: InstructConfig::default(),
1168        };
1169        let debug = format!("{config:?}");
1170        assert!(debug.contains("AdapterConfig"));
1171        assert!(debug.contains("test.jsonl"));
1172    }
1173
1174    #[test]
1175    fn test_adapter_config_clone() {
1176        let config = AdapterConfig {
1177            data_path: PathBuf::from("test.jsonl"),
1178            checkpoint_dir: PathBuf::from("/tmp/ckpt"),
1179            instruct_config: InstructConfig::default(),
1180        };
1181        let cloned = config.clone();
1182        assert_eq!(cloned.data_path, PathBuf::from("test.jsonl"));
1183        assert_eq!(cloned.checkpoint_dir, PathBuf::from("/tmp/ckpt"));
1184    }
1185
1186    #[test]
1187    fn test_adapters_config_file_debug() {
1188        let toml = r#"
1189[[adapter]]
1190data = "a.jsonl"
1191checkpoint = "ckpt/a"
1192"#;
1193        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1194        let debug = format!("{config:?}");
1195        assert!(debug.contains("AdaptersConfigFile"));
1196    }
1197
1198    #[test]
1199    fn test_adapter_entry_debug() {
1200        let toml = r#"
1201[[adapter]]
1202data = "a.jsonl"
1203checkpoint = "ckpt/a"
1204label = "test"
1205"#;
1206        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1207        let debug = format!("{:?}", config.adapters[0]);
1208        assert!(debug.contains("AdapterEntry"));
1209        assert!(debug.contains("test"));
1210    }
1211
1212    #[test]
1213    fn test_adapter_slot_cursor_tracking() {
1214        let mut slot = dummy_slot_with_data(5);
1215        assert_eq!(slot.cursor, 0);
1216        slot.cursor = 3;
1217        assert_eq!(slot.cursor, 3);
1218        assert!(slot.cursor < slot.train_samples.len());
1219        slot.cursor = 5;
1220        assert!(slot.cursor >= slot.train_samples.len());
1221    }
1222
1223    #[test]
1224    fn test_adapter_slot_best_val_loss() {
1225        let mut slot = dummy_slot();
1226        assert_eq!(slot.best_val_loss, f32::INFINITY);
1227        slot.best_val_loss = 0.5;
1228        assert!((slot.best_val_loss - 0.5).abs() < f32::EPSILON);
1229    }
1230
1231    #[test]
1232    fn test_multi_adapter_pipeline_global_step() {
1233        let pipeline = MultiAdapterPipeline {
1234            base_pipeline: create_dummy_pipeline(),
1235            adapters: vec![],
1236            schedule: AdapterSchedule::RoundRobin,
1237            global_step: 0,
1238        };
1239        assert_eq!(pipeline.global_step, 0);
1240    }
1241
1242    #[test]
1243    fn test_train_step_adapter_exhausted() {
1244        let mut slot = dummy_slot_with_data(2);
1245        slot.cursor = 2; // already exhausted
1246
1247        let mut pipeline = MultiAdapterPipeline {
1248            base_pipeline: create_dummy_pipeline(),
1249            adapters: vec![slot],
1250            schedule: AdapterSchedule::RoundRobin,
1251            global_step: 0,
1252        };
1253
1254        let result = pipeline.train_step_adapter(0);
1255        assert!(result.is_none(), "Exhausted adapter should return None");
1256    }
1257
1258    #[test]
1259    fn test_batch_train_step_empty() {
1260        let mut pipeline = MultiAdapterPipeline {
1261            base_pipeline: create_dummy_pipeline(),
1262            adapters: vec![],
1263            schedule: AdapterSchedule::Synchronized,
1264            global_step: 0,
1265        };
1266        let results = pipeline.batch_train_step();
1267        assert!(results.is_empty());
1268    }
1269
1270    // ── Additional coverage tests ─────────────────────────────────
1271
1272    #[test]
1273    fn test_multi_adapter_pipeline_new() {
1274        let pipeline =
1275            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::Synchronized);
1276        assert_eq!(pipeline.num_adapters(), 0);
1277        assert_eq!(pipeline.global_step, 0);
1278        assert!(pipeline.all_exhausted());
1279    }
1280
1281    #[test]
1282    fn test_multi_adapter_pipeline_add_adapter() {
1283        let mut pipeline =
1284            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1285        let config = AdapterConfig {
1286            data_path: PathBuf::from("data.jsonl"),
1287            checkpoint_dir: PathBuf::from("/tmp/ckpt"),
1288            instruct_config: InstructConfig::default(),
1289        };
1290        let samples = vec![InstructSample {
1291            instruction: "test".into(),
1292            response: "response".into(),
1293            system: None,
1294            metadata: None,
1295        }];
1296        pipeline.add_adapter(config, samples, vec![]);
1297        assert_eq!(pipeline.num_adapters(), 1);
1298        assert!(!pipeline.all_exhausted());
1299    }
1300
1301    #[test]
1302    fn test_train_step_adapter_no_tokenizer() {
1303        let mut pipeline = MultiAdapterPipeline {
1304            base_pipeline: create_dummy_pipeline(),
1305            adapters: vec![dummy_slot_with_data(5)],
1306            schedule: AdapterSchedule::RoundRobin,
1307            global_step: 0,
1308        };
1309        // No tokenizer loaded → should return None
1310        let result = pipeline.train_step_adapter(0);
1311        assert!(result.is_none());
1312        // Cursor should have advanced
1313        assert_eq!(pipeline.adapters[0].cursor, 1);
1314    }
1315
1316    #[test]
1317    fn test_train_step_increments_global_step() {
1318        let mut pipeline = MultiAdapterPipeline {
1319            base_pipeline: create_dummy_pipeline(),
1320            adapters: vec![dummy_slot_with_data(5)],
1321            schedule: AdapterSchedule::RoundRobin,
1322            global_step: 0,
1323        };
1324        // Even though result is None (no tokenizer), global_step should not increment
1325        // because early return happens before step increment
1326        let _ = pipeline.train_step_adapter(0);
1327        // Cursor advanced to 1, but no tokenizer so returns early before global_step increment
1328    }
1329
1330    #[test]
1331    fn test_batch_train_step_synchronized_all_exhausted() {
1332        let mut slot0 = dummy_slot_with_data(1);
1333        slot0.cursor = 1;
1334        let mut slot1 = dummy_slot_with_data(1);
1335        slot1.cursor = 1;
1336
1337        let mut pipeline = MultiAdapterPipeline {
1338            base_pipeline: create_dummy_pipeline(),
1339            adapters: vec![slot0, slot1],
1340            schedule: AdapterSchedule::Synchronized,
1341            global_step: 0,
1342        };
1343        let results = pipeline.batch_train_step();
1344        assert_eq!(results.len(), 2);
1345        assert!(results.iter().all(Option::is_none));
1346    }
1347
1348    #[test]
1349    fn test_reset_epoch_different_seeds_different_orders() {
1350        let mut pipeline1 = MultiAdapterPipeline {
1351            base_pipeline: create_dummy_pipeline(),
1352            adapters: vec![dummy_slot_with_data(20)],
1353            schedule: AdapterSchedule::RoundRobin,
1354            global_step: 0,
1355        };
1356        let mut pipeline2 = MultiAdapterPipeline {
1357            base_pipeline: create_dummy_pipeline(),
1358            adapters: vec![dummy_slot_with_data(20)],
1359            schedule: AdapterSchedule::RoundRobin,
1360            global_step: 0,
1361        };
1362
1363        pipeline1.reset_epoch(1);
1364        pipeline2.reset_epoch(999);
1365
1366        let same = pipeline1.adapters[0]
1367            .train_samples
1368            .iter()
1369            .zip(pipeline2.adapters[0].train_samples.iter())
1370            .all(|(s1, s2)| s1.instruction == s2.instruction);
1371        assert!(!same, "Different seeds should produce different shuffles");
1372    }
1373
1374    #[test]
1375    fn test_shuffle_samples_preserves_elements() {
1376        let mut samples: Vec<InstructSample> = (0..10)
1377            .map(|i| InstructSample {
1378                instruction: format!("inst_{i}"),
1379                response: format!("resp_{i}"),
1380                system: None,
1381                metadata: None,
1382            })
1383            .collect();
1384        let original_instructions: Vec<String> =
1385            samples.iter().map(|s| s.instruction.clone()).collect();
1386
1387        shuffle_samples(&mut samples, 42);
1388
1389        // All original elements should still be present
1390        let mut shuffled_instructions: Vec<String> =
1391            samples.iter().map(|s| s.instruction.clone()).collect();
1392        let mut sorted_original = original_instructions.clone();
1393        sorted_original.sort();
1394        shuffled_instructions.sort();
1395        assert_eq!(sorted_original, shuffled_instructions);
1396    }
1397
1398    #[test]
1399    fn test_adapter_slot_metrics_empty() {
1400        let slot = dummy_slot();
1401        assert!(slot.metrics.is_empty());
1402    }
1403
1404    #[test]
1405    fn test_adapter_slot_val_samples() {
1406        let slot = dummy_slot();
1407        assert!(slot.val_samples.is_empty());
1408    }
1409
1410    #[test]
1411    fn test_adapter_slot_lora_layers_empty() {
1412        let slot = dummy_slot();
1413        assert!(slot.lora_layers.is_empty());
1414    }
1415
1416    #[test]
1417    fn test_adapters_config_label_propagation() {
1418        let toml = r#"
1419[[adapter]]
1420data = "d1.jsonl"
1421checkpoint = "c1"
1422label = "adapter-one"
1423
1424[[adapter]]
1425data = "d2.jsonl"
1426checkpoint = "c2"
1427"#;
1428        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1429        assert_eq!(config.adapters[0].label, Some("adapter-one".to_string()));
1430        assert!(config.adapters[1].label.is_none());
1431    }
1432
1433    #[test]
1434    fn test_adapters_config_to_adapter_configs_alpha_calculation() {
1435        let toml = r#"
1436[[adapter]]
1437data = "data.jsonl"
1438checkpoint = "ckpt"
1439rank = 64
1440"#;
1441        let config = AdaptersConfigFile::from_toml(toml).expect("valid");
1442        let base = InstructConfig::default();
1443        let adapters = config.to_adapter_configs(&base);
1444        // alpha = rank * 2.0 = 128.0
1445        assert!((adapters[0].instruct_config.lora_alpha - 128.0).abs() < f32::EPSILON);
1446    }
1447
1448    #[test]
1449    fn test_select_next_adapter_round_robin_large_step() {
1450        let pipeline = MultiAdapterPipeline {
1451            base_pipeline: create_dummy_pipeline(),
1452            adapters: vec![dummy_slot(), dummy_slot()],
1453            schedule: AdapterSchedule::RoundRobin,
1454            global_step: 1000,
1455        };
1456        assert_eq!(pipeline.select_next_adapter(), Some(0)); // 1000 % 2 = 0
1457
1458        let pipeline = MultiAdapterPipeline { global_step: 1001, ..pipeline };
1459        assert_eq!(pipeline.select_next_adapter(), Some(1)); // 1001 % 2 = 1
1460    }
1461
1462    #[test]
1463    fn test_select_next_adapter_priority_selects_worst() {
1464        let mut slot0 = dummy_slot();
1465        slot0.best_val_loss = 0.1;
1466        let mut slot1 = dummy_slot();
1467        slot1.best_val_loss = 10.0;
1468        let mut slot2 = dummy_slot();
1469        slot2.best_val_loss = 5.0;
1470
1471        let pipeline = MultiAdapterPipeline {
1472            base_pipeline: create_dummy_pipeline(),
1473            adapters: vec![slot0, slot1, slot2],
1474            schedule: AdapterSchedule::PriorityValLoss,
1475            global_step: 0,
1476        };
1477        assert_eq!(pipeline.select_next_adapter(), Some(1)); // 10.0 is worst
1478    }
1479
1480    #[test]
1481    fn test_multi_adapter_multiple_add_adapter() {
1482        let mut pipeline =
1483            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::Synchronized);
1484
1485        for i in 0..3 {
1486            let config = AdapterConfig {
1487                data_path: PathBuf::from(format!("data{i}.jsonl")),
1488                checkpoint_dir: PathBuf::from(format!("/tmp/ckpt{i}")),
1489                instruct_config: InstructConfig::default(),
1490            };
1491            pipeline.add_adapter(config, vec![], vec![]);
1492        }
1493        assert_eq!(pipeline.num_adapters(), 3);
1494        assert!(pipeline.all_exhausted()); // all empty
1495    }
1496
1497    // ── cov3: additional coverage tests ─────────────────────────────
1498
1499    #[test]
1500    fn test_cov3_save_adapter_checkpoint_creates_dir_and_files() {
1501        let dir = std::env::temp_dir().join("entrenar_cov3_ckpt_test");
1502        let _ = std::fs::remove_dir_all(&dir);
1503
1504        let mut pipeline =
1505            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1506        let config = AdapterConfig {
1507            data_path: PathBuf::from("data.jsonl"),
1508            checkpoint_dir: dir.clone(),
1509            instruct_config: InstructConfig::default(),
1510        };
1511        let samples = vec![InstructSample {
1512            instruction: "test".into(),
1513            response: "resp".into(),
1514            system: None,
1515            metadata: None,
1516        }];
1517        pipeline.add_adapter(config, samples, vec![]);
1518
1519        let result = pipeline.save_adapter_checkpoint(0, 1, 0.5);
1520        assert!(result.is_ok());
1521        let ckpt_dir = result.unwrap();
1522        assert!(ckpt_dir.join("metadata.json").exists());
1523        assert!(ckpt_dir.join("model.safetensors").exists());
1524
1525        // Verify metadata contents
1526        let metadata_str = std::fs::read_to_string(ckpt_dir.join("metadata.json")).unwrap();
1527        assert!(metadata_str.contains("\"mode\": \"multi_adapter\""));
1528        assert!(metadata_str.contains("\"adapter_index\": 0"));
1529        assert!(metadata_str.contains("\"epoch\": 1"));
1530
1531        let _ = std::fs::remove_dir_all(&dir);
1532    }
1533
1534    #[test]
1535    fn test_cov3_save_best_checkpoint_creates_dir_and_files() {
1536        let dir = std::env::temp_dir().join("entrenar_cov3_best_ckpt_test");
1537        let _ = std::fs::remove_dir_all(&dir);
1538
1539        let mut pipeline =
1540            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1541        let config = AdapterConfig {
1542            data_path: PathBuf::from("data.jsonl"),
1543            checkpoint_dir: dir.clone(),
1544            instruct_config: InstructConfig::default(),
1545        };
1546        pipeline.add_adapter(config, vec![], vec![]);
1547
1548        let result = pipeline.save_best_checkpoint(0, 2, 0.3);
1549        assert!(result.is_ok());
1550        let best_dir = result.unwrap();
1551        assert_eq!(best_dir, dir.join("best"));
1552        assert!(best_dir.join("metadata.json").exists());
1553        assert!(best_dir.join("model.safetensors").exists());
1554
1555        // Verify metadata
1556        let metadata_str = std::fs::read_to_string(best_dir.join("metadata.json")).unwrap();
1557        assert!(metadata_str.contains("\"mode\": \"multi_adapter\""));
1558        assert!(metadata_str.contains("\"epoch\": 2"));
1559
1560        let _ = std::fs::remove_dir_all(&dir);
1561    }
1562
1563    #[test]
1564    fn test_cov3_save_best_checkpoint_overwrites_previous() {
1565        let dir = std::env::temp_dir().join("entrenar_cov3_best_overwrite");
1566        let _ = std::fs::remove_dir_all(&dir);
1567
1568        let mut pipeline =
1569            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1570        let config = AdapterConfig {
1571            data_path: PathBuf::from("data.jsonl"),
1572            checkpoint_dir: dir.clone(),
1573            instruct_config: InstructConfig::default(),
1574        };
1575        pipeline.add_adapter(config, vec![], vec![]);
1576
1577        // First save
1578        pipeline.save_best_checkpoint(0, 1, 1.0).unwrap();
1579        // Second save should overwrite
1580        pipeline.save_best_checkpoint(0, 5, 0.2).unwrap();
1581
1582        let metadata_str = std::fs::read_to_string(dir.join("best").join("metadata.json")).unwrap();
1583        assert!(metadata_str.contains("\"epoch\": 5"));
1584
1585        let _ = std::fs::remove_dir_all(&dir);
1586    }
1587
1588    #[test]
1589    fn test_cov3_save_adapter_lora_weights_empty_layers() {
1590        let dir = std::env::temp_dir().join("entrenar_cov3_empty_lora");
1591        let _ = std::fs::remove_dir_all(&dir);
1592        std::fs::create_dir_all(&dir).unwrap();
1593
1594        let result = save_adapter_lora_weights(&[], &dir);
1595        assert!(result.is_ok());
1596        // SafeTensors file should exist even with empty layers
1597        assert!(dir.join("model.safetensors").exists());
1598
1599        let _ = std::fs::remove_dir_all(&dir);
1600    }
1601
1602    #[test]
1603    fn test_cov3_save_adapter_lora_weights_with_real_layers() {
1604        let dir = std::env::temp_dir().join("entrenar_cov3_real_lora");
1605        let _ = std::fs::remove_dir_all(&dir);
1606        std::fs::create_dir_all(&dir).unwrap();
1607
1608        // Create a pipeline with real LoRA layers
1609        let model_config = crate::transformer::TransformerConfig::tiny();
1610        let model = crate::transformer::Transformer::new(&model_config);
1611        let instruct_config = InstructConfig { lora_rank: 4, ..InstructConfig::default() };
1612        let layers = InstructPipeline::build_lora_layers(&model, &model_config, &instruct_config);
1613
1614        let result = save_adapter_lora_weights(&layers, &dir);
1615        assert!(result.is_ok());
1616
1617        // Verify SafeTensors can be read back
1618        let st_bytes = std::fs::read(dir.join("model.safetensors")).unwrap();
1619        let st = safetensors::SafeTensors::deserialize(&st_bytes).unwrap();
1620        // 4 LoRA layers → 4 * 2 (A and B) = 8 tensors
1621        assert_eq!(st.len(), layers.len() * 2);
1622
1623        // Verify naming convention
1624        let names: Vec<String> = st.names().iter().map(std::string::ToString::to_string).collect();
1625        assert!(names.iter().any(|n| n.contains("lora_a")));
1626        assert!(names.iter().any(|n| n.contains("lora_b")));
1627        assert!(names.iter().any(|n| n.contains("q_proj")));
1628        assert!(names.iter().any(|n| n.contains("v_proj")));
1629
1630        let _ = std::fs::remove_dir_all(&dir);
1631    }
1632
1633    #[test]
1634    fn test_cov3_shuffle_samples_large_input() {
1635        let mut samples: Vec<InstructSample> = (0..100)
1636            .map(|i| InstructSample {
1637                instruction: format!("inst_{i}"),
1638                response: format!("resp_{i}"),
1639                system: None,
1640                metadata: None,
1641            })
1642            .collect();
1643        let original: Vec<String> = samples.iter().map(|s| s.instruction.clone()).collect();
1644
1645        shuffle_samples(&mut samples, 12345);
1646
1647        let shuffled: Vec<String> = samples.iter().map(|s| s.instruction.clone()).collect();
1648        // Should be different order
1649        assert_ne!(original, shuffled, "100 samples should shuffle to different order");
1650        // But same elements
1651        let mut sorted_original = original;
1652        sorted_original.sort();
1653        let mut sorted_shuffled = shuffled;
1654        sorted_shuffled.sort();
1655        assert_eq!(sorted_original, sorted_shuffled);
1656    }
1657
1658    #[test]
1659    fn test_cov3_shuffle_samples_two_elements() {
1660        let mut samples = vec![
1661            InstructSample {
1662                instruction: "a".into(),
1663                response: "1".into(),
1664                system: None,
1665                metadata: None,
1666            },
1667            InstructSample {
1668                instruction: "b".into(),
1669                response: "2".into(),
1670                system: None,
1671                metadata: None,
1672            },
1673        ];
1674        // Verify no panic with 2 elements
1675        shuffle_samples(&mut samples, 42);
1676        assert_eq!(samples.len(), 2);
1677    }
1678
1679    #[test]
1680    fn test_cov3_adapters_config_toml_all_overrides() {
1681        let toml = r#"
1682[[adapter]]
1683data = "data/test.jsonl"
1684checkpoint = "ckpt/test"
1685label = "full-override"
1686rank = 64
1687learning_rate = 0.001
1688epochs = 20
1689max_seq_len = 2048
1690"#;
1691        let config = AdaptersConfigFile::from_toml(toml).unwrap();
1692        let base = InstructConfig::default();
1693        let adapters = config.to_adapter_configs(&base);
1694        assert_eq!(adapters[0].instruct_config.lora_rank, 64);
1695        assert!((adapters[0].instruct_config.lora_alpha - 128.0).abs() < f32::EPSILON);
1696        assert!((adapters[0].instruct_config.learning_rate - 0.001).abs() < f32::EPSILON);
1697        assert_eq!(adapters[0].instruct_config.epochs, 20);
1698        assert_eq!(adapters[0].instruct_config.max_seq_len, 2048);
1699    }
1700
1701    #[test]
1702    fn test_cov3_adapters_config_many_adapters() {
1703        let mut toml_str = String::new();
1704        for i in 0..10 {
1705            toml_str.push_str(&format!(
1706                r#"
1707[[adapter]]
1708data = "data/{i}.jsonl"
1709checkpoint = "ckpt/{i}"
1710rank = {rank}
1711"#,
1712                i = i,
1713                rank = 4 + i * 2,
1714            ));
1715        }
1716        let config = AdaptersConfigFile::from_toml(&toml_str).unwrap();
1717        assert_eq!(config.adapters.len(), 10);
1718        // Verify each has the right rank
1719        for (i, entry) in config.adapters.iter().enumerate() {
1720            assert_eq!(entry.rank, Some(4 + i * 2));
1721        }
1722    }
1723
1724    #[test]
1725    fn test_cov3_adapters_config_toml_missing_required_fields() {
1726        // Missing checkpoint field
1727        let toml = r#"
1728[[adapter]]
1729data = "data.jsonl"
1730"#;
1731        let result = AdaptersConfigFile::from_toml(toml);
1732        assert!(result.is_err());
1733    }
1734
1735    #[test]
1736    fn test_cov3_adapters_config_toml_missing_data_field() {
1737        let toml = r#"
1738[[adapter]]
1739checkpoint = "ckpt"
1740"#;
1741        let result = AdaptersConfigFile::from_toml(toml);
1742        assert!(result.is_err());
1743    }
1744
1745    #[test]
1746    fn test_cov3_adapters_config_toml_extra_fields_ignored() {
1747        // Extra fields should be ignored by serde
1748        let toml = r#"
1749[[adapter]]
1750data = "data.jsonl"
1751checkpoint = "ckpt"
1752unknown_field = "ignored"
1753"#;
1754        // Depending on serde config, this might fail or succeed
1755        // toml::from_str with #[serde(deny_unknown_fields)] would fail
1756        // Without it, it succeeds
1757        let result = AdaptersConfigFile::from_toml(toml);
1758        // Just verify no panic
1759        let _ = result;
1760    }
1761
1762    #[test]
1763    fn test_cov3_adapters_config_rank_zero() {
1764        let toml = r#"
1765[[adapter]]
1766data = "data.jsonl"
1767checkpoint = "ckpt"
1768rank = 0
1769"#;
1770        let config = AdaptersConfigFile::from_toml(toml).unwrap();
1771        let base = InstructConfig::default();
1772        let adapters = config.to_adapter_configs(&base);
1773        assert_eq!(adapters[0].instruct_config.lora_rank, 0);
1774        assert!((adapters[0].instruct_config.lora_alpha - 0.0).abs() < f32::EPSILON);
1775    }
1776
1777    #[test]
1778    fn test_cov3_add_adapter_creates_lora_layers() {
1779        let mut pipeline =
1780            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1781        let config = AdapterConfig {
1782            data_path: PathBuf::from("data.jsonl"),
1783            checkpoint_dir: PathBuf::from("/tmp/ckpt"),
1784            instruct_config: InstructConfig { lora_rank: 4, ..InstructConfig::default() },
1785        };
1786        pipeline.add_adapter(config, vec![], vec![]);
1787        // Adapter should have LoRA layers created from the base model
1788        // tiny model has 2 layers → 2 * 2 (Q+V) = 4 LoRA layers
1789        assert_eq!(pipeline.adapters[0].lora_layers.len(), 4);
1790    }
1791
1792    #[test]
1793    fn test_cov3_add_adapter_with_val_samples() {
1794        let mut pipeline =
1795            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1796        let config = AdapterConfig {
1797            data_path: PathBuf::from("data.jsonl"),
1798            checkpoint_dir: PathBuf::from("/tmp/ckpt"),
1799            instruct_config: InstructConfig::default(),
1800        };
1801        let val_samples = vec![InstructSample {
1802            instruction: "val_q".into(),
1803            response: "val_a".into(),
1804            system: None,
1805            metadata: None,
1806        }];
1807        pipeline.add_adapter(config, vec![], val_samples);
1808        assert_eq!(pipeline.adapters[0].val_samples.len(), 1);
1809        assert_eq!(pipeline.adapters[0].val_samples[0].instruction, "val_q");
1810    }
1811
1812    #[test]
1813    fn test_cov3_add_adapter_initial_state() {
1814        let mut pipeline =
1815            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1816        let config = AdapterConfig {
1817            data_path: PathBuf::from("data.jsonl"),
1818            checkpoint_dir: PathBuf::from("/tmp/ckpt_initial"),
1819            instruct_config: InstructConfig { lora_rank: 8, ..InstructConfig::default() },
1820        };
1821        pipeline.add_adapter(config, vec![], vec![]);
1822        let slot = &pipeline.adapters[0];
1823        assert_eq!(slot.cursor, 0);
1824        assert_eq!(slot.best_val_loss, f32::INFINITY);
1825        assert!(slot.metrics.is_empty());
1826        assert_eq!(slot.config.lora_rank, 8);
1827        assert_eq!(slot.checkpoint_dir, PathBuf::from("/tmp/ckpt_initial"));
1828    }
1829
1830    #[test]
1831    fn test_cov3_train_step_adapter_empty_tokens() {
1832        // Test with empty instruction and response after tokenization
1833        let mut pipeline = MultiAdapterPipeline {
1834            base_pipeline: create_dummy_pipeline(),
1835            adapters: vec![dummy_slot_with_data(5)],
1836            schedule: AdapterSchedule::RoundRobin,
1837            global_step: 0,
1838        };
1839        // No tokenizer → returns None early
1840        let result = pipeline.train_step_adapter(0);
1841        assert!(result.is_none());
1842        // Cursor should have advanced by 1
1843        assert_eq!(pipeline.adapters[0].cursor, 1);
1844    }
1845
1846    #[test]
1847    fn test_cov3_batch_train_step_synchronized_mixed_exhaustion() {
1848        let slot0 = dummy_slot_with_data(3); // has data
1849        let mut slot1 = dummy_slot_with_data(1);
1850        slot1.cursor = 1; // exhausted
1851
1852        let mut pipeline = MultiAdapterPipeline {
1853            base_pipeline: create_dummy_pipeline(),
1854            adapters: vec![slot0, slot1],
1855            schedule: AdapterSchedule::Synchronized,
1856            global_step: 0,
1857        };
1858
1859        let results = pipeline.batch_train_step();
1860        assert_eq!(results.len(), 2);
1861        // Adapter 1 is exhausted → None
1862        assert!(results[1].is_none());
1863    }
1864
1865    #[test]
1866    fn test_cov3_batch_train_step_round_robin_cycling() {
1867        // Verify round-robin cycling by manually advancing global_step
1868        let pipeline = MultiAdapterPipeline {
1869            base_pipeline: create_dummy_pipeline(),
1870            adapters: vec![
1871                dummy_slot_with_data(10),
1872                dummy_slot_with_data(10),
1873                dummy_slot_with_data(10),
1874            ],
1875            schedule: AdapterSchedule::RoundRobin,
1876            global_step: 0,
1877        };
1878
1879        // Step 0 → adapter 0
1880        assert_eq!(pipeline.select_next_adapter(), Some(0));
1881        // Simulate step 1
1882        let pipeline = MultiAdapterPipeline { global_step: 1, ..pipeline };
1883        assert_eq!(pipeline.select_next_adapter(), Some(1));
1884        // Simulate step 2
1885        let pipeline = MultiAdapterPipeline { global_step: 2, ..pipeline };
1886        assert_eq!(pipeline.select_next_adapter(), Some(2));
1887        // Step 3 wraps back to 0
1888        let pipeline = MultiAdapterPipeline { global_step: 3, ..pipeline };
1889        assert_eq!(pipeline.select_next_adapter(), Some(0));
1890    }
1891
1892    #[test]
1893    fn test_cov3_reset_epoch_multiple_adapters_independent_seeds() {
1894        let mut pipeline = MultiAdapterPipeline {
1895            base_pipeline: create_dummy_pipeline(),
1896            adapters: vec![dummy_slot_with_data(20), dummy_slot_with_data(20)],
1897            schedule: AdapterSchedule::RoundRobin,
1898            global_step: 0,
1899        };
1900
1901        pipeline.reset_epoch(42);
1902
1903        // Different adapters should get different shuffle orders
1904        // (seed + adapter index → different effective seed)
1905        let order0: Vec<String> =
1906            pipeline.adapters[0].train_samples.iter().map(|s| s.instruction.clone()).collect();
1907        let order1: Vec<String> =
1908            pipeline.adapters[1].train_samples.iter().map(|s| s.instruction.clone()).collect();
1909        // With different effective seeds, orderings should differ
1910        assert_ne!(order0, order1, "Different adapters should have different shuffle orders");
1911    }
1912
1913    #[test]
1914    fn test_cov3_adapter_schedule_copy() {
1915        let s1 = AdapterSchedule::PriorityValLoss;
1916        let s2 = s1; // Copy
1917        assert_eq!(s1, s2);
1918    }
1919
1920    #[test]
1921    fn test_cov3_adapters_config_file_clone() {
1922        let toml = r#"
1923[[adapter]]
1924data = "data.jsonl"
1925checkpoint = "ckpt"
1926label = "test"
1927"#;
1928        let config = AdaptersConfigFile::from_toml(toml).unwrap();
1929        let cloned = config.clone();
1930        assert_eq!(cloned.adapters.len(), 1);
1931        assert_eq!(cloned.adapters[0].label, Some("test".to_string()));
1932    }
1933
1934    #[test]
1935    fn test_cov3_adapter_entry_clone() {
1936        let toml = r#"
1937[[adapter]]
1938data = "data.jsonl"
1939checkpoint = "ckpt"
1940rank = 32
1941learning_rate = 0.001
1942"#;
1943        let config = AdaptersConfigFile::from_toml(toml).unwrap();
1944        let cloned = config.adapters[0].clone();
1945        assert_eq!(cloned.rank, Some(32));
1946        assert_eq!(cloned.learning_rate, Some(0.001));
1947    }
1948
1949    #[test]
1950    fn test_cov3_save_adapter_checkpoint_metadata_values() {
1951        let dir = std::env::temp_dir().join("entrenar_cov3_ckpt_meta");
1952        let _ = std::fs::remove_dir_all(&dir);
1953
1954        let mut pipeline =
1955            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1956        pipeline.global_step = 42;
1957        let config = AdapterConfig {
1958            data_path: PathBuf::from("data.jsonl"),
1959            checkpoint_dir: dir.clone(),
1960            instruct_config: InstructConfig {
1961                lora_rank: 8,
1962                lora_alpha: 16.0,
1963                ..InstructConfig::default()
1964            },
1965        };
1966        let samples: Vec<InstructSample> = (0..5)
1967            .map(|i| InstructSample {
1968                instruction: format!("q{i}"),
1969                response: format!("a{i}"),
1970                system: None,
1971                metadata: None,
1972            })
1973            .collect();
1974        pipeline.add_adapter(config, samples, vec![]);
1975        pipeline.adapters[0].best_val_loss = 0.75;
1976
1977        let ckpt_dir = pipeline.save_adapter_checkpoint(0, 3, 0.42).unwrap();
1978        let metadata_str = std::fs::read_to_string(ckpt_dir.join("metadata.json")).unwrap();
1979        let metadata: serde_json::Value = serde_json::from_str(&metadata_str).unwrap();
1980
1981        assert_eq!(metadata["adapter_index"], 0);
1982        assert_eq!(metadata["epoch"], 3);
1983        assert_eq!(metadata["lora_rank"], 8);
1984        assert_eq!(metadata["train_samples"], 5);
1985        assert_eq!(metadata["global_step"], 42);
1986
1987        let _ = std::fs::remove_dir_all(&dir);
1988    }
1989
1990    #[test]
1991    fn test_cov3_save_adapter_checkpoint_multiple_epochs() {
1992        let dir = std::env::temp_dir().join("entrenar_cov3_multi_epoch");
1993        let _ = std::fs::remove_dir_all(&dir);
1994
1995        let mut pipeline =
1996            MultiAdapterPipeline::new(create_dummy_pipeline(), AdapterSchedule::RoundRobin);
1997        let config = AdapterConfig {
1998            data_path: PathBuf::from("data.jsonl"),
1999            checkpoint_dir: dir.clone(),
2000            instruct_config: InstructConfig::default(),
2001        };
2002        pipeline.add_adapter(config, vec![], vec![]);
2003
2004        // Save multiple epochs
2005        for epoch in 0..3 {
2006            let ckpt_dir =
2007                pipeline.save_adapter_checkpoint(0, epoch, 1.0 - epoch as f32 * 0.2).unwrap();
2008            assert!(ckpt_dir.join("metadata.json").exists());
2009            assert!(ckpt_dir.join("model.safetensors").exists());
2010        }
2011
2012        // All three epoch directories should exist
2013        assert!(dir.join("epoch-0").exists());
2014        assert!(dir.join("epoch-1").exists());
2015        assert!(dir.join("epoch-2").exists());
2016
2017        let _ = std::fs::remove_dir_all(&dir);
2018    }
2019
2020    #[test]
2021    fn test_cov3_all_exhausted_single_adapter_one_sample() {
2022        let slot = dummy_slot_with_data(1);
2023        let pipeline = MultiAdapterPipeline {
2024            base_pipeline: create_dummy_pipeline(),
2025            adapters: vec![slot],
2026            schedule: AdapterSchedule::RoundRobin,
2027            global_step: 0,
2028        };
2029        assert!(!pipeline.all_exhausted());
2030    }
2031
2032    #[test]
2033    fn test_cov3_all_exhausted_single_adapter_cursor_at_end() {
2034        let mut slot = dummy_slot_with_data(1);
2035        slot.cursor = 1;
2036        let pipeline = MultiAdapterPipeline {
2037            base_pipeline: create_dummy_pipeline(),
2038            adapters: vec![slot],
2039            schedule: AdapterSchedule::RoundRobin,
2040            global_step: 0,
2041        };
2042        assert!(pipeline.all_exhausted());
2043    }
2044
2045    #[test]
2046    fn test_cov3_select_priority_single_adapter() {
2047        let mut slot = dummy_slot();
2048        slot.best_val_loss = 3.0;
2049        let pipeline = MultiAdapterPipeline {
2050            base_pipeline: create_dummy_pipeline(),
2051            adapters: vec![slot],
2052            schedule: AdapterSchedule::PriorityValLoss,
2053            global_step: 0,
2054        };
2055        assert_eq!(pipeline.select_next_adapter(), Some(0));
2056    }
2057
2058    #[test]
2059    fn test_cov3_select_priority_equal_losses() {
2060        let mut slot0 = dummy_slot();
2061        slot0.best_val_loss = 1.0;
2062        let mut slot1 = dummy_slot();
2063        slot1.best_val_loss = 1.0;
2064        let pipeline = MultiAdapterPipeline {
2065            base_pipeline: create_dummy_pipeline(),
2066            adapters: vec![slot0, slot1],
2067            schedule: AdapterSchedule::PriorityValLoss,
2068            global_step: 0,
2069        };
2070        let result = pipeline.select_next_adapter();
2071        // With equal losses, max_by picks the last one (stable)
2072        assert!(result == Some(0) || result == Some(1));
2073    }
2074
2075    #[test]
2076    fn test_cov3_to_adapter_configs_no_overrides() {
2077        let toml = r#"
2078[[adapter]]
2079data = "d.jsonl"
2080checkpoint = "c"
2081"#;
2082        let config = AdaptersConfigFile::from_toml(toml).unwrap();
2083        let base = InstructConfig {
2084            lora_rank: 32,
2085            lora_alpha: 64.0,
2086            learning_rate: 0.005,
2087            epochs: 7,
2088            max_seq_len: 1024,
2089            gradient_clip_norm: Some(2.0),
2090            quantize_nf4: true,
2091        };
2092        let adapters = config.to_adapter_configs(&base);
2093        // All base values should be inherited
2094        assert_eq!(adapters[0].instruct_config.lora_rank, 32);
2095        assert!((adapters[0].instruct_config.lora_alpha - 64.0).abs() < f32::EPSILON);
2096        assert!((adapters[0].instruct_config.learning_rate - 0.005).abs() < f32::EPSILON);
2097        assert_eq!(adapters[0].instruct_config.epochs, 7);
2098        assert_eq!(adapters[0].instruct_config.max_seq_len, 1024);
2099        assert_eq!(adapters[0].instruct_config.gradient_clip_norm, Some(2.0));
2100        assert!(adapters[0].instruct_config.quantize_nf4);
2101    }
2102
2103    #[test]
2104    fn test_cov3_to_adapter_configs_preserves_data_and_checkpoint_paths() {
2105        let toml = r#"
2106[[adapter]]
2107data = "/absolute/path/data.jsonl"
2108checkpoint = "../relative/ckpt"
2109"#;
2110        let config = AdaptersConfigFile::from_toml(toml).unwrap();
2111        let base = InstructConfig::default();
2112        let adapters = config.to_adapter_configs(&base);
2113        assert_eq!(adapters[0].data_path, PathBuf::from("/absolute/path/data.jsonl"));
2114        assert_eq!(adapters[0].checkpoint_dir, PathBuf::from("../relative/ckpt"));
2115    }
2116
2117    #[test]
2118    fn test_cov3_adapters_config_from_file_invalid_toml() {
2119        let dir = std::env::temp_dir().join("entrenar_cov3_invalid_toml");
2120        let _ = std::fs::create_dir_all(&dir);
2121        let path = dir.join("invalid.toml");
2122        std::fs::write(&path, "this {{ is not valid TOML").unwrap();
2123        let result = AdaptersConfigFile::from_file(&path);
2124        assert!(result.is_err());
2125        let err = result.unwrap_err();
2126        assert!(err.contains("failed to parse"), "Expected parse error, got: {err}");
2127        let _ = std::fs::remove_dir_all(&dir);
2128    }
2129
2130    #[test]
2131    fn test_cov3_adapter_slot_checkpoint_dir() {
2132        let slot = AdapterSlot {
2133            lora_layers: Vec::new(),
2134            train_samples: Vec::new(),
2135            val_samples: Vec::new(),
2136            checkpoint_dir: PathBuf::from("/my/custom/ckpt"),
2137            metrics: Vec::new(),
2138            config: InstructConfig::default(),
2139            cursor: 0,
2140            best_val_loss: f32::INFINITY,
2141            #[cfg(feature = "cuda")]
2142            optimizer_states: None,
2143            #[cfg(feature = "cuda")]
2144            lora_step: 0,
2145        };
2146        assert_eq!(slot.checkpoint_dir, PathBuf::from("/my/custom/ckpt"));
2147    }
2148
2149    #[test]
2150    fn test_cov3_multi_adapter_schedule_field() {
2151        let pipeline = MultiAdapterPipeline {
2152            base_pipeline: create_dummy_pipeline(),
2153            adapters: vec![],
2154            schedule: AdapterSchedule::PriorityValLoss,
2155            global_step: 0,
2156        };
2157        assert_eq!(pipeline.schedule, AdapterSchedule::PriorityValLoss);
2158    }
2159}