1use crate::autograd::Tensor;
4use crate::cli::logging::log;
5use crate::cli::LogLevel;
6use crate::config::{MergeArgs, MergeMethod};
7use crate::merge::{
8 dare_merge, ensemble_merge, slerp_merge, ties_merge, DareConfig, EnsembleConfig, Model,
9 SlerpConfig, TiesConfig,
10};
11use safetensors::SafeTensors;
12use std::collections::HashMap;
13use std::path::Path;
14
15pub fn run_merge(args: MergeArgs, level: LogLevel) -> Result<(), String> {
16 if args.method == MergeMethod::LoraAdapter {
18 return run_lora_adapter_merge(&args, level);
19 }
20
21 log_merge_start(&args, level);
22 validate_model_count(&args)?;
23
24 let models = load_all_models(&args.models, level)?;
25 let merged = perform_merge(&models, &args)?;
26 export_merged_model(&merged, &args)?;
27
28 log_merge_complete(&merged, &args, level);
29 Ok(())
30}
31
32fn log_merge_start(args: &MergeArgs, level: LogLevel) {
34 log(
35 level,
36 LogLevel::Normal,
37 &format!("Merging {} models using {:?}", args.models.len(), args.method),
38 );
39
40 for (i, model) in args.models.iter().enumerate() {
41 log(level, LogLevel::Verbose, &format!(" Model {}: {}", i + 1, model.display()));
42 }
43 log(level, LogLevel::Verbose, &format!(" Output: {}", args.output.display()));
44}
45
46fn validate_model_count(args: &MergeArgs) -> Result<(), String> {
48 if args.models.len() < 2 {
49 return Err("Need at least 2 models to merge".to_string());
50 }
51 Ok(())
52}
53
54fn load_all_models(paths: &[std::path::PathBuf], level: LogLevel) -> Result<Vec<Model>, String> {
56 let mut models: Vec<Model> = Vec::new();
57 for path in paths {
58 let model = load_single_model(path)?;
59 let tensor_count = model.len();
60 models.push(model);
61
62 log(
63 level,
64 LogLevel::Verbose,
65 &format!(" Loaded {} tensors from {}", tensor_count, path.display()),
66 );
67 }
68 Ok(models)
69}
70
71fn load_single_model(path: &Path) -> Result<Model, String> {
73 let data =
74 std::fs::read(path).map_err(|e| format!("Failed to read {}: {e}", path.display()))?;
75
76 let tensors = SafeTensors::deserialize(&data)
77 .map_err(|e| format!("Failed to parse {}: {e}", path.display()))?;
78
79 let mut model: Model = HashMap::new();
80 for name in tensors.names() {
81 if let Some(tensor) = extract_f32_tensor(&tensors, name)? {
82 model.insert((*name).to_string(), tensor);
83 }
84 }
85 Ok(model)
86}
87
88fn extract_f32_tensor(tensors: &SafeTensors<'_>, name: &str) -> Result<Option<Tensor>, String> {
90 let tensor = tensors.tensor(name).map_err(|e| format!("Failed to get tensor {name}: {e}"))?;
91
92 if tensor.dtype() != safetensors::tensor::Dtype::F32 {
93 return Ok(None);
94 }
95
96 let bytes = tensor.data();
97 let values: Vec<f32> = bytes
98 .chunks_exact(4)
99 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
100 .collect();
101
102 Ok(Some(Tensor::from_vec(values, false)))
103}
104
105fn perform_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
107 match args.method {
108 MergeMethod::Ties => perform_ties_merge(models, args),
109 MergeMethod::Dare => perform_dare_merge(models, args),
110 MergeMethod::Slerp => perform_slerp_merge(models, args),
111 MergeMethod::Average => perform_average_merge(models, args),
112 MergeMethod::LoraAdapter => {
113 Err("LoRA adapter merge uses dedicated path".to_string())
115 }
116 }
117}
118
119fn perform_ties_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
121 let config = TiesConfig { density: args.density.unwrap_or(0.2) };
122 let base = &models[0];
123 ties_merge(models.get(1..).unwrap_or_default(), base, &config)
124 .map_err(|e| format!("TIES merge failed: {e}"))
125}
126
127fn perform_dare_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
129 let config = DareConfig { drop_prob: 1.0 - args.density.unwrap_or(0.5), seed: None };
130 let base = &models[0];
131 dare_merge(models.get(1..).unwrap_or_default(), base, &config)
132 .map_err(|e| format!("DARE merge failed: {e}"))
133}
134
135fn perform_slerp_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
137 if models.len() != 2 {
138 return Err("SLERP requires exactly 2 models".to_string());
139 }
140 let config = SlerpConfig { t: args.weight.unwrap_or(0.5) };
141 slerp_merge(&models[0], &models[1], &config).map_err(|e| format!("SLERP merge failed: {e}"))
142}
143
144fn perform_average_merge(models: &[Model], args: &MergeArgs) -> Result<Model, String> {
146 let config = build_ensemble_config(args)?;
147 ensemble_merge(models, &config).map_err(|e| format!("Average merge failed: {e}"))
148}
149
150fn build_ensemble_config(args: &MergeArgs) -> Result<EnsembleConfig, String> {
152 if let Some(w_str) = &args.weights {
153 let weights: Vec<f32> = w_str
154 .split(',')
155 .map(|s| s.trim().parse::<f32>())
156 .collect::<Result<Vec<_>, _>>()
157 .map_err(|e| format!("Invalid weights: {e}"))?;
158 Ok(EnsembleConfig::weighted_average(weights))
159 } else {
160 Ok(EnsembleConfig::uniform_average())
161 }
162}
163
164fn export_merged_model(merged: &Model, args: &MergeArgs) -> Result<(), String> {
166 let output_ext = args.output.extension().and_then(|s| s.to_str()).unwrap_or("json");
167
168 if output_ext == "safetensors" {
169 export_safetensors(merged, args)
170 } else {
171 export_json(merged, args)
172 }
173}
174
175fn export_safetensors(merged: &Model, args: &MergeArgs) -> Result<(), String> {
177 use safetensors::tensor::{Dtype, TensorView};
178
179 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = merged
180 .iter()
181 .map(|(name, tensor)| {
182 let data = tensor.data();
183 let bytes: Vec<u8> = bytemuck::cast_slice(data.as_slice().unwrap_or(&[])).to_vec();
184 let shape = vec![tensor.len()];
185 (name.clone(), bytes, shape)
186 })
187 .collect();
188
189 let views: Vec<(&str, TensorView<'_>)> = tensor_data
190 .iter()
191 .filter_map(|(name, bytes, shape)| {
192 TensorView::new(Dtype::F32, shape.clone(), bytes).ok().map(|view| (name.as_str(), view))
193 })
194 .collect();
195
196 let metadata = build_safetensor_metadata(merged, args);
197 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
198 .map_err(|e| format!("Failed to serialize SafeTensors: {e}"))?;
199
200 std::fs::write(&args.output, safetensor_bytes)
201 .map_err(|e| format!("Failed to write output: {e}"))
202}
203
204fn build_safetensor_metadata(merged: &Model, args: &MergeArgs) -> HashMap<String, String> {
206 let mut metadata = HashMap::new();
207 metadata.insert("name".to_string(), "merged-model".to_string());
208 metadata.insert("merge_method".to_string(), format!("{:?}", args.method));
209 metadata.insert("tensor_count".to_string(), merged.len().to_string());
210 metadata
211}
212
213fn export_json(merged: &Model, args: &MergeArgs) -> Result<(), String> {
215 let output_data: HashMap<String, Vec<f32>> =
216 merged.iter().map(|(name, tensor)| (name.clone(), tensor.data().to_vec())).collect();
217
218 let json_data =
219 serde_json::to_vec_pretty(&output_data).map_err(|e| format!("Failed to serialize: {e}"))?;
220
221 std::fs::write(&args.output, &json_data).map_err(|e| format!("Failed to write output: {e}"))
222}
223
224fn log_merge_complete(merged: &Model, args: &MergeArgs, level: LogLevel) {
226 log(
227 level,
228 LogLevel::Normal,
229 &format!("Merge complete: {} tensors written to {}", merged.len(), args.output.display()),
230 );
231}
232
233fn run_lora_adapter_merge(args: &MergeArgs, level: LogLevel) -> Result<(), String> {
238 let base_path = args.base.as_ref().ok_or("--base required for lora-adapter merge")?;
239 let adapter_dir = args.adapter.as_ref().ok_or("--adapter required for lora-adapter merge")?;
240
241 let config_path = adapter_dir.join("adapter_config.json");
242 let adapter_path = adapter_dir.join("adapter_model.safetensors");
243
244 if !base_path.exists() {
245 return Err(format!("Base model not found: {}", base_path.display()));
246 }
247 if !config_path.exists() {
248 return Err(format!("adapter_config.json not found in {}", adapter_dir.display()));
249 }
250 if !adapter_path.exists() {
251 return Err(format!("adapter_model.safetensors not found in {}", adapter_dir.display()));
252 }
253
254 log(level, LogLevel::Normal, "LoRA adapter merge:");
255 log(level, LogLevel::Normal, &format!(" Base: {}", base_path.display()));
256 log(level, LogLevel::Normal, &format!(" Adapter: {}", adapter_dir.display()));
257
258 let config_str =
260 std::fs::read_to_string(&config_path).map_err(|e| format!("Read adapter config: {e}"))?;
261 let config: serde_json::Value =
262 serde_json::from_str(&config_str).map_err(|e| format!("Parse adapter config: {e}"))?;
263
264 let rank = config.get("r").and_then(serde_json::Value::as_u64).unwrap_or(8) as usize;
265 let alpha =
266 config.get("lora_alpha").and_then(serde_json::Value::as_f64).unwrap_or(rank as f64 * 2.0);
267 let scale = alpha as f32 / rank as f32;
268
269 log(level, LogLevel::Normal, &format!(" Rank: {rank}, Alpha: {alpha}, Scale: {scale:.4}"));
270
271 let base_data = std::fs::read(base_path).map_err(|e| format!("Read base model: {e}"))?;
273 let base_tensors =
274 SafeTensors::deserialize(&base_data).map_err(|e| format!("Parse base model: {e}"))?;
275
276 let adapter_data = std::fs::read(&adapter_path).map_err(|e| format!("Read adapter: {e}"))?;
278 let adapter_tensors =
279 SafeTensors::deserialize(&adapter_data).map_err(|e| format!("Parse adapter: {e}"))?;
280
281 let adapter_names: Vec<String> =
283 adapter_tensors.names().iter().map(|s| (*s).to_string()).collect();
284 let base_names: Vec<String> = base_tensors.names().iter().map(|s| (*s).to_string()).collect();
285
286 let lora_pairs = build_lora_pairs(&adapter_names, &adapter_tensors)?;
288 let mut merged_count = 0usize;
289
290 let mut output_tensors: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
292
293 for name in &base_names {
294 let base_t = base_tensors.tensor(name).map_err(|e| format!("Get tensor {name}: {e}"))?;
295 let shape: Vec<usize> = base_t.shape().to_vec();
296
297 if let Some((a_data, b_data, a_shape, b_shape)) = lora_pairs.get(name.as_str()) {
299 let base_f32 = bytes_to_f32(base_t.data(), base_t.dtype());
301 let a_f32 = bytes_to_f32(a_data, safetensors::tensor::Dtype::F32);
302 let b_f32 = bytes_to_f32(b_data, safetensors::tensor::Dtype::F32);
303
304 let d_out = b_shape[0];
305 let r = b_shape[1];
306 let d_in = a_shape[1];
307
308 let mut ba = vec![0.0f32; d_out * d_in];
310 for i in 0..d_out {
311 for j in 0..d_in {
312 let mut sum = 0.0f32;
313 for k in 0..r {
314 sum += b_f32[i * r + k] * a_f32[k * d_in + j];
315 }
316 ba[i * d_in + j] = sum;
317 }
318 }
319
320 let mut merged: Vec<f32> = base_f32;
322 for (i, val) in merged.iter_mut().enumerate() {
323 *val += scale * ba[i];
324 }
325
326 let bytes: Vec<u8> = bytemuck::cast_slice(&merged).to_vec();
327 output_tensors.push((name.clone(), bytes, shape));
328 merged_count += 1;
329 } else {
330 output_tensors.push((name.clone(), base_t.data().to_vec(), shape));
332 }
333 }
334
335 let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = output_tensors
337 .iter()
338 .filter_map(|(name, bytes, shape)| {
339 safetensors::tensor::TensorView::new(
340 safetensors::tensor::Dtype::F32,
341 shape.clone(),
342 bytes,
343 )
344 .ok()
345 .map(|view| (name.as_str(), view))
346 })
347 .collect();
348
349 let mut metadata = HashMap::new();
350 metadata.insert("format".to_string(), "entrenar-merged-lora".to_string());
351 metadata.insert("lora_rank".to_string(), rank.to_string());
352 metadata.insert("lora_alpha".to_string(), format!("{alpha}"));
353
354 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
355 .map_err(|e| format!("Serialize merged model: {e}"))?;
356
357 std::fs::write(&args.output, safetensor_bytes)
358 .map_err(|e| format!("Write merged model: {e}"))?;
359
360 let output_size = std::fs::metadata(&args.output).map(|m| m.len()).unwrap_or(0);
361 log(
362 level,
363 LogLevel::Normal,
364 &format!(" Merged {merged_count} adapter weights into base model"),
365 );
366 log(
367 level,
368 LogLevel::Normal,
369 &format!(" Output: {} ({:.2} MB)", args.output.display(), output_size as f64 / 1e6),
370 );
371
372 Ok(())
373}
374
375fn build_lora_pairs<'a>(
377 names: &[String],
378 tensors: &'a SafeTensors<'a>,
379) -> Result<HashMap<&'a str, (Vec<u8>, Vec<u8>, Vec<usize>, Vec<usize>)>, String> {
380 let mut pairs: HashMap<String, (Option<(Vec<u8>, Vec<usize>)>, Option<(Vec<u8>, Vec<usize>)>)> =
381 HashMap::new();
382
383 for name in names {
384 let (base_name, is_a) = if let Some(stripped) = name.strip_suffix(".lora_A.weight") {
386 (stripped.replace("base_model.model.", "") + ".weight", true)
387 } else if let Some(stripped) = name.strip_suffix(".lora_B.weight") {
388 (stripped.replace("base_model.model.", "") + ".weight", false)
389 } else {
390 continue;
391 };
392
393 let tensor = tensors.tensor(name).map_err(|e| format!("Get adapter tensor {name}: {e}"))?;
394 let data = tensor.data().to_vec();
395 let shape = tensor.shape().to_vec();
396
397 let entry = pairs.entry(base_name).or_insert((None, None));
398 if is_a {
399 entry.0 = Some((data, shape));
400 } else {
401 entry.1 = Some((data, shape));
402 }
403 }
404
405 let mut result = HashMap::new();
406 for (base_name, (a, b)) in &pairs {
407 if let (Some((a_data, a_shape)), Some((b_data, b_shape))) = (a, b) {
408 let key: &str = Box::leak(base_name.clone().into_boxed_str());
411 result.insert(key, (a_data.clone(), b_data.clone(), a_shape.clone(), b_shape.clone()));
412 }
413 }
414 Ok(result)
415}
416
417fn bytes_to_f32(data: &[u8], dtype: safetensors::tensor::Dtype) -> Vec<f32> {
419 match dtype {
420 safetensors::tensor::Dtype::F32 => {
421 data.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect()
422 }
423 safetensors::tensor::Dtype::F16 => data
424 .chunks_exact(2)
425 .map(|c| {
426 let bits = u16::from_le_bytes([c[0], c[1]]);
427 half::f16::from_bits(bits).to_f32()
428 })
429 .collect(),
430 safetensors::tensor::Dtype::BF16 => data
431 .chunks_exact(2)
432 .map(|c| {
433 let bits = u16::from_le_bytes([c[0], c[1]]);
434 half::bf16::from_bits(bits).to_f32()
435 })
436 .collect(),
437 _ => {
438 data.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect()
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 #![allow(clippy::unwrap_used)]
447 use super::*;
448 use std::path::PathBuf;
449
450 #[test]
451 fn test_validate_model_count_zero() {
452 let args = MergeArgs {
453 models: vec![],
454 output: PathBuf::from("o.json"),
455 method: MergeMethod::Ties,
456 weight: None,
457 density: None,
458 weights: None,
459 base: None,
460 adapter: None,
461 };
462 assert!(validate_model_count(&args).is_err());
463 }
464
465 #[test]
466 fn test_validate_model_count_two_ok() {
467 let args = MergeArgs {
468 models: vec![PathBuf::from("a"), PathBuf::from("b")],
469 output: PathBuf::from("o.json"),
470 method: MergeMethod::Ties,
471 weight: None,
472 density: None,
473 weights: None,
474 base: None,
475 adapter: None,
476 };
477 assert!(validate_model_count(&args).is_ok());
478 }
479
480 #[test]
481 fn test_build_ensemble_config_no_weights() {
482 let args = MergeArgs {
483 models: vec![],
484 output: PathBuf::from("o.json"),
485 method: MergeMethod::Average,
486 weight: None,
487 density: None,
488 weights: None,
489 base: None,
490 adapter: None,
491 };
492 assert!(build_ensemble_config(&args).is_ok());
493 }
494
495 #[test]
496 fn test_build_ensemble_config_with_weights() {
497 let args = MergeArgs {
498 models: vec![],
499 output: PathBuf::from("o.json"),
500 method: MergeMethod::Average,
501 weight: None,
502 density: None,
503 weights: Some("0.3, 0.7".into()),
504 base: None,
505 adapter: None,
506 };
507 assert!(build_ensemble_config(&args).is_ok());
508 }
509
510 #[test]
511 fn test_build_ensemble_config_invalid() {
512 let args = MergeArgs {
513 models: vec![],
514 output: PathBuf::from("o.json"),
515 method: MergeMethod::Average,
516 weight: None,
517 density: None,
518 weights: Some("abc".into()),
519 base: None,
520 adapter: None,
521 };
522 assert!(build_ensemble_config(&args).unwrap_err().contains("Invalid weights"));
523 }
524
525 fn mk(keys: &[(&str, &[f32])]) -> Model {
526 keys.iter().map(|(n, v)| (n.to_string(), Tensor::from_vec(v.to_vec(), false))).collect()
527 }
528
529 #[test]
530 fn test_slerp_wrong_count() {
531 let ms = vec![mk(&[("w", &[1.0])]), mk(&[("w", &[2.0])]), mk(&[("w", &[3.0])])];
532 let a = MergeArgs {
533 models: vec![],
534 output: PathBuf::from("o"),
535 method: MergeMethod::Slerp,
536 weight: None,
537 density: None,
538 weights: None,
539 base: None,
540 adapter: None,
541 };
542 assert!(perform_slerp_merge(&ms, &a).unwrap_err().contains("SLERP requires exactly 2"));
543 }
544
545 #[test]
546 fn test_merge_lora_err() {
547 let a = MergeArgs {
548 models: vec![],
549 output: PathBuf::from("o"),
550 method: MergeMethod::LoraAdapter,
551 weight: None,
552 density: None,
553 weights: None,
554 base: None,
555 adapter: None,
556 };
557 assert!(perform_merge(&[], &a).is_err());
558 }
559
560 #[test]
561 fn test_bytes_to_f32_f32() {
562 let v = vec![1.0f32, 2.5];
563 let b: Vec<u8> = v.iter().flat_map(|x| x.to_le_bytes()).collect();
564 let r = bytes_to_f32(&b, safetensors::tensor::Dtype::F32);
565 assert!((r[0] - 1.0).abs() < 1e-6);
566 }
567
568 #[test]
569 fn test_bytes_to_f32_f16() {
570 let b = half::f16::from_f32(1.0).to_le_bytes().to_vec();
571 assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::F16)[0] - 1.0).abs() < 0.01);
572 }
573
574 #[test]
575 fn test_bytes_to_f32_bf16() {
576 let b = half::bf16::from_f32(2.0).to_le_bytes().to_vec();
577 assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::BF16)[0] - 2.0).abs() < 0.1);
578 }
579
580 #[test]
581 fn test_bytes_to_f32_fallback() {
582 let b: Vec<u8> = 42.0f32.to_le_bytes().to_vec();
583 assert!((bytes_to_f32(&b, safetensors::tensor::Dtype::I8)[0] - 42.0).abs() < 1e-6);
584 }
585
586 #[test]
587 fn test_bytes_to_f32_empty() {
588 assert!(bytes_to_f32(&[], safetensors::tensor::Dtype::F32).is_empty());
589 }
590
591 #[test]
592 fn test_safetensor_metadata() {
593 let m = mk(&[("a", &[1.0]), ("b", &[2.0])]);
594 let a = MergeArgs {
595 models: vec![],
596 output: PathBuf::from("o.st"),
597 method: MergeMethod::Dare,
598 weight: None,
599 density: None,
600 weights: None,
601 base: None,
602 adapter: None,
603 };
604 let md = build_safetensor_metadata(&m, &a);
605 assert_eq!(md["name"], "merged-model");
606 assert_eq!(md["tensor_count"], "2");
607 }
608
609 #[test]
610 fn test_export_json() {
611 let m = mk(&[("w", &[1.0])]);
612 let t = std::env::temp_dir().join("ent_merge_j.json");
613 let a = MergeArgs {
614 models: vec![],
615 output: t.clone(),
616 method: MergeMethod::Average,
617 weight: None,
618 density: None,
619 weights: None,
620 base: None,
621 adapter: None,
622 };
623 assert!(export_merged_model(&m, &a).is_ok());
624 let _ = std::fs::remove_file(&t);
625 }
626
627 #[test]
628 fn test_export_safetensors() {
629 let m = mk(&[("w", &[1.0])]);
630 let t = std::env::temp_dir().join("ent_merge_s.safetensors");
631 let a = MergeArgs {
632 models: vec![],
633 output: t.clone(),
634 method: MergeMethod::Average,
635 weight: None,
636 density: None,
637 weights: None,
638 base: None,
639 adapter: None,
640 };
641 assert!(export_merged_model(&m, &a).is_ok());
642 let _ = std::fs::remove_file(&t);
643 }
644
645 #[test]
646 fn test_ties_merge_ok() {
647 let a = MergeArgs {
648 models: vec![],
649 output: PathBuf::from("o"),
650 method: MergeMethod::Ties,
651 weight: None,
652 density: None,
653 weights: None,
654 base: None,
655 adapter: None,
656 };
657 assert!(perform_ties_merge(
659 &[mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])]),],
660 &a
661 )
662 .is_ok());
663 }
664
665 #[test]
666 fn test_dare_merge_ok() {
667 let a = MergeArgs {
668 models: vec![],
669 output: PathBuf::from("o"),
670 method: MergeMethod::Dare,
671 weight: None,
672 density: None,
673 weights: None,
674 base: None,
675 adapter: None,
676 };
677 assert!(
678 perform_dare_merge(&[mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])])], &a).is_ok()
679 );
680 }
681
682 #[test]
683 fn test_average_merge() {
684 let a = MergeArgs {
685 models: vec![],
686 output: PathBuf::from("o"),
687 method: MergeMethod::Average,
688 weight: None,
689 density: None,
690 weights: None,
691 base: None,
692 adapter: None,
693 };
694 let r = perform_average_merge(&[mk(&[("w", &[2.0, 4.0])]), mk(&[("w", &[6.0, 8.0])])], &a)
695 .unwrap();
696 let s = r["w"].data().as_slice().unwrap().to_vec();
697 assert!((s[0] - 4.0).abs() < 1e-6);
698 }
699
700 #[test]
701 fn test_log_merge_no_panic() {
702 let a = MergeArgs {
703 models: vec![PathBuf::from("a"), PathBuf::from("b")],
704 output: PathBuf::from("o"),
705 method: MergeMethod::Ties,
706 weight: None,
707 density: None,
708 weights: None,
709 base: None,
710 adapter: None,
711 };
712 log_merge_start(&a, LogLevel::Quiet);
713 log_merge_start(&a, LogLevel::Verbose);
714 log_merge_complete(&mk(&[("w", &[1.0])]), &a, LogLevel::Normal);
715 }
716
717 #[test]
718 fn test_lora_missing_base() {
719 let a = MergeArgs {
720 models: vec![],
721 output: PathBuf::from("o"),
722 method: MergeMethod::LoraAdapter,
723 weight: None,
724 density: None,
725 weights: None,
726 base: None,
727 adapter: Some(PathBuf::from("/tmp")),
728 };
729 assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
730 .unwrap_err()
731 .contains("--base required"));
732 }
733
734 #[test]
735 fn test_lora_missing_adapter() {
736 let a = MergeArgs {
737 models: vec![],
738 output: PathBuf::from("o"),
739 method: MergeMethod::LoraAdapter,
740 weight: None,
741 density: None,
742 weights: None,
743 base: Some(PathBuf::from("/tmp/x")),
744 adapter: None,
745 };
746 assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
747 .unwrap_err()
748 .contains("--adapter required"));
749 }
750
751 #[test]
752 fn test_lora_base_not_found() {
753 let a = MergeArgs {
754 models: vec![],
755 output: PathBuf::from("o"),
756 method: MergeMethod::LoraAdapter,
757 weight: None,
758 density: None,
759 weights: None,
760 base: Some(PathBuf::from("/no/base")),
761 adapter: Some(PathBuf::from("/tmp")),
762 };
763 assert!(run_lora_adapter_merge(&a, LogLevel::Quiet)
764 .unwrap_err()
765 .contains("Base model not found"));
766 }
767
768 #[test]
769 fn test_load_nonexistent() {
770 assert!(load_single_model(std::path::Path::new("/no/m"))
771 .unwrap_err()
772 .contains("Failed to read"));
773 }
774
775 #[test]
776 fn test_run_merge_too_few() {
777 let a = MergeArgs {
778 models: vec![PathBuf::from("a")],
779 output: PathBuf::from("o"),
780 method: MergeMethod::Ties,
781 weight: None,
782 density: None,
783 weights: None,
784 base: None,
785 adapter: None,
786 };
787 assert!(run_merge(a, LogLevel::Quiet).unwrap_err().contains("Need at least 2"));
788 }
789
790 #[test]
791 fn test_run_merge_lora_routes() {
792 let a = MergeArgs {
793 models: vec![],
794 output: PathBuf::from("o"),
795 method: MergeMethod::LoraAdapter,
796 weight: None,
797 density: None,
798 weights: None,
799 base: None,
800 adapter: None,
801 };
802 assert!(run_merge(a, LogLevel::Quiet).unwrap_err().contains("--base required"));
803 }
804
805 #[test]
808 fn test_perform_merge_ties_route() {
809 let models =
810 vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])])];
811 let a = MergeArgs {
812 models: vec![],
813 output: PathBuf::from("o"),
814 method: MergeMethod::Ties,
815 weight: None,
816 density: Some(0.5),
817 weights: None,
818 base: None,
819 adapter: None,
820 };
821 assert!(perform_merge(&models, &a).is_ok());
822 }
823
824 #[test]
825 fn test_perform_merge_dare_route() {
826 let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
827 let a = MergeArgs {
828 models: vec![],
829 output: PathBuf::from("o"),
830 method: MergeMethod::Dare,
831 weight: None,
832 density: Some(0.3),
833 weights: None,
834 base: None,
835 adapter: None,
836 };
837 assert!(perform_merge(&models, &a).is_ok());
838 }
839
840 #[test]
841 fn test_perform_merge_slerp_route() {
842 let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
843 let a = MergeArgs {
844 models: vec![],
845 output: PathBuf::from("o"),
846 method: MergeMethod::Slerp,
847 weight: Some(0.5),
848 density: None,
849 weights: None,
850 base: None,
851 adapter: None,
852 };
853 assert!(perform_merge(&models, &a).is_ok());
854 }
855
856 #[test]
857 fn test_perform_merge_average_route() {
858 let models = vec![mk(&[("w", &[2.0])]), mk(&[("w", &[4.0])])];
859 let a = MergeArgs {
860 models: vec![],
861 output: PathBuf::from("o"),
862 method: MergeMethod::Average,
863 weight: None,
864 density: None,
865 weights: None,
866 base: None,
867 adapter: None,
868 };
869 let result = perform_merge(&models, &a).unwrap();
870 let vals = result["w"].data().as_slice().unwrap().to_vec();
871 assert!((vals[0] - 3.0).abs() < 1e-6);
872 }
873
874 #[test]
877 fn test_slerp_merge_two_models_ok() {
878 let ms = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
879 let a = MergeArgs {
880 models: vec![],
881 output: PathBuf::from("o"),
882 method: MergeMethod::Slerp,
883 weight: Some(0.3),
884 density: None,
885 weights: None,
886 base: None,
887 adapter: None,
888 };
889 assert!(perform_slerp_merge(&ms, &a).is_ok());
890 }
891
892 #[test]
893 fn test_slerp_merge_default_weight() {
894 let ms = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
895 let a = MergeArgs {
896 models: vec![],
897 output: PathBuf::from("o"),
898 method: MergeMethod::Slerp,
899 weight: None, density: None,
901 weights: None,
902 base: None,
903 adapter: None,
904 };
905 assert!(perform_slerp_merge(&ms, &a).is_ok());
906 }
907
908 #[test]
911 fn test_ties_merge_with_density() {
912 let a = MergeArgs {
913 models: vec![],
914 output: PathBuf::from("o"),
915 method: MergeMethod::Ties,
916 weight: None,
917 density: Some(0.8),
918 weights: None,
919 base: None,
920 adapter: None,
921 };
922 let models =
923 vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])]), mk(&[("w", &[1.2, 2.2])])];
924 let result = perform_ties_merge(&models, &a);
925 assert!(result.is_ok());
926 }
927
928 #[test]
931 fn test_dare_merge_with_density() {
932 let a = MergeArgs {
933 models: vec![],
934 output: PathBuf::from("o"),
935 method: MergeMethod::Dare,
936 weight: None,
937 density: Some(0.9),
938 weights: None,
939 base: None,
940 adapter: None,
941 };
942 let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
943 assert!(perform_dare_merge(&models, &a).is_ok());
944 }
945
946 #[test]
949 fn test_average_merge_with_weights() {
950 let a = MergeArgs {
951 models: vec![],
952 output: PathBuf::from("o"),
953 method: MergeMethod::Average,
954 weight: None,
955 density: None,
956 weights: Some("0.8,0.2".to_string()),
957 base: None,
958 adapter: None,
959 };
960 let models = vec![mk(&[("w", &[10.0])]), mk(&[("w", &[0.0])])];
961 let result = perform_average_merge(&models, &a).unwrap();
962 let vals = result["w"].data().as_slice().unwrap().to_vec();
963 assert!((vals[0] - 8.0).abs() < 1e-4);
965 }
966
967 #[test]
970 fn test_build_ensemble_config_single_weight() {
971 let a = MergeArgs {
972 models: vec![],
973 output: PathBuf::from("o.json"),
974 method: MergeMethod::Average,
975 weight: None,
976 density: None,
977 weights: Some("1.0".to_string()),
978 base: None,
979 adapter: None,
980 };
981 let config = build_ensemble_config(&a);
982 assert!(config.is_ok());
983 }
984
985 #[test]
986 fn test_build_ensemble_config_three_weights() {
987 let a = MergeArgs {
988 models: vec![],
989 output: PathBuf::from("o.json"),
990 method: MergeMethod::Average,
991 weight: None,
992 density: None,
993 weights: Some("0.2, 0.3, 0.5".to_string()),
994 base: None,
995 adapter: None,
996 };
997 let config = build_ensemble_config(&a);
998 assert!(config.is_ok());
999 }
1000
1001 #[test]
1002 fn test_build_ensemble_config_empty_weights_string() {
1003 let a = MergeArgs {
1004 models: vec![],
1005 output: PathBuf::from("o.json"),
1006 method: MergeMethod::Average,
1007 weight: None,
1008 density: None,
1009 weights: Some(String::new()),
1010 base: None,
1011 adapter: None,
1012 };
1013 assert!(build_ensemble_config(&a).is_err());
1015 }
1016
1017 #[test]
1020 fn test_validate_model_count_one() {
1021 let a = MergeArgs {
1022 models: vec![PathBuf::from("a")],
1023 output: PathBuf::from("o"),
1024 method: MergeMethod::Ties,
1025 weight: None,
1026 density: None,
1027 weights: None,
1028 base: None,
1029 adapter: None,
1030 };
1031 assert!(validate_model_count(&a).is_err());
1032 }
1033
1034 #[test]
1035 fn test_validate_model_count_three() {
1036 let a = MergeArgs {
1037 models: vec![PathBuf::from("a"), PathBuf::from("b"), PathBuf::from("c")],
1038 output: PathBuf::from("o"),
1039 method: MergeMethod::Ties,
1040 weight: None,
1041 density: None,
1042 weights: None,
1043 base: None,
1044 adapter: None,
1045 };
1046 assert!(validate_model_count(&a).is_ok());
1047 }
1048
1049 #[test]
1052 fn test_export_merged_model_no_extension() {
1053 let m = mk(&[("w", &[1.0])]);
1054 let t = std::env::temp_dir().join("ent_merge_noext");
1055 let a = MergeArgs {
1056 models: vec![],
1057 output: t.clone(),
1058 method: MergeMethod::Average,
1059 weight: None,
1060 density: None,
1061 weights: None,
1062 base: None,
1063 adapter: None,
1064 };
1065 assert!(export_merged_model(&m, &a).is_ok());
1067 let _ = std::fs::remove_file(&t);
1068 }
1069
1070 #[test]
1073 fn test_bytes_to_f32_f32_multiple() {
1074 let vals = vec![1.0f32, 2.0, 3.5, -1.0];
1075 let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1076 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1077 assert_eq!(result.len(), 4);
1078 assert!((result[0] - 1.0).abs() < 1e-6);
1079 assert!((result[1] - 2.0).abs() < 1e-6);
1080 assert!((result[2] - 3.5).abs() < 1e-6);
1081 assert!((result[3] - (-1.0)).abs() < 1e-6);
1082 }
1083
1084 #[test]
1085 fn test_bytes_to_f32_f16_multiple() {
1086 let vals = vec![half::f16::from_f32(0.5), half::f16::from_f32(1.5)];
1087 let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1088 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1089 assert_eq!(result.len(), 2);
1090 assert!((result[0] - 0.5).abs() < 0.01);
1091 assert!((result[1] - 1.5).abs() < 0.01);
1092 }
1093
1094 #[test]
1095 fn test_bytes_to_f32_bf16_multiple() {
1096 let vals = vec![half::bf16::from_f32(3.0), half::bf16::from_f32(-1.0)];
1097 let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1098 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1099 assert_eq!(result.len(), 2);
1100 assert!((result[0] - 3.0).abs() < 0.1);
1101 assert!((result[1] - (-1.0)).abs() < 0.1);
1102 }
1103
1104 #[test]
1107 fn test_safetensor_metadata_ties() {
1108 let m = mk(&[("a", &[1.0]), ("b", &[2.0]), ("c", &[3.0])]);
1109 let a = MergeArgs {
1110 models: vec![],
1111 output: PathBuf::from("o.st"),
1112 method: MergeMethod::Ties,
1113 weight: None,
1114 density: None,
1115 weights: None,
1116 base: None,
1117 adapter: None,
1118 };
1119 let md = build_safetensor_metadata(&m, &a);
1120 assert_eq!(md["name"], "merged-model");
1121 assert_eq!(md["tensor_count"], "3");
1122 assert!(md["merge_method"].contains("Ties"));
1123 }
1124
1125 #[test]
1126 fn test_safetensor_metadata_slerp() {
1127 let m = mk(&[("x", &[1.0])]);
1128 let a = MergeArgs {
1129 models: vec![],
1130 output: PathBuf::from("o.st"),
1131 method: MergeMethod::Slerp,
1132 weight: None,
1133 density: None,
1134 weights: None,
1135 base: None,
1136 adapter: None,
1137 };
1138 let md = build_safetensor_metadata(&m, &a);
1139 assert!(md["merge_method"].contains("Slerp"));
1140 }
1141
1142 #[test]
1145 fn test_log_merge_start_normal() {
1146 let a = MergeArgs {
1147 models: vec![PathBuf::from("m1"), PathBuf::from("m2")],
1148 output: PathBuf::from("out"),
1149 method: MergeMethod::Average,
1150 weight: None,
1151 density: None,
1152 weights: None,
1153 base: None,
1154 adapter: None,
1155 };
1156 log_merge_start(&a, LogLevel::Normal);
1157 }
1158
1159 #[test]
1160 fn test_log_merge_complete_verbose() {
1161 let m = mk(&[("a", &[1.0, 2.0])]);
1162 let a = MergeArgs {
1163 models: vec![],
1164 output: PathBuf::from("merged.json"),
1165 method: MergeMethod::Dare,
1166 weight: None,
1167 density: None,
1168 weights: None,
1169 base: None,
1170 adapter: None,
1171 };
1172 log_merge_complete(&m, &a, LogLevel::Verbose);
1173 }
1174
1175 #[test]
1178 fn test_lora_adapter_config_not_found() {
1179 let dir = tempfile::tempdir().unwrap();
1181 let base_file = dir.path().join("base.safetensors");
1183 std::fs::write(&base_file, b"fake").unwrap();
1184 let a = MergeArgs {
1185 models: vec![],
1186 output: PathBuf::from("o"),
1187 method: MergeMethod::LoraAdapter,
1188 weight: None,
1189 density: None,
1190 weights: None,
1191 base: Some(base_file),
1192 adapter: Some(dir.path().to_path_buf()),
1193 };
1194 let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1195 assert!(err.contains("adapter_config.json"), "Error: {err}");
1196 }
1197
1198 #[test]
1199 fn test_lora_adapter_model_not_found() {
1200 let dir = tempfile::tempdir().unwrap();
1201 let base_file = dir.path().join("base.safetensors");
1202 std::fs::write(&base_file, b"fake").unwrap();
1203 std::fs::write(dir.path().join("adapter_config.json"), r#"{"r": 8, "lora_alpha": 16}"#)
1205 .unwrap();
1206 let a = MergeArgs {
1207 models: vec![],
1208 output: PathBuf::from("o"),
1209 method: MergeMethod::LoraAdapter,
1210 weight: None,
1211 density: None,
1212 weights: None,
1213 base: Some(base_file),
1214 adapter: Some(dir.path().to_path_buf()),
1215 };
1216 let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1217 assert!(err.contains("adapter_model.safetensors"), "Error: {err}");
1218 }
1219
1220 #[test]
1223 fn test_run_merge_nonexistent_models() {
1224 let a = MergeArgs {
1225 models: vec![PathBuf::from("/no/m1"), PathBuf::from("/no/m2")],
1226 output: PathBuf::from("o"),
1227 method: MergeMethod::Ties,
1228 weight: None,
1229 density: None,
1230 weights: None,
1231 base: None,
1232 adapter: None,
1233 };
1234 let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1235 assert!(err.contains("Failed to read"), "Error: {err}");
1236 }
1237
1238 #[test]
1241 fn test_mk_helper_creates_model() {
1242 let model = mk(&[("a", &[1.0, 2.0, 3.0]), ("b", &[4.0])]);
1243 assert_eq!(model.len(), 2);
1244 assert!(model.contains_key("a"));
1245 assert!(model.contains_key("b"));
1246 assert_eq!(model["a"].len(), 3);
1247 assert_eq!(model["b"].len(), 1);
1248 }
1249
1250 #[test]
1253 fn test_export_safetensors_multiple_tensors() {
1254 let m = mk(&[("w1", &[1.0, 2.0]), ("w2", &[3.0, 4.0, 5.0])]);
1255 let t = std::env::temp_dir().join("ent_merge_multi.safetensors");
1256 let a = MergeArgs {
1257 models: vec![],
1258 output: t.clone(),
1259 method: MergeMethod::Average,
1260 weight: None,
1261 density: None,
1262 weights: None,
1263 base: None,
1264 adapter: None,
1265 };
1266 assert!(export_merged_model(&m, &a).is_ok());
1267 assert!(t.exists());
1269 let _ = std::fs::remove_file(&t);
1270 }
1271
1272 #[test]
1275 fn test_export_json_roundtrip() {
1276 let m = mk(&[("w1", &[1.0, 2.0]), ("w2", &[3.0])]);
1277 let t = std::env::temp_dir().join("ent_merge_roundtrip.json");
1278 let a = MergeArgs {
1279 models: vec![],
1280 output: t.clone(),
1281 method: MergeMethod::Average,
1282 weight: None,
1283 density: None,
1284 weights: None,
1285 base: None,
1286 adapter: None,
1287 };
1288 assert!(export_merged_model(&m, &a).is_ok());
1289 let content = std::fs::read_to_string(&t).unwrap();
1290 let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1291 assert!(parsed.contains_key("w1"));
1292 assert!(parsed.contains_key("w2"));
1293 assert_eq!(parsed["w1"].len(), 2);
1294 let _ = std::fs::remove_file(&t);
1295 }
1296
1297 fn mk_args(method: MergeMethod) -> MergeArgs {
1303 MergeArgs {
1304 models: vec![],
1305 output: PathBuf::from("out.json"),
1306 method,
1307 weight: None,
1308 density: None,
1309 weights: None,
1310 base: None,
1311 adapter: None,
1312 }
1313 }
1314
1315 #[test]
1318 fn test_cov2_bytes_to_f32_f32_zeros() {
1319 let zeros = vec![0.0f32; 10];
1320 let bytes: Vec<u8> = zeros.iter().flat_map(|x| x.to_le_bytes()).collect();
1321 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1322 assert_eq!(result.len(), 10);
1323 assert!(result.iter().all(|&v| v == 0.0));
1324 }
1325
1326 #[test]
1327 fn test_cov2_bytes_to_f32_f32_negative() {
1328 let vals = vec![-1.0f32, -100.0, -0.001];
1329 let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1330 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1331 assert_eq!(result.len(), 3);
1332 assert!((result[0] - (-1.0)).abs() < 1e-6);
1333 assert!((result[1] - (-100.0)).abs() < 1e-6);
1334 assert!((result[2] - (-0.001)).abs() < 1e-6);
1335 }
1336
1337 #[test]
1338 fn test_cov2_bytes_to_f32_f32_large() {
1339 let vals = vec![1e30f32, -1e30];
1340 let bytes: Vec<u8> = vals.iter().flat_map(|x| x.to_le_bytes()).collect();
1341 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1342 assert_eq!(result.len(), 2);
1343 assert!((result[0] - 1e30).abs() / 1e30 < 1e-6);
1344 }
1345
1346 #[test]
1347 fn test_cov2_bytes_to_f32_f16_zero() {
1348 let zero = half::f16::from_f32(0.0);
1349 let bytes = zero.to_le_bytes().to_vec();
1350 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1351 assert_eq!(result.len(), 1);
1352 assert!((result[0]).abs() < 1e-6);
1353 }
1354
1355 #[test]
1356 fn test_cov2_bytes_to_f32_bf16_zero() {
1357 let zero = half::bf16::from_f32(0.0);
1358 let bytes = zero.to_le_bytes().to_vec();
1359 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1360 assert_eq!(result.len(), 1);
1361 assert!((result[0]).abs() < 1e-6);
1362 }
1363
1364 #[test]
1365 fn test_cov2_bytes_to_f32_f16_negative() {
1366 let neg = half::f16::from_f32(-3.14);
1367 let bytes = neg.to_le_bytes().to_vec();
1368 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1369 assert_eq!(result.len(), 1);
1370 assert!((result[0] - (-3.14)).abs() < 0.01);
1371 }
1372
1373 #[test]
1374 fn test_cov2_bytes_to_f32_bf16_negative() {
1375 let neg = half::bf16::from_f32(-5.0);
1376 let bytes = neg.to_le_bytes().to_vec();
1377 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::BF16);
1378 assert_eq!(result.len(), 1);
1379 assert!((result[0] - (-5.0)).abs() < 0.5);
1380 }
1381
1382 #[test]
1385 fn test_cov2_bytes_to_f32_f32_truncated() {
1386 let bytes: Vec<u8> = vec![0, 0, 128, 63, 99];
1388 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F32);
1389 assert_eq!(result.len(), 1);
1390 assert!((result[0] - 1.0).abs() < 1e-6);
1391 }
1392
1393 #[test]
1394 fn test_cov2_bytes_to_f32_f16_truncated() {
1395 let val = half::f16::from_f32(2.0);
1397 let mut bytes = val.to_le_bytes().to_vec();
1398 bytes.push(0xFF);
1399 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::F16);
1400 assert_eq!(result.len(), 1);
1401 assert!((result[0] - 2.0).abs() < 0.01);
1402 }
1403
1404 #[test]
1407 fn test_cov2_bytes_to_f32_i64_fallback() {
1408 let v = 3.14f32;
1409 let bytes = v.to_le_bytes().to_vec();
1410 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::I64);
1411 assert_eq!(result.len(), 1);
1412 assert!((result[0] - 3.14).abs() < 1e-6);
1413 }
1414
1415 #[test]
1416 fn test_cov2_bytes_to_f32_u8_fallback() {
1417 let v = 7.0f32;
1418 let bytes = v.to_le_bytes().to_vec();
1419 let result = bytes_to_f32(&bytes, safetensors::tensor::Dtype::U8);
1420 assert_eq!(result.len(), 1);
1421 assert!((result[0] - 7.0).abs() < 1e-6);
1422 }
1423
1424 #[test]
1427 fn test_cov2_build_ensemble_config_whitespace_weights() {
1428 let a = MergeArgs {
1429 weights: Some(" 0.5 , 0.3 , 0.2 ".to_string()),
1430 ..mk_args(MergeMethod::Average)
1431 };
1432 let config = build_ensemble_config(&a);
1433 assert!(config.is_ok());
1434 }
1435
1436 #[test]
1439 fn test_cov2_build_ensemble_config_negative_weights() {
1440 let a =
1441 MergeArgs { weights: Some("-0.5, 1.5".to_string()), ..mk_args(MergeMethod::Average) };
1442 let config = build_ensemble_config(&a);
1443 assert!(config.is_ok());
1445 }
1446
1447 #[test]
1450 fn test_cov2_build_ensemble_config_many_weights() {
1451 let w_str = (0..10).map(|_| "0.1").collect::<Vec<_>>().join(",");
1452 let a = MergeArgs { weights: Some(w_str), ..mk_args(MergeMethod::Average) };
1453 let config = build_ensemble_config(&a);
1454 assert!(config.is_ok());
1455 }
1456
1457 #[test]
1460 fn test_cov2_safetensor_metadata_average() {
1461 let m = mk(&[("w", &[1.0])]);
1462 let a = mk_args(MergeMethod::Average);
1463 let md = build_safetensor_metadata(&m, &a);
1464 assert!(md["merge_method"].contains("Average"));
1465 assert_eq!(md["tensor_count"], "1");
1466 }
1467
1468 #[test]
1469 fn test_cov2_safetensor_metadata_dare() {
1470 let m = mk(&[("a", &[1.0]), ("b", &[2.0])]);
1471 let a = mk_args(MergeMethod::Dare);
1472 let md = build_safetensor_metadata(&m, &a);
1473 assert!(md["merge_method"].contains("Dare"));
1474 assert_eq!(md["tensor_count"], "2");
1475 }
1476
1477 #[test]
1478 fn test_cov2_safetensor_metadata_lora() {
1479 let m = mk(&[("w", &[1.0])]);
1480 let a = mk_args(MergeMethod::LoraAdapter);
1481 let md = build_safetensor_metadata(&m, &a);
1482 assert!(md["merge_method"].contains("LoraAdapter"));
1483 }
1484
1485 #[test]
1486 fn test_cov2_safetensor_metadata_empty_model() {
1487 let m: Model = HashMap::new();
1488 let a = mk_args(MergeMethod::Ties);
1489 let md = build_safetensor_metadata(&m, &a);
1490 assert_eq!(md["tensor_count"], "0");
1491 }
1492
1493 #[test]
1496 fn test_cov2_validate_model_count_exactly_2() {
1497 let a = MergeArgs {
1498 models: vec![PathBuf::from("a"), PathBuf::from("b")],
1499 ..mk_args(MergeMethod::Average)
1500 };
1501 assert!(validate_model_count(&a).is_ok());
1502 }
1503
1504 #[test]
1505 fn test_cov2_validate_model_count_large() {
1506 let models: Vec<PathBuf> = (0..100).map(|i| PathBuf::from(format!("m{i}"))).collect();
1507 let a = MergeArgs { models, ..mk_args(MergeMethod::Average) };
1508 assert!(validate_model_count(&a).is_ok());
1509 }
1510
1511 #[test]
1514 fn test_cov2_perform_merge_lora_error_msg() {
1515 let a = mk_args(MergeMethod::LoraAdapter);
1516 let err = perform_merge(&[], &a).unwrap_err();
1517 assert_eq!(err, "LoRA adapter merge uses dedicated path");
1518 }
1519
1520 #[test]
1523 fn test_cov2_export_json_bad_path() {
1524 let m = mk(&[("w", &[1.0])]);
1525 let a = MergeArgs {
1526 output: PathBuf::from("/nonexistent_dir_xxxx/output.json"),
1527 ..mk_args(MergeMethod::Average)
1528 };
1529 let result = export_merged_model(&m, &a);
1530 assert!(result.is_err());
1531 assert!(result.unwrap_err().contains("Failed to write"));
1532 }
1533
1534 #[test]
1535 fn test_cov2_export_safetensors_bad_path() {
1536 let m = mk(&[("w", &[1.0])]);
1537 let a = MergeArgs {
1538 output: PathBuf::from("/nonexistent_dir_xxxx/output.safetensors"),
1539 ..mk_args(MergeMethod::Average)
1540 };
1541 let result = export_merged_model(&m, &a);
1542 assert!(result.is_err());
1543 assert!(result.unwrap_err().contains("Failed to write"));
1544 }
1545
1546 #[test]
1549 fn test_cov2_export_safetensors_roundtrip() {
1550 let m = mk(&[("layer1", &[1.0, 2.0, 3.0]), ("layer2", &[4.0, 5.0])]);
1551 let t = std::env::temp_dir().join("ent_merge_cov2_rt.safetensors");
1552 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Ties) };
1553 assert!(export_merged_model(&m, &a).is_ok());
1554 let data = std::fs::read(&t).unwrap();
1556 let tensors = SafeTensors::deserialize(&data).unwrap();
1557 let names: Vec<&str> = tensors.names().clone();
1558 assert!(names.contains(&"layer1"));
1559 assert!(names.contains(&"layer2"));
1560 let _ = std::fs::remove_file(&t);
1561 }
1562
1563 #[test]
1566 fn test_cov2_export_json_empty_model() {
1567 let m: Model = HashMap::new();
1568 let t = std::env::temp_dir().join("ent_merge_cov2_empty.json");
1569 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1570 assert!(export_merged_model(&m, &a).is_ok());
1571 let content = std::fs::read_to_string(&t).unwrap();
1572 let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1573 assert!(parsed.is_empty());
1574 let _ = std::fs::remove_file(&t);
1575 }
1576
1577 #[test]
1580 fn test_cov2_export_safetensors_empty_model() {
1581 let m: Model = HashMap::new();
1582 let t = std::env::temp_dir().join("ent_merge_cov2_empty.safetensors");
1583 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1584 assert!(export_merged_model(&m, &a).is_ok());
1585 let _ = std::fs::remove_file(&t);
1586 }
1587
1588 #[test]
1591 fn test_cov2_ties_merge_default_density() {
1592 let a = MergeArgs { density: None, ..mk_args(MergeMethod::Ties) };
1593 let models =
1595 vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])]), mk(&[("w", &[1.2, 2.2])])];
1596 assert!(perform_ties_merge(&models, &a).is_ok());
1597 }
1598
1599 #[test]
1602 fn test_cov2_dare_merge_default_density() {
1603 let a = MergeArgs { density: None, ..mk_args(MergeMethod::Dare) };
1604 let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1606 assert!(perform_dare_merge(&models, &a).is_ok());
1607 }
1608
1609 #[test]
1612 fn test_cov2_slerp_merge_default_weight() {
1613 let a = MergeArgs { weight: None, ..mk_args(MergeMethod::Slerp) };
1614 let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1615 let result = perform_slerp_merge(&models, &a);
1616 assert!(result.is_ok());
1617 }
1618
1619 #[test]
1622 fn test_cov2_slerp_single_model() {
1623 let a = mk_args(MergeMethod::Slerp);
1624 let models = vec![mk(&[("w", &[1.0])])];
1625 let err = perform_slerp_merge(&models, &a).unwrap_err();
1626 assert!(err.contains("SLERP requires exactly 2"));
1627 }
1628
1629 #[test]
1632 fn test_cov2_run_merge_lora_missing_both() {
1633 let a = MergeArgs {
1634 method: MergeMethod::LoraAdapter,
1635 base: None,
1636 adapter: None,
1637 ..mk_args(MergeMethod::LoraAdapter)
1638 };
1639 let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1640 assert!(err.contains("--base required"));
1641 }
1642
1643 #[test]
1644 fn test_cov2_run_merge_lora_has_base_no_adapter() {
1645 let a = MergeArgs {
1646 method: MergeMethod::LoraAdapter,
1647 base: Some(PathBuf::from("/tmp/some_base")),
1648 adapter: None,
1649 ..mk_args(MergeMethod::LoraAdapter)
1650 };
1651 let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1652 assert!(err.contains("--adapter required"));
1653 }
1654
1655 #[test]
1658 fn test_cov2_load_single_model_empty_file() {
1659 let dir = tempfile::tempdir().unwrap();
1660 let path = dir.path().join("empty.safetensors");
1661 std::fs::write(&path, b"").unwrap();
1662 let err = load_single_model(&path).unwrap_err();
1663 assert!(err.contains("Failed to parse"));
1664 }
1665
1666 #[test]
1669 fn test_cov2_load_single_model_garbage() {
1670 let dir = tempfile::tempdir().unwrap();
1671 let path = dir.path().join("garbage.safetensors");
1672 std::fs::write(&path, b"this is not a safetensors file at all").unwrap();
1673 let err = load_single_model(&path).unwrap_err();
1674 assert!(err.contains("Failed to parse"));
1675 }
1676
1677 #[test]
1680 fn test_cov2_run_merge_first_model_missing() {
1681 let dir = tempfile::tempdir().unwrap();
1682 let a = MergeArgs {
1683 models: vec![
1684 dir.path().join("no_exist_1.safetensors"),
1685 dir.path().join("no_exist_2.safetensors"),
1686 ],
1687 output: dir.path().join("out.json"),
1688 method: MergeMethod::Average,
1689 ..mk_args(MergeMethod::Average)
1690 };
1691 let err = run_merge(a, LogLevel::Quiet).unwrap_err();
1692 assert!(err.contains("Failed to read"));
1693 }
1694
1695 #[test]
1698 fn test_cov2_log_merge_start_quiet() {
1699 let a = MergeArgs {
1700 models: vec![PathBuf::from("m1"), PathBuf::from("m2"), PathBuf::from("m3")],
1701 output: PathBuf::from("out.safetensors"),
1702 ..mk_args(MergeMethod::Dare)
1703 };
1704 log_merge_start(&a, LogLevel::Quiet);
1705 }
1706
1707 #[test]
1708 fn test_cov2_log_merge_complete_quiet() {
1709 let m = mk(&[("w", &[1.0, 2.0, 3.0])]);
1710 let a = MergeArgs {
1711 output: PathBuf::from("merged.safetensors"),
1712 ..mk_args(MergeMethod::Average)
1713 };
1714 log_merge_complete(&m, &a, LogLevel::Quiet);
1715 }
1716
1717 #[test]
1720 fn test_cov2_average_merge_multi_tensor() {
1721 let a = mk_args(MergeMethod::Average);
1722 let m1 = mk(&[("a", &[1.0, 2.0]), ("b", &[3.0])]);
1723 let m2 = mk(&[("a", &[3.0, 4.0]), ("b", &[5.0])]);
1724 let result = perform_average_merge(&[m1, m2], &a).unwrap();
1725 let a_vals = result["a"].data().as_slice().unwrap().to_vec();
1726 let b_vals = result["b"].data().as_slice().unwrap().to_vec();
1727 assert!((a_vals[0] - 2.0).abs() < 1e-6);
1728 assert!((a_vals[1] - 3.0).abs() < 1e-6);
1729 assert!((b_vals[0] - 4.0).abs() < 1e-6);
1730 }
1731
1732 #[test]
1735 fn test_cov2_slerp_merge_weight_0() {
1736 let a = MergeArgs { weight: Some(0.0), ..mk_args(MergeMethod::Slerp) };
1737 let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1738 let result = perform_slerp_merge(&models, &a).unwrap();
1739 let vals = result["w"].data().as_slice().unwrap().to_vec();
1740 assert!((vals[0] - 1.0).abs() < 0.1);
1742 }
1743
1744 #[test]
1745 fn test_cov2_slerp_merge_weight_1() {
1746 let a = MergeArgs { weight: Some(1.0), ..mk_args(MergeMethod::Slerp) };
1747 let models = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1748 let result = perform_slerp_merge(&models, &a).unwrap();
1749 let vals = result["w"].data().as_slice().unwrap().to_vec();
1750 assert!((vals[1] - 1.0).abs() < 0.1);
1752 }
1753
1754 #[test]
1757 fn test_cov2_ties_merge_high_density() {
1758 let a = MergeArgs { density: Some(0.99), ..mk_args(MergeMethod::Ties) };
1759 let models = vec![
1760 mk(&[("w", &[1.0, 2.0, 3.0])]),
1761 mk(&[("w", &[1.1, 2.1, 3.1])]),
1762 mk(&[("w", &[1.2, 2.2, 3.2])]),
1763 ];
1764 assert!(perform_ties_merge(&models, &a).is_ok());
1765 }
1766
1767 #[test]
1770 fn test_cov2_dare_merge_low_density() {
1771 let a = MergeArgs { density: Some(0.01), ..mk_args(MergeMethod::Dare) };
1772 let models = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1773 assert!(perform_dare_merge(&models, &a).is_ok());
1774 }
1775
1776 #[test]
1779 fn test_cov2_lora_adapter_config_exists_no_model() {
1780 let dir = tempfile::tempdir().unwrap();
1781 let base = dir.path().join("base.safetensors");
1782 std::fs::write(&base, b"fake").unwrap();
1783 let adapter_dir = dir.path().join("adapter");
1784 std::fs::create_dir_all(&adapter_dir).unwrap();
1785 std::fs::write(adapter_dir.join("adapter_config.json"), r#"{"r":8}"#).unwrap();
1786 let a = MergeArgs {
1788 base: Some(base),
1789 adapter: Some(adapter_dir),
1790 ..mk_args(MergeMethod::LoraAdapter)
1791 };
1792 let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1793 assert!(err.contains("adapter_model.safetensors"));
1794 }
1795
1796 #[test]
1799 fn test_cov2_lora_base_path_not_found() {
1800 let dir = tempfile::tempdir().unwrap();
1801 let adapter_dir = dir.path().join("adapter");
1802 std::fs::create_dir_all(&adapter_dir).unwrap();
1803 let a = MergeArgs {
1804 base: Some(PathBuf::from("/definitely/not/exist/base.st")),
1805 adapter: Some(adapter_dir),
1806 ..mk_args(MergeMethod::LoraAdapter)
1807 };
1808 let err = run_lora_adapter_merge(&a, LogLevel::Quiet).unwrap_err();
1809 assert!(err.contains("Base model not found"));
1810 }
1811
1812 #[test]
1815 fn test_cov2_export_extension_safetensors() {
1816 let m = mk(&[("w", &[1.0])]);
1817 let t = std::env::temp_dir().join("ent_merge_cov2_ext.safetensors");
1818 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1819 assert!(export_merged_model(&m, &a).is_ok());
1820 let data = std::fs::read(&t).unwrap();
1822 assert!(SafeTensors::deserialize(&data).is_ok());
1823 let _ = std::fs::remove_file(&t);
1824 }
1825
1826 #[test]
1827 fn test_cov2_export_extension_json() {
1828 let m = mk(&[("w", &[1.0])]);
1829 let t = std::env::temp_dir().join("ent_merge_cov2_ext.json");
1830 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1831 assert!(export_merged_model(&m, &a).is_ok());
1832 let content = std::fs::read_to_string(&t).unwrap();
1833 assert!(serde_json::from_str::<HashMap<String, Vec<f32>>>(&content).is_ok());
1834 let _ = std::fs::remove_file(&t);
1835 }
1836
1837 #[test]
1838 fn test_cov2_export_extension_unknown() {
1839 let m = mk(&[("w", &[1.0])]);
1840 let t = std::env::temp_dir().join("ent_merge_cov2_ext.bin");
1841 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1842 assert!(export_merged_model(&m, &a).is_ok());
1844 let content = std::fs::read_to_string(&t).unwrap();
1845 assert!(serde_json::from_str::<HashMap<String, Vec<f32>>>(&content).is_ok());
1846 let _ = std::fs::remove_file(&t);
1847 }
1848
1849 #[test]
1852 fn test_cov2_mk_empty_model() {
1853 let model = mk(&[]);
1854 assert!(model.is_empty());
1855 }
1856
1857 #[test]
1858 fn test_cov2_mk_single_empty_tensor() {
1859 let model = mk(&[("empty", &[])]);
1860 assert_eq!(model.len(), 1);
1861 assert_eq!(model["empty"].len(), 0);
1862 }
1863
1864 #[test]
1867 fn test_cov2_perform_merge_all_methods() {
1868 let models3 =
1870 vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.1, 2.1])]), mk(&[("w", &[1.2, 2.2])])];
1871 assert!(perform_merge(&models3, &mk_args(MergeMethod::Ties)).is_ok());
1872
1873 let models2 = vec![mk(&[("w", &[1.0, 2.0])]), mk(&[("w", &[1.5, 2.5])])];
1875 assert!(perform_merge(&models2, &mk_args(MergeMethod::Dare)).is_ok());
1876
1877 let models_s = vec![mk(&[("w", &[1.0, 0.0])]), mk(&[("w", &[0.0, 1.0])])];
1879 assert!(perform_merge(
1880 &models_s,
1881 &MergeArgs { weight: Some(0.5), ..mk_args(MergeMethod::Slerp) }
1882 )
1883 .is_ok());
1884
1885 let models_a = vec![mk(&[("w", &[2.0])]), mk(&[("w", &[4.0])])];
1887 assert!(perform_merge(&models_a, &mk_args(MergeMethod::Average)).is_ok());
1888
1889 assert!(perform_merge(&[], &mk_args(MergeMethod::LoraAdapter)).is_err());
1891 }
1892
1893 #[test]
1896 fn test_cov2_large_tensor_roundtrip() {
1897 let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
1898 let m = mk(&[("big", large_data.as_slice())]);
1899 let t = std::env::temp_dir().join("ent_merge_cov2_large.json");
1900 let a = MergeArgs { output: t.clone(), ..mk_args(MergeMethod::Average) };
1901 assert!(export_merged_model(&m, &a).is_ok());
1902 let content = std::fs::read_to_string(&t).unwrap();
1903 let parsed: HashMap<String, Vec<f32>> = serde_json::from_str(&content).unwrap();
1904 assert_eq!(parsed["big"].len(), 1000);
1905 assert!((parsed["big"][500] - 0.5).abs() < 1e-3);
1906 let _ = std::fs::remove_file(&t);
1907 }
1908}