Skip to main content

entrenar/cli/commands/
merge.rs

1//! Merge command implementation
2
3use crate::autograd::Tensor;
4use crate::cli::logging::log;
5use crate::cli::LogLevel;
6use crate::config::{MergeArgs, MergeMethod};
7use crate::merge::{
8    dare_merge, ensemble_merge, slerp_merge, ties_merge, DareConfig, EnsembleConfig, Model,
9    SlerpConfig, TiesConfig,
10};
11use safetensors::SafeTensors;
12use std::collections::HashMap;
13use std::path::Path;
14
15pub fn run_merge(args: MergeArgs, level: LogLevel) -> Result<(), String> {
16    // ENT-LoRA-017: LoRA adapter merge path
17    if args.method == MergeMethod::LoraAdapter {
18        return run_lora_adapter_merge(&args, level);
19    }
20
21    log_merge_start(&args, level);
22    validate_model_count(&args)?;
23
24    let models = load_all_models(&args.models, level)?;
25    let merged = perform_merge(&models, &args)?;
26    export_merged_model(&merged, &args)?;
27
28    log_merge_complete(&merged, &args, level);
29    Ok(())
30}
31
32/// Log merge operation start
33fn log_merge_start(args: &MergeArgs, level: LogLevel) {
34    log(
35        level,
36        LogLevel::Normal,
37        &format!("Merging {} models using {:?}", args.models.len(), args.method),
38    );
39
40    for (i, model) in args.models.iter().enumerate() {
41        log(level, LogLevel::Verbose, &format!("  Model {}: {}", i + 1, model.display()));
42    }
43    log(level, LogLevel::Verbose, &format!("  Output: {}", args.output.display()));
44}
45
46/// Validate we have enough models
47fn validate_model_count(args: &MergeArgs) -> Result<(), String> {
48    if args.models.len() < 2 {
49        return Err("Need at least 2 models to merge".to_string());
50    }
51    Ok(())
52}
53
54/// Load all models from paths
55fn load_all_models(paths: &[std::path::PathBuf], level: LogLevel) -> Result<Vec<Model>, String> {
56    let mut models: Vec<Model> = Vec::new();
57    for path in paths {
58        let model = load_single_model(path)?;
59        let tensor_count = model.len();
60        models.push(model);
61
62        log(
63            level,
64            LogLevel::Verbose,
65            &format!("  Loaded {} tensors from {}", tensor_count, path.display()),
66        );
67    }
68    Ok(models)
69}
70
71/// Load a single model from a SafeTensors file
72fn load_single_model(path: &Path) -> Result<Model, String> {
73    let data =
74        std::fs::read(path).map_err(|e| format!("Failed to read {}: {e}", path.display()))?;
75
76    let tensors = SafeTensors::deserialize(&data)
77        .map_err(|e| format!("Failed to parse {}: {e}", path.display()))?;
78
79    let mut model: Model = HashMap::new();
80    for name in tensors.names() {
81        if let Some(tensor) = extract_f32_tensor(&tensors, name)? {
82            model.insert((*name).to_string(), tensor);
83        }
84    }
85    Ok(model)
86}
87
88/// Extract a tensor as f32 values (returns None for non-F32 tensors)
89fn extract_f32_tensor(tensors: &SafeTensors<'_>, name: &str) -> Result<Option<Tensor>, String> {
90    let tensor = tensors.tensor(name).map_err(|e| format!("Failed to get tensor {name}: {e}"))?;
91
92    if tensor.dtype() != safetensors::tensor::Dtype::F32 {
93        return Ok(None);
94    }
95
96    let bytes = tensor.data();
97    let values: Vec<f32> = bytes
98        .chunks_exact(4)
99        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
100        .collect();
101
102    Ok(Some(Tensor::from_vec(values, false)))
103}
104
105/// Perform the merge based on the specified method
106fn perform_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
107    match args.method {
108        MergeMethod::Ties => perform_ties_merge(models, args),
109        MergeMethod::Dare => perform_dare_merge(models, args),
110        MergeMethod::Slerp => perform_slerp_merge(models, args),
111        MergeMethod::Average => perform_average_merge(models, args),
112        MergeMethod::LoraAdapter => {
113            // Handled by early return in run_merge; shouldn't reach here
114            Err("LoRA adapter merge uses dedicated path".to_string())
115        }
116    }
117}
118
119/// TIES merge: first model is base, rest are task-specific
120fn perform_ties_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
121    let config = TiesConfig { density: args.density.unwrap_or(0.2) };
122    let base = &models[0];
123    ties_merge(models.get(1..).unwrap_or_default(), base, &config)
124        .map_err(|e| format!("TIES merge failed: {e}"))
125}
126
127/// DARE merge with dropout
128fn perform_dare_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
129    let config = DareConfig { drop_prob: 1.0 - args.density.unwrap_or(0.5), seed: None };
130    let base = &models[0];
131    dare_merge(models.get(1..).unwrap_or_default(), base, &config)
132        .map_err(|e| format!("DARE merge failed: {e}"))
133}
134
135/// SLERP merge (requires exactly 2 models)
136fn perform_slerp_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
137    if models.len() != 2 {
138        return Err("SLERP requires exactly 2 models".to_string());
139    }
140    let config = SlerpConfig { t: args.weight.unwrap_or(0.5) };
141    slerp_merge(&models[0], &models[1], &config).map_err(|e| format!("SLERP merge failed: {e}"))
142}
143
144/// Average/ensemble merge with optional weights
145fn perform_average_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
146    let config = build_ensemble_config(args)?;
147    ensemble_merge(models, &config).map_err(|e| format!("Average merge failed: {e}"))
148}
149
150/// Build ensemble config from args
151fn build_ensemble_config(args: &MergeArgs) -> Result<EnsembleConfig, String> {
152    if let Some(w_str) = &args.weights {
153        let weights: Vec<f32> = w_str
154            .split(',')
155            .map(|s| s.trim().parse::<f32>())
156            .collect::<Result<Vec<_>, _>>()
157            .map_err(|e| format!("Invalid weights: {e}"))?;
158        Ok(EnsembleConfig::weighted_average(weights))
159    } else {
160        Ok(EnsembleConfig::uniform_average())
161    }
162}
163
164/// Export merged model to file
165fn export_merged_model(merged: &Model, args: &MergeArgs) -> Result<(), String> {
166    let output_ext = args.output.extension().and_then(|s| s.to_str()).unwrap_or("json");
167
168    if output_ext == "safetensors" {
169        export_safetensors(merged, args)
170    } else {
171        export_json(merged, args)
172    }
173}
174
175/// Export to SafeTensors format
176fn export_safetensors(merged: &Model, args: &MergeArgs) -> Result<(), String> {
177    use safetensors::tensor::{Dtype, TensorView};
178
179    let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = merged
180        .iter()
181        .map(|(name, tensor)| {
182            let data = tensor.data();
183            let bytes: Vec<u8> = bytemuck::cast_slice(data.as_slice().unwrap_or(&[])).to_vec();
184            let shape = vec![tensor.len()];
185            (name.clone(), bytes, shape)
186        })
187        .collect();
188
189    let views: Vec<(&str, TensorView<'_>)> = tensor_data
190        .iter()
191        .filter_map(|(name, bytes, shape)| {
192            TensorView::new(Dtype::F32, shape.clone(), bytes).ok().map(|view| (name.as_str(), view))
193        })
194        .collect();
195
196    let metadata = build_safetensor_metadata(merged, args);
197    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
198        .map_err(|e| format!("Failed to serialize SafeTensors: {e}"))?;
199
200    std::fs::write(&args.output, safetensor_bytes)
201        .map_err(|e| format!("Failed to write output: {e}"))
202}
203
204/// Build SafeTensors metadata
205fn build_safetensor_metadata(merged: &Model, args: &MergeArgs) -> HashMap<String, String> {
206    let mut metadata = HashMap::new();
207    metadata.insert("name".to_string(), "merged-model".to_string());
208    metadata.insert("merge_method".to_string(), format!("{:?}", args.method));
209    metadata.insert("tensor_count".to_string(), merged.len().to_string());
210    metadata
211}
212
213/// Export to JSON format
214fn export_json(merged: &Model, args: &MergeArgs) -> Result<(), String> {
215    let output_data: HashMap<String, Vec<f32>> =
216        merged.iter().map(|(name, tensor)| (name.clone(), tensor.data().to_vec())).collect();
217
218    let json_data =
219        serde_json::to_vec_pretty(&output_data).map_err(|e| format!("Failed to serialize: {e}"))?;
220
221    std::fs::write(&args.output, &json_data).map_err(|e| format!("Failed to write output: {e}"))
222}
223
224/// Log merge completion
225fn log_merge_complete(merged: &Model, args: &MergeArgs, level: LogLevel) {
226    log(
227        level,
228        LogLevel::Normal,
229        &format!("Merge complete: {} tensors written to {}", merged.len(), args.output.display()),
230    );
231}
232
233/// Merge LoRA adapter into base model (ENT-LoRA-017)
234///
235/// Computes W_merged = W_base + scale * B @ A for each adapted module,
236/// producing a standard safetensors model with no LoRA tensors.
237fn run_lora_adapter_merge(args: &MergeArgs, level: LogLevel) -> Result<(), String> {
238    let base_path = args.base.as_ref().ok_or("--base required for lora-adapter merge")?;
239    let adapter_dir = args.adapter.as_ref().ok_or("--adapter required for lora-adapter merge")?;
240
241    let config_path = adapter_dir.join("adapter_config.json");
242    let adapter_path = adapter_dir.join("adapter_model.safetensors");
243
244    if !base_path.exists() {
245        return Err(format!("Base model not found: {}", base_path.display()));
246    }
247    if !config_path.exists() {
248        return Err(format!("adapter_config.json not found in {}", adapter_dir.display()));
249    }
250    if !adapter_path.exists() {
251        return Err(format!("adapter_model.safetensors not found in {}", adapter_dir.display()));
252    }
253
254    log(level, LogLevel::Normal, "LoRA adapter merge:");
255    log(level, LogLevel::Normal, &format!("  Base: {}", base_path.display()));
256    log(level, LogLevel::Normal, &format!("  Adapter: {}", adapter_dir.display()));
257
258    // Read adapter config
259    let config_str =
260        std::fs::read_to_string(&config_path).map_err(|e| format!("Read adapter config: {e}"))?;
261    let config: serde_json::Value =
262        serde_json::from_str(&config_str).map_err(|e| format!("Parse adapter config: {e}"))?;
263
264    let rank = config.get("r").and_then(serde_json::Value::as_u64).unwrap_or(8) as usize;
265    let alpha =
266        config.get("lora_alpha").and_then(serde_json::Value::as_f64).unwrap_or(rank as f64 * 2.0);
267    let scale = alpha as f32 / rank as f32;
268
269    log(level, LogLevel::Normal, &format!("  Rank: {rank}, Alpha: {alpha}, Scale: {scale:.4}"));
270
271    // Load base model
272    let base_data = std::fs::read(base_path).map_err(|e| format!("Read base model: {e}"))?;
273    let base_tensors =
274        SafeTensors::deserialize(&base_data).map_err(|e| format!("Parse base model: {e}"))?;
275
276    // Load adapter
277    let adapter_data = std::fs::read(&adapter_path).map_err(|e| format!("Read adapter: {e}"))?;
278    let adapter_tensors =
279        SafeTensors::deserialize(&adapter_data).map_err(|e| format!("Parse adapter: {e}"))?;
280
281    // Merge: copy all base tensors, apply LoRA delta where adapters exist
282    let adapter_names: Vec<String> =
283        adapter_tensors.names().iter().map(|s| (*s).to_string()).collect();
284    let base_names: Vec<String> = base_tensors.names().iter().map(|s| (*s).to_string()).collect();
285
286    // Build map of adapter A/B pairs grouped by module path
287    let lora_pairs = build_lora_pairs(&adapter_names, &adapter_tensors)?;
288    let mut merged_count = 0usize;
289
290    // Prepare output tensors
291    let mut output_tensors: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
292
293    for name in &base_names {
294        let base_t = base_tensors.tensor(name).map_err(|e| format!("Get tensor {name}: {e}"))?;
295        let shape: Vec<usize> = base_t.shape().to_vec();
296
297        // Check if this weight has a LoRA adapter
298        if let Some((a_data, b_data, a_shape, b_shape)) = lora_pairs.get(name.as_str()) {
299            // W_merged = W_base + scale * B @ A
300            let base_f32 = bytes_to_f32(base_t.data(), base_t.dtype());
301            let a_f32 = bytes_to_f32(a_data, safetensors::tensor::Dtype::F32);
302            let b_f32 = bytes_to_f32(b_data, safetensors::tensor::Dtype::F32);
303
304            let d_out = b_shape[0];
305            let r = b_shape[1];
306            let d_in = a_shape[1];
307
308            // Compute B @ A: [d_out, r] @ [r, d_in] -> [d_out, d_in]
309            let mut ba = vec![0.0f32; d_out * d_in];
310            for i in 0..d_out {
311                for j in 0..d_in {
312                    let mut sum = 0.0f32;
313                    for k in 0..r {
314                        sum += b_f32[i * r + k] * a_f32[k * d_in + j];
315                    }
316                    ba[i * d_in + j] = sum;
317                }
318            }
319
320            // W_merged = W_base + scale * BA
321            let mut merged: Vec<f32> = base_f32;
322            for (i, val) in merged.iter_mut().enumerate() {
323                *val += scale * ba[i];
324            }
325
326            let bytes: Vec<u8> = bytemuck::cast_slice(&merged).to_vec();
327            output_tensors.push((name.clone(), bytes, shape));
328            merged_count += 1;
329        } else {
330            // Pass through base tensor unchanged
331            output_tensors.push((name.clone(), base_t.data().to_vec(), shape));
332        }
333    }
334
335    // Serialize to safetensors
336    let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = output_tensors
337        .iter()
338        .filter_map(|(name, bytes, shape)| {
339            safetensors::tensor::TensorView::new(
340                safetensors::tensor::Dtype::F32,
341                shape.clone(),
342                bytes,
343            )
344            .ok()
345            .map(|view| (name.as_str(), view))
346        })
347        .collect();
348
349    let mut metadata = HashMap::new();
350    metadata.insert("format".to_string(), "entrenar-merged-lora".to_string());
351    metadata.insert("lora_rank".to_string(), rank.to_string());
352    metadata.insert("lora_alpha".to_string(), format!("{alpha}"));
353
354    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
355        .map_err(|e| format!("Serialize merged model: {e}"))?;
356
357    std::fs::write(&args.output, safetensor_bytes)
358        .map_err(|e| format!("Write merged model: {e}"))?;
359
360    let output_size = std::fs::metadata(&args.output).map(|m| m.len()).unwrap_or(0);
361    log(
362        level,
363        LogLevel::Normal,
364        &format!("  Merged {merged_count} adapter weights into base model"),
365    );
366    log(
367        level,
368        LogLevel::Normal,
369        &format!("  Output: {} ({:.2} MB)", args.output.display(), output_size as f64 / 1e6),
370    );
371
372    Ok(())
373}
374
375/// Build a map of base weight name -> (A_data, B_data, A_shape, B_shape)
376fn build_lora_pairs<'a>(
377    names: &[String],
378    tensors: &'a SafeTensors<'a>,
379) -> Result<HashMap<&'a str, (Vec<u8>, Vec<u8>, Vec<usize>, Vec<usize>)>, String> {
380    let mut pairs: HashMap<String, (Option<(Vec<u8>, Vec<usize>)>, Option<(Vec<u8>, Vec<usize>)>)> =
381        HashMap::new();
382
383    for name in names {
384        // PEFT naming: base_model.model.{path}.lora_A.weight / lora_B.weight
385        let (base_name, is_a) = if let Some(stripped) = name.strip_suffix(".lora_A.weight") {
386            (stripped.replace("base_model.model.", "") + ".weight", true)
387        } else if let Some(stripped) = name.strip_suffix(".lora_B.weight") {
388            (stripped.replace("base_model.model.", "") + ".weight", false)
389        } else {
390            continue;
391        };
392
393        let tensor = tensors.tensor(name).map_err(|e| format!("Get adapter tensor {name}: {e}"))?;
394        let data = tensor.data().to_vec();
395        let shape = tensor.shape().to_vec();
396
397        let entry = pairs.entry(base_name).or_insert((None, None));
398        if is_a {
399            entry.0 = Some((data, shape));
400        } else {
401            entry.1 = Some((data, shape));
402        }
403    }
404
405    let mut result = HashMap::new();
406    for (base_name, (a, b)) in &pairs {
407        if let (Some((a_data, a_shape)), Some((b_data, b_shape))) = (a, b) {
408            // Leak the base_name string to get a &'a str — safe in this context
409            // as the result lives only for the merge duration
410            let key: &str = Box::leak(base_name.clone().into_boxed_str());
411            result.insert(key, (a_data.clone(), b_data.clone(), a_shape.clone(), b_shape.clone()));
412        }
413    }
414    Ok(result)
415}
416
417/// Convert tensor bytes to f32 based on dtype
418fn bytes_to_f32(data: &[u8], dtype: safetensors::tensor::Dtype) -> Vec<f32> {
419    match dtype {
420        safetensors::tensor::Dtype::F32 => {
421            data.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect()
422        }
423        safetensors::tensor::Dtype::F16 => data
424            .chunks_exact(2)
425            .map(|c| {
426                let bits = u16::from_le_bytes([c[0], c[1]]);
427                half::f16::from_bits(bits).to_f32()
428            })
429            .collect(),
430        safetensors::tensor::Dtype::BF16 => data
431            .chunks_exact(2)
432            .map(|c| {
433                let bits = u16::from_le_bytes([c[0], c[1]]);
434                half::bf16::from_bits(bits).to_f32()
435            })
436            .collect(),
437        _ => {
438            // For other dtypes, treat as f32 (best effort)
439            data.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect()
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    #![allow(clippy::unwrap_used)]
447    use super::*;
448    use std::path::PathBuf;
449
450    #[test]
451    fn test_validate_model_count_zero() {
452        let args = MergeArgs {
453            models: vec![],
454            output: PathBuf::from("o.json"),
455            method: MergeMethod::Ties,
456            weight: None,
457            density: None,
458            weights: None,
459            base: None,
460            adapter: None,
461        };
462        assert!(validate_model_count(&args).is_err());
463    }
464
465    #[test]
466    fn test_validate_model_count_two_ok() {
467        let args = MergeArgs {
468            models: vec![PathBuf::from("a"), PathBuf::from("b")],
469            output: PathBuf::from("o.json"),
470            method: MergeMethod::Ties,
471            weight: None,
472            density: None,
473            weights: None,
474            base: None,
475            adapter: None,
476        };
477        assert!(validate_model_count(&args).is_ok());
478    }
479
480    #[test]
481    fn test_build_ensemble_config_no_weights() {
482        let args = MergeArgs {
483            models: vec![],
484            output: PathBuf::from("o.json"),
485            method: MergeMethod::Average,
486            weight: None,
487            density: None,
488            weights: None,
489            base: None,
490            adapter: None,
491        };
492        assert!(build_ensemble_config(&args).is_ok());
493    }
494
495    #[test]
496    fn test_build_ensemble_config_with_weights() {
497        let args = MergeArgs {
498            models: vec![],
499            output: PathBuf::from("o.json"),
500            method: MergeMethod::Average,
501            weight: None,
502            density: None,
503            weights: Some("0.3, 0.7".into()),
504            base: None,
505            adapter: None,
506        };
507        assert!(build_ensemble_config(&args).is_ok());
508    }
509
510    #[test]
511    fn test_build_ensemble_config_invalid() {
512        let args = MergeArgs {
513            models: vec![],
514            output: PathBuf::from("o.json"),
515            method: MergeMethod::Average,
516            weight: None,
517            density: None,
518            weights: Some("abc".into()),
519            base: None,
520            adapter: None,
521        };
522        assert!(build_ensemble_config(&args).unwrap_err().contains("Invalid weights"));
523    }
524
525    fn mk(keys: &[(&str, &[f32])]) -> Model {
526        keys.iter().map(|(n, v)| (n.to_string(), Tensor::from_vec(v.to_vec(), false))).collect()
527    }
528
529    #[test]
530    fn test_slerp_wrong_count() {
531        let ms = vec![mk(&[("w", &[1.0])]), mk(&[("w", &[2.0])]), mk(&[("w", &[3.0])])];
532        let a = MergeArgs {
533            models: vec![],
534            output: PathBuf::from("o"),
535            method: MergeMethod::Slerp,
536            weight: None,
537            density: None,
538            weights: None,
539            base: None,
540            adapter: None,
541        };
542        assert!(perform_slerp_merge(&ms, &a).unwrap_err().contains("SLERP requires exactly 2"));
543    }
544
545    #[test]
546    fn test_merge_lora_err() {
547        let a = MergeArgs {
548            models: vec![],
549            output: PathBuf::from("o"),
550            method: MergeMethod::LoraAdapter,
551            weight: None,
552            density: None,
553            weights: None,
554            base: None,
555            adapter: None,
556        };
557        assert!(perform_merge(&[], &a).is_err());
558    }
559
560    #[test]
561    fn test_bytes_to_f32_f32() {
562        let v = vec![1.0f32, 2.5];
563        let b: Vec<u8> = v.iter().flat_map(|x| x.to_le_bytes()).collect();
564        let r = bytes_to_f32(&b, safetensors::tensor::Dtype::F32);
565        assert!((r[0] - 1.0).abs() < 1e-6);
566    }
567
568    #[test]
569    fn test_bytes_to_f32_f16() {
570        let b = half::f16::from_f32(1.0).to_le_bytes().to_vec();
571        assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::F16)[0] - 1.0).abs() < 0.01);
572    }
573
574    #[test]
575    fn test_bytes_to_f32_bf16() {
576        let b = half::bf16::from_f32(2.0).to_le_bytes().to_vec();
577        assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::BF16)[0] - 2.0).abs() < 0.1);
578    }
579
580    #[test]
581    fn test_bytes_to_f32_fallback() {
582        let b: Vec<u8> = 42.0f32.to_le_bytes().to_vec();
583        assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::I8)[0] - 42.0).abs() < 1e-6);
584    }
585
586    #[test]
587    fn test_bytes_to_f32_empty() {
588        assert!(bytes_to_f32(&[], safetensors::tensor::Dtype::F32).is_empty());
589    }
590
591    #[test]
592    fn test_safetensor_metadata() {
593        let m = mk(&[("a", &[1.0]), ("b", &[2.0])]);
594        let a = MergeArgs {
595            models: vec![],
596            output: PathBuf::from("o.st"),
597            method: MergeMethod::Dare,
598            weight: None,
599            density: None,
600            weights: None,
601            base: None,
602            adapter: None,
603        };
604        let md = build_safetensor_metadata(&m, &a);
605        assert_eq!(md["name"], "merged-model");
606        assert_eq!(md["tensor_count"], "2");
607    }
608
609    #[test]
610    fn test_export_json() {
611        let m = mk(&[("w", &[1.0])]);
612        let t = std::env::temp_dir().join("ent_merge_j.json");
613        let a = MergeArgs {
614            models: vec![],
615            output: t.clone(),
616            method: MergeMethod::Average,
617            weight: None,
618            density: None,
619            weights: None,
620            base: None,
621            adapter: None,
622        };
623        assert!(export_merged_model(&m, &a).is_ok());
624        let _ = std::fs::remove_file(&t);
625    }
626
627    #[test]
628    fn test_export_safetensors() {
629        let m = mk(&[("w", &[1.0])]);
630        let t = std::env::temp_dir().join("ent_merge_s.safetensors");
631        let a = MergeArgs {
632            models: vec![],
633            output: t.clone(),
634            method: MergeMethod::Average,
635            weight: None,
636            density: None,
637            weights: None,
638            base: None,
639            adapter: None,
640        };
641        assert!(export_merged_model(&m, &a).is_ok());
642        let _ = std::fs::remove_file(&t);
643    }
644
645    #[test]
646    fn test_ties_merge_ok() {
647        let a = MergeArgs {
648            models: vec![],
649            output: PathBuf::from("o"),
650            method: MergeMethod::Ties,
651            weight: None,
652            density: None,
653            weights: None,
654            base: None,
655            adapter: None,
656        };
657        // ties_merge needs base + at least 2 delta models (3 total)
658        assert!(perform_ties_merge(
659            &[mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])]),],
660            &a
661        )
662        .is_ok());
663    }
664
665    #[test]
666    fn test_dare_merge_ok() {
667        let a = MergeArgs {
668            models: vec![],
669            output: PathBuf::from("o"),
670            method: MergeMethod::Dare,
671            weight: None,
672            density: None,
673            weights: None,
674            base: None,
675            adapter: None,
676        };
677        assert!(
678            perform_dare_merge(&[mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])])], &a).is_ok()
679        );
680    }
681
682    #[test]
683    fn test_average_merge() {
684        let a = MergeArgs {
685            models: vec![],
686            output: PathBuf::from("o"),
687            method: MergeMethod::Average,
688            weight: None,
689            density: None,
690            weights: None,
691            base: None,
692            adapter: None,
693        };
694        let r = perform_average_merge(&[mk(&[("w", &[2.0, 4.0])]), mk(&[("w", &[6.0, 8.0])])], &a)
695            .unwrap();
696        let s = r["w"].data().as_slice().unwrap().to_vec();
697        assert!((s[0] - 4.0).abs() < 1e-6);
698    }
699
700    #[test]
701    fn test_log_merge_no_panic() {
702        let a = MergeArgs {
703            models: vec![PathBuf::from("a"), PathBuf::from("b")],
704            output: PathBuf::from("o"),
705            method: MergeMethod::Ties,
706            weight: None,
707            density: None,
708            weights: None,
709            base: None,
710            adapter: None,
711        };
712        log_merge_start(&a, LogLevel::Quiet);
713        log_merge_start(&a, LogLevel::Verbose);
714        log_merge_complete(&mk(&[("w", &[1.0])]), &a, LogLevel::Normal);
715    }
716
717    #[test]
718    fn test_lora_missing_base() {
719        let a = MergeArgs {
720            models: vec![],
721            output: PathBuf::from("o"),
722            method: MergeMethod::LoraAdapter,
723            weight: None,
724            density: None,
725            weights: None,
726            base: None,
727            adapter: Some(PathBuf::from("/tmp")),
728        };
729        assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
730            .unwrap_err()
731            .contains("--base required"));
732    }
733
734    #[test]
735    fn test_lora_missing_adapter() {
736        let a = MergeArgs {
737            models: vec![],
738            output: PathBuf::from("o"),
739            method: MergeMethod::LoraAdapter,
740            weight: None,
741            density: None,
742            weights: None,
743            base: Some(PathBuf::from("/tmp/x")),
744            adapter: None,
745        };
746        assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
747            .unwrap_err()
748            .contains("--adapter required"));
749    }
750
751    #[test]
752    fn test_lora_base_not_found() {
753        let a = MergeArgs {
754            models: vec![],
755            output: PathBuf::from("o"),
756            method: MergeMethod::LoraAdapter,
757            weight: None,
758            density: None,
759            weights: None,
760            base: Some(PathBuf::from("/no/base")),
761            adapter: Some(PathBuf::from("/tmp")),
762        };
763        assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
764            .unwrap_err()
765            .contains("Base model not found"));
766    }
767
768    #[test]
769    fn test_load_nonexistent() {
770        assert!(load_single_model(std::path::Path::new("/no/m"))
771            .unwrap_err()
772            .contains("Failed to read"));
773    }
774
775    #[test]
776    fn test_run_merge_too_few() {
777        let a = MergeArgs {
778            models: vec![PathBuf::from("a")],
779            output: PathBuf::from("o"),
780            method: MergeMethod::Ties,
781            weight: None,
782            density: None,
783            weights: None,
784            base: None,
785            adapter: None,
786        };
787        assert!(run_merge(a, LogLevel::Quiet).unwrap_err().contains("Need at least 2"));
788    }
789
790    #[test]
791    fn test_run_merge_lora_routes() {
792        let a = MergeArgs {
793            models: vec![],
794            output: PathBuf::from("o"),
795            method: MergeMethod::LoraAdapter,
796            weight: None,
797            density: None,
798            weights: None,
799            base: None,
800            adapter: None,
801        };
802        assert!(run_merge(a, LogLevel::Quiet).unwrap_err().contains("--base required"));
803    }
804
805    // ── perform_merge routing tests ─────────────────────────────────────
806
807    #[test]
808    fn test_perform_merge_ties_route() {
809        let models =
810            vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])])];
811        let a = MergeArgs {
812            models: vec![],
813            output: PathBuf::from("o"),
814            method: MergeMethod::Ties,
815            weight: None,
816            density: Some(0.5),
817            weights: None,
818            base: None,
819            adapter: None,
820        };
821        assert!(perform_merge(&models, &a).is_ok());
822    }
823
824    #[test]
825    fn test_perform_merge_dare_route() {
826        let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
827        let a = MergeArgs {
828            models: vec![],
829            output: PathBuf::from("o"),
830            method: MergeMethod::Dare,
831            weight: None,
832            density: Some(0.3),
833            weights: None,
834            base: None,
835            adapter: None,
836        };
837        assert!(perform_merge(&models, &a).is_ok());
838    }
839
840    #[test]
841    fn test_perform_merge_slerp_route() {
842        let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
843        let a = MergeArgs {
844            models: vec![],
845            output: PathBuf::from("o"),
846            method: MergeMethod::Slerp,
847            weight: Some(0.5),
848            density: None,
849            weights: None,
850            base: None,
851            adapter: None,
852        };
853        assert!(perform_merge(&models, &a).is_ok());
854    }
855
856    #[test]
857    fn test_perform_merge_average_route() {
858        let models = vec![mk(&[("w", &[2.0])]), mk(&[("w", &[4.0])])];
859        let a = MergeArgs {
860            models: vec![],
861            output: PathBuf::from("o"),
862            method: MergeMethod::Average,
863            weight: None,
864            density: None,
865            weights: None,
866            base: None,
867            adapter: None,
868        };
869        let result = perform_merge(&models, &a).unwrap();
870        let vals = result["w"].data().as_slice().unwrap().to_vec();
871        assert!((vals[0] - 3.0).abs() < 1e-6);
872    }
873
874    // ── slerp merge with exactly 2 models ───────────────────────────────
875
876    #[test]
877    fn test_slerp_merge_two_models_ok() {
878        let ms = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
879        let a = MergeArgs {
880            models: vec![],
881            output: PathBuf::from("o"),
882            method: MergeMethod::Slerp,
883            weight: Some(0.3),
884            density: None,
885            weights: None,
886            base: None,
887            adapter: None,
888        };
889        assert!(perform_slerp_merge(&ms, &a).is_ok());
890    }
891
892    #[test]
893    fn test_slerp_merge_default_weight() {
894        let ms = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
895        let a = MergeArgs {
896            models: vec![],
897            output: PathBuf::from("o"),
898            method: MergeMethod::Slerp,
899            weight: None, // defaults to 0.5
900            density: None,
901            weights: None,
902            base: None,
903            adapter: None,
904        };
905        assert!(perform_slerp_merge(&ms, &a).is_ok());
906    }
907
908    // ── ties merge with density ─────────────────────────────────────────
909
910    #[test]
911    fn test_ties_merge_with_density() {
912        let a = MergeArgs {
913            models: vec![],
914            output: PathBuf::from("o"),
915            method: MergeMethod::Ties,
916            weight: None,
917            density: Some(0.8),
918            weights: None,
919            base: None,
920            adapter: None,
921        };
922        let models =
923            vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])]), mk(&[("w", &[1.2, 2.2])])];
924        let result = perform_ties_merge(&models, &a);
925        assert!(result.is_ok());
926    }
927
928    // ── dare merge with density ─────────────────────────────────────────
929
930    #[test]
931    fn test_dare_merge_with_density() {
932        let a = MergeArgs {
933            models: vec![],
934            output: PathBuf::from("o"),
935            method: MergeMethod::Dare,
936            weight: None,
937            density: Some(0.9),
938            weights: None,
939            base: None,
940            adapter: None,
941        };
942        let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
943        assert!(perform_dare_merge(&models, &a).is_ok());
944    }
945
946    // ── average merge with explicit weights ─────────────────────────────
947
948    #[test]
949    fn test_average_merge_with_weights() {
950        let a = MergeArgs {
951            models: vec![],
952            output: PathBuf::from("o"),
953            method: MergeMethod::Average,
954            weight: None,
955            density: None,
956            weights: Some("0.8,0.2".to_string()),
957            base: None,
958            adapter: None,
959        };
960        let models = vec![mk(&[("w", &[10.0])]), mk(&[("w", &[0.0])])];
961        let result = perform_average_merge(&models, &a).unwrap();
962        let vals = result["w"].data().as_slice().unwrap().to_vec();
963        // 0.8 * 10.0 + 0.2 * 0.0 = 8.0
964        assert!((vals[0] - 8.0).abs() < 1e-4);
965    }
966
967    // ── build_ensemble_config edge cases ────────────────────────────────
968
969    #[test]
970    fn test_build_ensemble_config_single_weight() {
971        let a = MergeArgs {
972            models: vec![],
973            output: PathBuf::from("o.json"),
974            method: MergeMethod::Average,
975            weight: None,
976            density: None,
977            weights: Some("1.0".to_string()),
978            base: None,
979            adapter: None,
980        };
981        let config = build_ensemble_config(&a);
982        assert!(config.is_ok());
983    }
984
985    #[test]
986    fn test_build_ensemble_config_three_weights() {
987        let a = MergeArgs {
988            models: vec![],
989            output: PathBuf::from("o.json"),
990            method: MergeMethod::Average,
991            weight: None,
992            density: None,
993            weights: Some("0.2, 0.3, 0.5".to_string()),
994            base: None,
995            adapter: None,
996        };
997        let config = build_ensemble_config(&a);
998        assert!(config.is_ok());
999    }
1000
1001    #[test]
1002    fn test_build_ensemble_config_empty_weights_string() {
1003        let a = MergeArgs {
1004            models: vec![],
1005            output: PathBuf::from("o.json"),
1006            method: MergeMethod::Average,
1007            weight: None,
1008            density: None,
1009            weights: Some(String::new()),
1010            base: None,
1011            adapter: None,
1012        };
1013        // Empty string should fail to parse as f32
1014        assert!(build_ensemble_config(&a).is_err());
1015    }
1016
1017    // ── validate_model_count edge cases ─────────────────────────────────
1018
1019    #[test]
1020    fn test_validate_model_count_one() {
1021        let a = MergeArgs {
1022            models: vec![PathBuf::from("a")],
1023            output: PathBuf::from("o"),
1024            method: MergeMethod::Ties,
1025            weight: None,
1026            density: None,
1027            weights: None,
1028            base: None,
1029            adapter: None,
1030        };
1031        assert!(validate_model_count(&a).is_err());
1032    }
1033
1034    #[test]
1035    fn test_validate_model_count_three() {
1036        let a = MergeArgs {
1037            models: vec![PathBuf::from("a"), PathBuf::from("b"), PathBuf::from("c")],
1038            output: PathBuf::from("o"),
1039            method: MergeMethod::Ties,
1040            weight: None,
1041            density: None,
1042            weights: None,
1043            base: None,
1044            adapter: None,
1045        };
1046        assert!(validate_model_count(&a).is_ok());
1047    }
1048
1049    // ── export_merged_model extension detection ─────────────────────────
1050
1051    #[test]
1052    fn test_export_merged_model_no_extension() {
1053        let m = mk(&[("w", &[1.0])]);
1054        let t = std::env::temp_dir().join("ent_merge_noext");
1055        let a = MergeArgs {
1056            models: vec![],
1057            output: t.clone(),
1058            method: MergeMethod::Average,
1059            weight: None,
1060            density: None,
1061            weights: None,
1062            base: None,
1063            adapter: None,
1064        };
1065        // Should fall through to JSON export (default)
1066        assert!(export_merged_model(&m, &a).is_ok());
1067        let _ = std::fs::remove_file(&t);
1068    }
1069
1070    // ── bytes_to_f32 additional edge cases ──────────────────────────────
1071
1072    #[test]
1073    fn test_bytes_to_f32_f32_multiple() {
1074        let vals = vec![1.0f32, 2.0, 3.5, -1.0];
1075        let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1076        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1077        assert_eq!(result.len(), 4);
1078        assert!((result[0] - 1.0).abs() < 1e-6);
1079        assert!((result[1] - 2.0).abs() < 1e-6);
1080        assert!((result[2] - 3.5).abs() < 1e-6);
1081        assert!((result[3] - (-1.0)).abs() < 1e-6);
1082    }
1083
1084    #[test]
1085    fn test_bytes_to_f32_f16_multiple() {
1086        let vals = vec![half::f16::from_f32(0.5), half::f16::from_f32(1.5)];
1087        let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1088        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1089        assert_eq!(result.len(), 2);
1090        assert!((result[0] - 0.5).abs() < 0.01);
1091        assert!((result[1] - 1.5).abs() < 0.01);
1092    }
1093
1094    #[test]
1095    fn test_bytes_to_f32_bf16_multiple() {
1096        let vals = vec![half::bf16::from_f32(3.0), half::bf16::from_f32(-1.0)];
1097        let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1098        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1099        assert_eq!(result.len(), 2);
1100        assert!((result[0] - 3.0).abs() < 0.1);
1101        assert!((result[1] - (-1.0)).abs() < 0.1);
1102    }
1103
1104    // ── build_safetensor_metadata tests ─────────────────────────────────
1105
1106    #[test]
1107    fn test_safetensor_metadata_ties() {
1108        let m = mk(&[("a", &[1.0]), ("b", &[2.0]), ("c", &[3.0])]);
1109        let a = MergeArgs {
1110            models: vec![],
1111            output: PathBuf::from("o.st"),
1112            method: MergeMethod::Ties,
1113            weight: None,
1114            density: None,
1115            weights: None,
1116            base: None,
1117            adapter: None,
1118        };
1119        let md = build_safetensor_metadata(&m, &a);
1120        assert_eq!(md["name"], "merged-model");
1121        assert_eq!(md["tensor_count"], "3");
1122        assert!(md["merge_method"].contains("Ties"));
1123    }
1124
1125    #[test]
1126    fn test_safetensor_metadata_slerp() {
1127        let m = mk(&[("x", &[1.0])]);
1128        let a = MergeArgs {
1129            models: vec![],
1130            output: PathBuf::from("o.st"),
1131            method: MergeMethod::Slerp,
1132            weight: None,
1133            density: None,
1134            weights: None,
1135            base: None,
1136            adapter: None,
1137        };
1138        let md = build_safetensor_metadata(&m, &a);
1139        assert!(md["merge_method"].contains("Slerp"));
1140    }
1141
1142    // ── log_merge_start and log_merge_complete with different levels ────
1143
1144    #[test]
1145    fn test_log_merge_start_normal() {
1146        let a = MergeArgs {
1147            models: vec![PathBuf::from("m1"), PathBuf::from("m2")],
1148            output: PathBuf::from("out"),
1149            method: MergeMethod::Average,
1150            weight: None,
1151            density: None,
1152            weights: None,
1153            base: None,
1154            adapter: None,
1155        };
1156        log_merge_start(&a, LogLevel::Normal);
1157    }
1158
1159    #[test]
1160    fn test_log_merge_complete_verbose() {
1161        let m = mk(&[("a", &[1.0, 2.0])]);
1162        let a = MergeArgs {
1163            models: vec![],
1164            output: PathBuf::from("merged.json"),
1165            method: MergeMethod::Dare,
1166            weight: None,
1167            density: None,
1168            weights: None,
1169            base: None,
1170            adapter: None,
1171        };
1172        log_merge_complete(&m, &a, LogLevel::Verbose);
1173    }
1174
1175    // ── LoRA merge error paths ──────────────────────────────────────────
1176
1177    #[test]
1178    fn test_lora_adapter_config_not_found() {
1179        // adapter dir exists but no adapter_config.json inside
1180        let dir = tempfile::tempdir().unwrap();
1181        // Create a fake base file
1182        let base_file = dir.path().join("base.safetensors");
1183        std::fs::write(&base_file, b"fake").unwrap();
1184        let a = MergeArgs {
1185            models: vec![],
1186            output: PathBuf::from("o"),
1187            method: MergeMethod::LoraAdapter,
1188            weight: None,
1189            density: None,
1190            weights: None,
1191            base: Some(base_file),
1192            adapter: Some(dir.path().to_path_buf()),
1193        };
1194        let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1195        assert!(err.contains("adapter_config.json"), "Error: {err}");
1196    }
1197
1198    #[test]
1199    fn test_lora_adapter_model_not_found() {
1200        let dir = tempfile::tempdir().unwrap();
1201        let base_file = dir.path().join("base.safetensors");
1202        std::fs::write(&base_file, b"fake").unwrap();
1203        // Create adapter_config.json but not adapter_model.safetensors
1204        std::fs::write(dir.path().join("adapter_config.json"), r#"{"r": 8, "lora_alpha": 16}"#)
1205            .unwrap();
1206        let a = MergeArgs {
1207            models: vec![],
1208            output: PathBuf::from("o"),
1209            method: MergeMethod::LoraAdapter,
1210            weight: None,
1211            density: None,
1212            weights: None,
1213            base: Some(base_file),
1214            adapter: Some(dir.path().to_path_buf()),
1215        };
1216        let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1217        assert!(err.contains("adapter_model.safetensors"), "Error: {err}");
1218    }
1219
1220    // ── run_merge with nonexistent model files ──────────────────────────
1221
1222    #[test]
1223    fn test_run_merge_nonexistent_models() {
1224        let a = MergeArgs {
1225            models: vec![PathBuf::from("/no/m1"), PathBuf::from("/no/m2")],
1226            output: PathBuf::from("o"),
1227            method: MergeMethod::Ties,
1228            weight: None,
1229            density: None,
1230            weights: None,
1231            base: None,
1232            adapter: None,
1233        };
1234        let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1235        assert!(err.contains("Failed to read"), "Error: {err}");
1236    }
1237
1238    // ── mk helper verify ────────────────────────────────────────────────
1239
1240    #[test]
1241    fn test_mk_helper_creates_model() {
1242        let model = mk(&[("a", &[1.0, 2.0, 3.0]), ("b", &[4.0])]);
1243        assert_eq!(model.len(), 2);
1244        assert!(model.contains_key("a"));
1245        assert!(model.contains_key("b"));
1246        assert_eq!(model["a"].len(), 3);
1247        assert_eq!(model["b"].len(), 1);
1248    }
1249
1250    // ── export safetensors with multiple tensors ────────────────────────
1251
1252    #[test]
1253    fn test_export_safetensors_multiple_tensors() {
1254        let m = mk(&[("w1", &[1.0, 2.0]), ("w2", &[3.0, 4.0, 5.0])]);
1255        let t = std::env::temp_dir().join("ent_merge_multi.safetensors");
1256        let a = MergeArgs {
1257            models: vec![],
1258            output: t.clone(),
1259            method: MergeMethod::Average,
1260            weight: None,
1261            density: None,
1262            weights: None,
1263            base: None,
1264            adapter: None,
1265        };
1266        assert!(export_merged_model(&m, &a).is_ok());
1267        // Verify file was created and has content
1268        assert!(t.exists());
1269        let _ = std::fs::remove_file(&t);
1270    }
1271
1272    // ── export json roundtrip ───────────────────────────────────────────
1273
1274    #[test]
1275    fn test_export_json_roundtrip() {
1276        let m = mk(&[("w1", &[1.0, 2.0]), ("w2", &[3.0])]);
1277        let t = std::env::temp_dir().join("ent_merge_roundtrip.json");
1278        let a = MergeArgs {
1279            models: vec![],
1280            output: t.clone(),
1281            method: MergeMethod::Average,
1282            weight: None,
1283            density: None,
1284            weights: None,
1285            base: None,
1286            adapter: None,
1287        };
1288        assert!(export_merged_model(&m, &a).is_ok());
1289        let content = std::fs::read_to_string(&t).unwrap();
1290        let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1291        assert!(parsed.contains_key("w1"));
1292        assert!(parsed.contains_key("w2"));
1293        assert_eq!(parsed["w1"].len(), 2);
1294        let _ = std::fs::remove_file(&t);
1295    }
1296
1297    // =========================================================================
1298    // test_cov2_* — Additional coverage tests
1299    // =========================================================================
1300
1301    /// Helper to build MergeArgs easily
1302    fn mk_args(method: MergeMethod) -> MergeArgs {
1303        MergeArgs {
1304            models: vec![],
1305            output: PathBuf::from("out.json"),
1306            method,
1307            weight: None,
1308            density: None,
1309            weights: None,
1310            base: None,
1311            adapter: None,
1312        }
1313    }
1314
1315    // ── bytes_to_f32 with zero-value data ────────────────────────────────
1316
1317    #[test]
1318    fn test_cov2_bytes_to_f32_f32_zeros() {
1319        let zeros = vec![0.0f32; 10];
1320        let bytes: Vec<u8> = zeros.iter().flat_map(|x| x.to_le_bytes()).collect();
1321        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1322        assert_eq!(result.len(), 10);
1323        assert!(result.iter().all(|&v| v == 0.0));
1324    }
1325
1326    #[test]
1327    fn test_cov2_bytes_to_f32_f32_negative() {
1328        let vals = vec![-1.0f32, -100.0, -0.001];
1329        let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1330        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1331        assert_eq!(result.len(), 3);
1332        assert!((result[0] - (-1.0)).abs() < 1e-6);
1333        assert!((result[1] - (-100.0)).abs() < 1e-6);
1334        assert!((result[2] - (-0.001)).abs() < 1e-6);
1335    }
1336
1337    #[test]
1338    fn test_cov2_bytes_to_f32_f32_large() {
1339        let vals = vec![1e30f32, -1e30];
1340        let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1341        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1342        assert_eq!(result.len(), 2);
1343        assert!((result[0] - 1e30).abs() / 1e30 < 1e-6);
1344    }
1345
1346    #[test]
1347    fn test_cov2_bytes_to_f32_f16_zero() {
1348        let zero = half::f16::from_f32(0.0);
1349        let bytes = zero.to_le_bytes().to_vec();
1350        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1351        assert_eq!(result.len(), 1);
1352        assert!((result[0]).abs() < 1e-6);
1353    }
1354
1355    #[test]
1356    fn test_cov2_bytes_to_f32_bf16_zero() {
1357        let zero = half::bf16::from_f32(0.0);
1358        let bytes = zero.to_le_bytes().to_vec();
1359        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1360        assert_eq!(result.len(), 1);
1361        assert!((result[0]).abs() < 1e-6);
1362    }
1363
1364    #[test]
1365    fn test_cov2_bytes_to_f32_f16_negative() {
1366        let neg = half::f16::from_f32(-3.14);
1367        let bytes = neg.to_le_bytes().to_vec();
1368        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1369        assert_eq!(result.len(), 1);
1370        assert!((result[0] - (-3.14)).abs() < 0.01);
1371    }
1372
1373    #[test]
1374    fn test_cov2_bytes_to_f32_bf16_negative() {
1375        let neg = half::bf16::from_f32(-5.0);
1376        let bytes = neg.to_le_bytes().to_vec();
1377        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1378        assert_eq!(result.len(), 1);
1379        assert!((result[0] - (-5.0)).abs() < 0.5);
1380    }
1381
1382    // ── bytes_to_f32 with truncated data (not aligned) ──────────────────
1383
1384    #[test]
1385    fn test_cov2_bytes_to_f32_f32_truncated() {
1386        // 5 bytes → only 1 full f32 chunk (4 bytes), remainder ignored
1387        let bytes: Vec<u8> = vec![0, 0, 128, 63, 99];
1388        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1389        assert_eq!(result.len(), 1);
1390        assert!((result[0] - 1.0).abs() < 1e-6);
1391    }
1392
1393    #[test]
1394    fn test_cov2_bytes_to_f32_f16_truncated() {
1395        // 3 bytes → only 1 full f16 chunk (2 bytes), remainder ignored
1396        let val = half::f16::from_f32(2.0);
1397        let mut bytes = val.to_le_bytes().to_vec();
1398        bytes.push(0xFF);
1399        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1400        assert_eq!(result.len(), 1);
1401        assert!((result[0] - 2.0).abs() < 0.01);
1402    }
1403
1404    // ── bytes_to_f32 with I64 fallback (other dtype) ────────────────────
1405
1406    #[test]
1407    fn test_cov2_bytes_to_f32_i64_fallback() {
1408        let v = 3.14f32;
1409        let bytes = v.to_le_bytes().to_vec();
1410        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::I64);
1411        assert_eq!(result.len(), 1);
1412        assert!((result[0] - 3.14).abs() < 1e-6);
1413    }
1414
1415    #[test]
1416    fn test_cov2_bytes_to_f32_u8_fallback() {
1417        let v = 7.0f32;
1418        let bytes = v.to_le_bytes().to_vec();
1419        let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::U8);
1420        assert_eq!(result.len(), 1);
1421        assert!((result[0] - 7.0).abs() < 1e-6);
1422    }
1423
1424    // ── build_ensemble_config with whitespace-padded weights ────────────
1425
1426    #[test]
1427    fn test_cov2_build_ensemble_config_whitespace_weights() {
1428        let a = MergeArgs {
1429            weights: Some("  0.5 , 0.3 , 0.2  ".to_string()),
1430            ..mk_args(MergeMethod::Average)
1431        };
1432        let config = build_ensemble_config(&a);
1433        assert!(config.is_ok());
1434    }
1435
1436    // ── build_ensemble_config with negative weights ─────────────────────
1437
1438    #[test]
1439    fn test_cov2_build_ensemble_config_negative_weights() {
1440        let a =
1441            MergeArgs { weights: Some("-0.5, 1.5".to_string()), ..mk_args(MergeMethod::Average) };
1442        let config = build_ensemble_config(&a);
1443        // Parsing should succeed (negative floats are valid f32)
1444        assert!(config.is_ok());
1445    }
1446
1447    // ── build_ensemble_config with large number of weights ──────────────
1448
1449    #[test]
1450    fn test_cov2_build_ensemble_config_many_weights() {
1451        let w_str = (0..10).map(|_| "0.1").collect::<Vec<_>>().join(",");
1452        let a = MergeArgs { weights: Some(w_str), ..mk_args(MergeMethod::Average) };
1453        let config = build_ensemble_config(&a);
1454        assert!(config.is_ok());
1455    }
1456
1457    // ── build_safetensor_metadata for each method ───────────────────────
1458
1459    #[test]
1460    fn test_cov2_safetensor_metadata_average() {
1461        let m = mk(&[("w", &[1.0])]);
1462        let a = mk_args(MergeMethod::Average);
1463        let md = build_safetensor_metadata(&m, &a);
1464        assert!(md["merge_method"].contains("Average"));
1465        assert_eq!(md["tensor_count"], "1");
1466    }
1467
1468    #[test]
1469    fn test_cov2_safetensor_metadata_dare() {
1470        let m = mk(&[("a", &[1.0]), ("b", &[2.0])]);
1471        let a = mk_args(MergeMethod::Dare);
1472        let md = build_safetensor_metadata(&m, &a);
1473        assert!(md["merge_method"].contains("Dare"));
1474        assert_eq!(md["tensor_count"], "2");
1475    }
1476
1477    #[test]
1478    fn test_cov2_safetensor_metadata_lora() {
1479        let m = mk(&[("w", &[1.0])]);
1480        let a = mk_args(MergeMethod::LoraAdapter);
1481        let md = build_safetensor_metadata(&m, &a);
1482        assert!(md["merge_method"].contains("LoraAdapter"));
1483    }
1484
1485    #[test]
1486    fn test_cov2_safetensor_metadata_empty_model() {
1487        let m: Model = HashMap::new();
1488        let a = mk_args(MergeMethod::Ties);
1489        let md = build_safetensor_metadata(&m, &a);
1490        assert_eq!(md["tensor_count"], "0");
1491    }
1492
1493    // ── validate_model_count edge: exactly 2 ────────────────────────────
1494
1495    #[test]
1496    fn test_cov2_validate_model_count_exactly_2() {
1497        let a = MergeArgs {
1498            models: vec![PathBuf::from("a"), PathBuf::from("b")],
1499            ..mk_args(MergeMethod::Average)
1500        };
1501        assert!(validate_model_count(&a).is_ok());
1502    }
1503
1504    #[test]
1505    fn test_cov2_validate_model_count_large() {
1506        let models: Vec<PathBuf> = (0..100).map(|i| PathBuf::from(format!("m{i}"))).collect();
1507        let a = MergeArgs { models, ..mk_args(MergeMethod::Average) };
1508        assert!(validate_model_count(&a).is_ok());
1509    }
1510
1511    // ── perform_merge LoRA early error message ──────────────────────────
1512
1513    #[test]
1514    fn test_cov2_perform_merge_lora_error_msg() {
1515        let a = mk_args(MergeMethod::LoraAdapter);
1516        let err = perform_merge(&[], &a).unwrap_err();
1517        assert_eq!(err, "LoRA adapter merge uses dedicated path");
1518    }
1519
1520    // ── export_merged_model to bad path ─────────────────────────────────
1521
1522    #[test]
1523    fn test_cov2_export_json_bad_path() {
1524        let m = mk(&[("w", &[1.0])]);
1525        let a = MergeArgs {
1526            output: PathBuf::from("/nonexistent_dir_xxxx/output.json"),
1527            ..mk_args(MergeMethod::Average)
1528        };
1529        let result = export_merged_model(&m, &a);
1530        assert!(result.is_err());
1531        assert!(result.unwrap_err().contains("Failed to write"));
1532    }
1533
1534    #[test]
1535    fn test_cov2_export_safetensors_bad_path() {
1536        let m = mk(&[("w", &[1.0])]);
1537        let a = MergeArgs {
1538            output: PathBuf::from("/nonexistent_dir_xxxx/output.safetensors"),
1539            ..mk_args(MergeMethod::Average)
1540        };
1541        let result = export_merged_model(&m, &a);
1542        assert!(result.is_err());
1543        assert!(result.unwrap_err().contains("Failed to write"));
1544    }
1545
1546    // ── export safetensors roundtrip ────────────────────────────────────
1547
1548    #[test]
1549    fn test_cov2_export_safetensors_roundtrip() {
1550        let m = mk(&[("layer1", &[1.0, 2.0, 3.0]), ("layer2", &[4.0, 5.0])]);
1551        let t = std::env::temp_dir().join("ent_merge_cov2_rt.safetensors");
1552        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Ties) };
1553        assert!(export_merged_model(&m, &a).is_ok());
1554        // Read back and verify
1555        let data = std::fs::read(&t).unwrap();
1556        let tensors = SafeTensors::deserialize(&data).unwrap();
1557        let names: Vec<&str> = tensors.names().clone();
1558        assert!(names.contains(&"layer1"));
1559        assert!(names.contains(&"layer2"));
1560        let _ = std::fs::remove_file(&t);
1561    }
1562
1563    // ── export json with empty model ────────────────────────────────────
1564
1565    #[test]
1566    fn test_cov2_export_json_empty_model() {
1567        let m: Model = HashMap::new();
1568        let t = std::env::temp_dir().join("ent_merge_cov2_empty.json");
1569        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1570        assert!(export_merged_model(&m, &a).is_ok());
1571        let content = std::fs::read_to_string(&t).unwrap();
1572        let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1573        assert!(parsed.is_empty());
1574        let _ = std::fs::remove_file(&t);
1575    }
1576
1577    // ── export safetensors with empty model ─────────────────────────────
1578
1579    #[test]
1580    fn test_cov2_export_safetensors_empty_model() {
1581        let m: Model = HashMap::new();
1582        let t = std::env::temp_dir().join("ent_merge_cov2_empty.safetensors");
1583        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1584        assert!(export_merged_model(&m, &a).is_ok());
1585        let _ = std::fs::remove_file(&t);
1586    }
1587
1588    // ── perform_ties_merge with default density ─────────────────────────
1589
1590    #[test]
1591    fn test_cov2_ties_merge_default_density() {
1592        let a = MergeArgs { density: None, ..mk_args(MergeMethod::Ties) };
1593        // density defaults to 0.2
1594        let models =
1595            vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])]), mk(&[("w", &[1.2, 2.2])])];
1596        assert!(perform_ties_merge(&models, &a).is_ok());
1597    }
1598
1599    // ── perform_dare_merge with default density ─────────────────────────
1600
1601    #[test]
1602    fn test_cov2_dare_merge_default_density() {
1603        let a = MergeArgs { density: None, ..mk_args(MergeMethod::Dare) };
1604        // density defaults to 0.5 → drop_prob = 0.5
1605        let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1606        assert!(perform_dare_merge(&models, &a).is_ok());
1607    }
1608
1609    // ── perform_slerp_merge with default weight ─────────────────────────
1610
1611    #[test]
1612    fn test_cov2_slerp_merge_default_weight() {
1613        let a = MergeArgs { weight: None, ..mk_args(MergeMethod::Slerp) };
1614        let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1615        let result = perform_slerp_merge(&models, &a);
1616        assert!(result.is_ok());
1617    }
1618
1619    // ── slerp with single model → error ─────────────────────────────────
1620
1621    #[test]
1622    fn test_cov2_slerp_single_model() {
1623        let a = mk_args(MergeMethod::Slerp);
1624        let models = vec![mk(&[("w", &[1.0])])];
1625        let err = perform_slerp_merge(&models, &a).unwrap_err();
1626        assert!(err.contains("SLERP requires exactly 2"));
1627    }
1628
1629    // ── run_merge with LoRA routes to lora function ─────────────────────
1630
1631    #[test]
1632    fn test_cov2_run_merge_lora_missing_both() {
1633        let a = MergeArgs {
1634            method: MergeMethod::LoraAdapter,
1635            base: None,
1636            adapter: None,
1637            ..mk_args(MergeMethod::LoraAdapter)
1638        };
1639        let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1640        assert!(err.contains("--base required"));
1641    }
1642
1643    #[test]
1644    fn test_cov2_run_merge_lora_has_base_no_adapter() {
1645        let a = MergeArgs {
1646            method: MergeMethod::LoraAdapter,
1647            base: Some(PathBuf::from("/tmp/some_base")),
1648            adapter: None,
1649            ..mk_args(MergeMethod::LoraAdapter)
1650        };
1651        let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1652        assert!(err.contains("--adapter required"));
1653    }
1654
1655    // ── load_single_model with empty file ───────────────────────────────
1656
1657    #[test]
1658    fn test_cov2_load_single_model_empty_file() {
1659        let dir = tempfile::tempdir().unwrap();
1660        let path = dir.path().join("empty.safetensors");
1661        std::fs::write(&path, b"").unwrap();
1662        let err = load_single_model(&path).unwrap_err();
1663        assert!(err.contains("Failed to parse"));
1664    }
1665
1666    // ── load_single_model with garbage data ─────────────────────────────
1667
1668    #[test]
1669    fn test_cov2_load_single_model_garbage() {
1670        let dir = tempfile::tempdir().unwrap();
1671        let path = dir.path().join("garbage.safetensors");
1672        std::fs::write(&path, b"this is not a safetensors file at all").unwrap();
1673        let err = load_single_model(&path).unwrap_err();
1674        assert!(err.contains("Failed to parse"));
1675    }
1676
1677    // ── run_merge models don't exist → load error ───────────────────────
1678
1679    #[test]
1680    fn test_cov2_run_merge_first_model_missing() {
1681        let dir = tempfile::tempdir().unwrap();
1682        let a = MergeArgs {
1683            models: vec![
1684                dir.path().join("no_exist_1.safetensors"),
1685                dir.path().join("no_exist_2.safetensors"),
1686            ],
1687            output: dir.path().join("out.json"),
1688            method: MergeMethod::Average,
1689            ..mk_args(MergeMethod::Average)
1690        };
1691        let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1692        assert!(err.contains("Failed to read"));
1693    }
1694
1695    // ── log functions with all log levels ────────────────────────────────
1696
1697    #[test]
1698    fn test_cov2_log_merge_start_quiet() {
1699        let a = MergeArgs {
1700            models: vec![PathBuf::from("m1"), PathBuf::from("m2"), PathBuf::from("m3")],
1701            output: PathBuf::from("out.safetensors"),
1702            ..mk_args(MergeMethod::Dare)
1703        };
1704        log_merge_start(&a, LogLevel::Quiet);
1705    }
1706
1707    #[test]
1708    fn test_cov2_log_merge_complete_quiet() {
1709        let m = mk(&[("w", &[1.0, 2.0, 3.0])]);
1710        let a = MergeArgs {
1711            output: PathBuf::from("merged.safetensors"),
1712            ..mk_args(MergeMethod::Average)
1713        };
1714        log_merge_complete(&m, &a, LogLevel::Quiet);
1715    }
1716
1717    // ── average merge with multiple tensors per model ───────────────────
1718
1719    #[test]
1720    fn test_cov2_average_merge_multi_tensor() {
1721        let a = mk_args(MergeMethod::Average);
1722        let m1 = mk(&[("a", &[1.0, 2.0]), ("b", &[3.0])]);
1723        let m2 = mk(&[("a", &[3.0, 4.0]), ("b", &[5.0])]);
1724        let result = perform_average_merge(&[m1, m2], &a).unwrap();
1725        let a_vals = result["a"].data().as_slice().unwrap().to_vec();
1726        let b_vals = result["b"].data().as_slice().unwrap().to_vec();
1727        assert!((a_vals[0] - 2.0).abs() < 1e-6);
1728        assert!((a_vals[1] - 3.0).abs() < 1e-6);
1729        assert!((b_vals[0] - 4.0).abs() < 1e-6);
1730    }
1731
1732    // ── slerp merge with custom weight ──────────────────────────────────
1733
1734    #[test]
1735    fn test_cov2_slerp_merge_weight_0() {
1736        let a = MergeArgs { weight: Some(0.0), ..mk_args(MergeMethod::Slerp) };
1737        let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1738        let result = perform_slerp_merge(&models, &a).unwrap();
1739        let vals = result["w"].data().as_slice().unwrap().to_vec();
1740        // t=0 should give model 0's values
1741        assert!((vals[0] - 1.0).abs() < 0.1);
1742    }
1743
1744    #[test]
1745    fn test_cov2_slerp_merge_weight_1() {
1746        let a = MergeArgs { weight: Some(1.0), ..mk_args(MergeMethod::Slerp) };
1747        let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1748        let result = perform_slerp_merge(&models, &a).unwrap();
1749        let vals = result["w"].data().as_slice().unwrap().to_vec();
1750        // t=1 should give model 1's values
1751        assert!((vals[1] - 1.0).abs() < 0.1);
1752    }
1753
1754    // ── ties merge with explicit density close to 1.0 ───────────────────
1755
1756    #[test]
1757    fn test_cov2_ties_merge_high_density() {
1758        let a = MergeArgs { density: Some(0.99), ..mk_args(MergeMethod::Ties) };
1759        let models = vec![
1760            mk(&[("w", &[1.0, 2.0, 3.0])]),
1761            mk(&[("w", &[1.1, 2.1, 3.1])]),
1762            mk(&[("w", &[1.2, 2.2, 3.2])]),
1763        ];
1764        assert!(perform_ties_merge(&models, &a).is_ok());
1765    }
1766
1767    // ── dare merge with density close to 0 ──────────────────────────────
1768
1769    #[test]
1770    fn test_cov2_dare_merge_low_density() {
1771        let a = MergeArgs { density: Some(0.01), ..mk_args(MergeMethod::Dare) };
1772        let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1773        assert!(perform_dare_merge(&models, &a).is_ok());
1774    }
1775
1776    // ── LoRA merge: adapter dir exists, has config, but no model.safetensors ─
1777
1778    #[test]
1779    fn test_cov2_lora_adapter_config_exists_no_model() {
1780        let dir = tempfile::tempdir().unwrap();
1781        let base = dir.path().join("base.safetensors");
1782        std::fs::write(&base, b"fake").unwrap();
1783        let adapter_dir = dir.path().join("adapter");
1784        std::fs::create_dir_all(&adapter_dir).unwrap();
1785        std::fs::write(adapter_dir.join("adapter_config.json"), r#"{"r":8}"#).unwrap();
1786        // No adapter_model.safetensors
1787        let a = MergeArgs {
1788            base: Some(base),
1789            adapter: Some(adapter_dir),
1790            ..mk_args(MergeMethod::LoraAdapter)
1791        };
1792        let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1793        assert!(err.contains("adapter_model.safetensors"));
1794    }
1795
1796    // ── LoRA merge: base path doesn't exist ─────────────────────────────
1797
1798    #[test]
1799    fn test_cov2_lora_base_path_not_found() {
1800        let dir = tempfile::tempdir().unwrap();
1801        let adapter_dir = dir.path().join("adapter");
1802        std::fs::create_dir_all(&adapter_dir).unwrap();
1803        let a = MergeArgs {
1804            base: Some(PathBuf::from("/definitely/not/exist/base.st")),
1805            adapter: Some(adapter_dir),
1806            ..mk_args(MergeMethod::LoraAdapter)
1807        };
1808        let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1809        assert!(err.contains("Base model not found"));
1810    }
1811
1812    // ── export extension detection ──────────────────────────────────────
1813
1814    #[test]
1815    fn test_cov2_export_extension_safetensors() {
1816        let m = mk(&[("w", &[1.0])]);
1817        let t = std::env::temp_dir().join("ent_merge_cov2_ext.safetensors");
1818        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1819        assert!(export_merged_model(&m, &a).is_ok());
1820        // Verify file exists and is valid safetensors
1821        let data = std::fs::read(&t).unwrap();
1822        assert!(SafeTensors::deserialize(&data).is_ok());
1823        let _ = std::fs::remove_file(&t);
1824    }
1825
1826    #[test]
1827    fn test_cov2_export_extension_json() {
1828        let m = mk(&[("w", &[1.0])]);
1829        let t = std::env::temp_dir().join("ent_merge_cov2_ext.json");
1830        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1831        assert!(export_merged_model(&m, &a).is_ok());
1832        let content = std::fs::read_to_string(&t).unwrap();
1833        assert!(serde_json::from_str::<HashMap<String, Vec<f32>>>(&content).is_ok());
1834        let _ = std::fs::remove_file(&t);
1835    }
1836
1837    #[test]
1838    fn test_cov2_export_extension_unknown() {
1839        let m = mk(&[("w", &[1.0])]);
1840        let t = std::env::temp_dir().join("ent_merge_cov2_ext.bin");
1841        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1842        // Unknown extension → falls through to JSON
1843        assert!(export_merged_model(&m, &a).is_ok());
1844        let content = std::fs::read_to_string(&t).unwrap();
1845        assert!(serde_json::from_str::<HashMap<String, Vec<f32>>>(&content).is_ok());
1846        let _ = std::fs::remove_file(&t);
1847    }
1848
1849    // ── mk helper edge cases ────────────────────────────────────────────
1850
1851    #[test]
1852    fn test_cov2_mk_empty_model() {
1853        let model = mk(&[]);
1854        assert!(model.is_empty());
1855    }
1856
1857    #[test]
1858    fn test_cov2_mk_single_empty_tensor() {
1859        let model = mk(&[("empty", &[])]);
1860        assert_eq!(model.len(), 1);
1861        assert_eq!(model["empty"].len(), 0);
1862    }
1863
1864    // ── perform_merge dispatch coverage ─────────────────────────────────
1865
1866    #[test]
1867    fn test_cov2_perform_merge_all_methods() {
1868        // Ties
1869        let models3 =
1870            vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])])];
1871        assert!(perform_merge(&models3, &mk_args(MergeMethod::Ties)).is_ok());
1872
1873        // Dare
1874        let models2 = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1875        assert!(perform_merge(&models2, &mk_args(MergeMethod::Dare)).is_ok());
1876
1877        // Slerp
1878        let models_s = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1879        assert!(perform_merge(
1880            &models_s,
1881            &MergeArgs { weight: Some(0.5), ..mk_args(MergeMethod::Slerp) }
1882        )
1883        .is_ok());
1884
1885        // Average
1886        let models_a = vec![mk(&[("w", &[2.0])]), mk(&[("w", &[4.0])])];
1887        assert!(perform_merge(&models_a, &mk_args(MergeMethod::Average)).is_ok());
1888
1889        // LoraAdapter → error
1890        assert!(perform_merge(&[], &mk_args(MergeMethod::LoraAdapter)).is_err());
1891    }
1892
1893    // ── large tensor export/import roundtrip ────────────────────────────
1894
1895    #[test]
1896    fn test_cov2_large_tensor_roundtrip() {
1897        let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
1898        let m = mk(&[("big", large_data.as_slice())]);
1899        let t = std::env::temp_dir().join("ent_merge_cov2_large.json");
1900        let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1901        assert!(export_merged_model(&m, &a).is_ok());
1902        let content = std::fs::read_to_string(&t).unwrap();
1903        let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1904        assert_eq!(parsed["big"].len(), 1000);
1905        assert!((parsed["big"][500] - 0.5).abs() < 1e-3);
1906        let _ = std::fs::remove_file(&t);
1907    }
1908}