use std::collections::HashMap;
use crate::error::{AprenderError, Result};
const VALID_STRATEGIES: &[&str] = &[
"average",
"weighted_average",
"slerp",
"ties",
"dare",
"passthrough",
];
#[derive(Debug, Clone)]
pub struct LayerMergeConfig {
pub layer_rules: Vec<LayerRule>,
pub default_strategy: String,
pub default_weights: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct LayerRule {
pub layer_pattern: String,
pub strategy: String,
pub weights: Option<Vec<f64>>,
pub scale: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct MergeYamlConfig {
pub models: Vec<ModelSource>,
pub output: String,
pub default_strategy: String,
pub layers: Option<Vec<LayerRule>>,
}
#[derive(Debug, Clone)]
pub struct ModelSource {
pub path: String,
pub weight: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct LayerMergeReport {
pub tensors_processed: usize,
pub rules_matched: HashMap<String, usize>,
}
impl LayerMergeReport {
#[must_use]
pub fn new() -> Self {
Self {
tensors_processed: 0,
rules_matched: HashMap::new(),
}
}
pub fn record_tensor(&mut self, matched_pattern: Option<&str>) {
self.tensors_processed += 1;
if let Some(pattern) = matched_pattern {
*self.rules_matched.entry(pattern.to_string()).or_insert(0) += 1;
}
}
#[must_use]
pub fn total_matched(&self) -> usize {
self.rules_matched.values().sum()
}
#[must_use]
pub fn total_defaulted(&self) -> usize {
self.tensors_processed.saturating_sub(self.total_matched())
}
}
impl Default for LayerMergeReport {
fn default() -> Self {
Self::new()
}
}
pub fn match_layer_rule<'a>(tensor_name: &str, rules: &'a [LayerRule]) -> Option<&'a LayerRule> {
rules
.iter()
.find(|rule| pattern_matches(tensor_name, &rule.layer_pattern))
}
fn pattern_matches(name: &str, pattern: &str) -> bool {
let normalized = pattern.replace("\\.", ".");
if !normalized.contains('*') {
return name.contains(&normalized);
}
let parts: Vec<&str> = normalized.split('*').collect();
let mut search_from = 0;
for part in &parts {
if part.is_empty() {
continue;
}
match name[search_from..].find(part) {
Some(pos) => {
search_from += pos + part.len();
}
None => return false,
}
}
true
}
#[derive(Debug, PartialEq)]
enum ParserSection {
Root,
Models,
ModelEntry,
Layers,
LayerEntry,
}
pub fn parse_merge_yaml(yaml_str: &str) -> Result<MergeYamlConfig> {
let mut models: Vec<ModelSource> = Vec::new();
let mut output = String::new();
let mut default_strategy = String::new();
let mut layer_rules: Vec<LayerRule> = Vec::new();
let mut section = ParserSection::Root;
let mut current_model_path = String::new();
let mut current_model_weight: Option<f64> = None;
let mut current_layer_pattern = String::new();
let mut current_layer_strategy = String::new();
let mut current_layer_weights: Option<Vec<f64>> = None;
let mut current_layer_scale: Option<f64> = None;
for line in yaml_str.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if !line.starts_with(' ') && !line.starts_with('\t') && !trimmed.starts_with('-') {
flush_model_entry(
§ion,
&mut models,
&mut current_model_path,
&mut current_model_weight,
);
flush_layer_entry(
§ion,
&mut layer_rules,
&mut current_layer_pattern,
&mut current_layer_strategy,
&mut current_layer_weights,
&mut current_layer_scale,
);
if trimmed.ends_with(':') && !trimmed.contains(": ") {
let key = trimmed.trim_end_matches(':');
match key {
"models" => {
section = ParserSection::Models;
continue;
}
"layers" => {
section = ParserSection::Layers;
continue;
}
_ => {}
}
}
if let Some(val) = parse_kv(trimmed, "output") {
output = unquote(val);
section = ParserSection::Root;
} else if let Some(val) = parse_kv(trimmed, "default_strategy") {
default_strategy = unquote(val);
section = ParserSection::Root;
}
continue;
}
match section {
ParserSection::Models | ParserSection::ModelEntry => {
if trimmed.starts_with("- ") || trimmed == "-" {
flush_model_entry(
§ion,
&mut models,
&mut current_model_path,
&mut current_model_weight,
);
section = ParserSection::ModelEntry;
let after_dash = trimmed.trim_start_matches('-').trim();
if let Some(val) = parse_kv(after_dash, "path") {
current_model_path = unquote(val);
}
} else if let Some(val) = parse_kv(trimmed, "path") {
current_model_path = unquote(val);
} else if let Some(val) = parse_kv(trimmed, "weight") {
current_model_weight = val.trim().parse::<f64>().ok();
}
}
ParserSection::Layers | ParserSection::LayerEntry => {
if trimmed.starts_with("- ") || trimmed == "-" {
flush_layer_entry(
§ion,
&mut layer_rules,
&mut current_layer_pattern,
&mut current_layer_strategy,
&mut current_layer_weights,
&mut current_layer_scale,
);
section = ParserSection::LayerEntry;
let after_dash = trimmed.trim_start_matches('-').trim();
if let Some(val) = parse_kv(after_dash, "layer_pattern") {
current_layer_pattern = unquote(val);
}
} else if let Some(val) = parse_kv(trimmed, "layer_pattern") {
current_layer_pattern = unquote(val);
} else if let Some(val) = parse_kv(trimmed, "strategy") {
current_layer_strategy = unquote(val);
} else if let Some(val) = parse_kv(trimmed, "weights") {
current_layer_weights = Some(parse_float_list(val));
} else if let Some(val) = parse_kv(trimmed, "scale") {
current_layer_scale = val.trim().parse::<f64>().ok();
}
}
ParserSection::Root => {
if let Some(val) = parse_kv(trimmed, "output") {
output = unquote(val);
} else if let Some(val) = parse_kv(trimmed, "default_strategy") {
default_strategy = unquote(val);
}
}
}
}
flush_model_entry(
§ion,
&mut models,
&mut current_model_path,
&mut current_model_weight,
);
flush_layer_entry(
§ion,
&mut layer_rules,
&mut current_layer_pattern,
&mut current_layer_strategy,
&mut current_layer_weights,
&mut current_layer_scale,
);
if output.is_empty() {
return Err(AprenderError::ValidationError {
message: "merge config missing required field: output".to_string(),
});
}
if default_strategy.is_empty() {
return Err(AprenderError::ValidationError {
message: "merge config missing required field: default_strategy".to_string(),
});
}
let layers = if layer_rules.is_empty() {
None
} else {
Some(layer_rules)
};
Ok(MergeYamlConfig {
models,
output,
default_strategy,
layers,
})
}
pub fn validate_merge_config(config: &MergeYamlConfig) -> Result<()> {
if config.models.len() < 2 {
return Err(AprenderError::ValidationError {
message: format!(
"merge requires at least 2 models, got {}",
config.models.len()
),
});
}
if !is_valid_strategy(&config.default_strategy) {
return Err(AprenderError::ValidationError {
message: format!(
"unknown default strategy '{}', valid: {}",
config.default_strategy,
VALID_STRATEGIES.join(", ")
),
});
}
for (i, model) in config.models.iter().enumerate() {
if model.path.is_empty() {
return Err(AprenderError::ValidationError {
message: format!("model {} has empty path", i),
});
}
if let Some(w) = model.weight {
if !w.is_finite() || w < 0.0 {
return Err(AprenderError::ValidationError {
message: format!(
"model {} weight must be non-negative and finite, got {}",
i, w
),
});
}
}
}
if config.output.is_empty() {
return Err(AprenderError::ValidationError {
message: "output path is empty".to_string(),
});
}
if let Some(ref rules) = config.layers {
for (i, rule) in rules.iter().enumerate() {
if rule.layer_pattern.is_empty() {
return Err(AprenderError::ValidationError {
message: format!("layer rule {} has empty pattern", i),
});
}
if !is_valid_strategy(&rule.strategy) {
return Err(AprenderError::ValidationError {
message: format!(
"layer rule {} has unknown strategy '{}', valid: {}",
i,
rule.strategy,
VALID_STRATEGIES.join(", ")
),
});
}
if let Some(ref weights) = rule.weights {
for (j, &w) in weights.iter().enumerate() {
if !w.is_finite() {
return Err(AprenderError::ValidationError {
message: format!("layer rule {} weight[{}] is not finite: {}", i, j, w),
});
}
}
}
if let Some(s) = rule.scale {
if !s.is_finite() {
return Err(AprenderError::ValidationError {
message: format!("layer rule {} scale is not finite: {}", i, s),
});
}
}
}
}
Ok(())
}
fn is_valid_strategy(name: &str) -> bool {
VALID_STRATEGIES.contains(&name)
}
fn parse_kv<'a>(line: &'a str, key: &str) -> Option<&'a str> {
let trimmed = line.trim();
let prefix_with_space = format!("{}: ", key);
let prefix_bare = format!("{}:", key);
if trimmed.starts_with(&prefix_with_space) {
Some(trimmed[prefix_with_space.len()..].trim())
} else if trimmed == prefix_bare {
None
} else if trimmed.starts_with(&prefix_bare) {
let rest = &trimmed[prefix_bare.len()..];
if rest.is_empty() {
None
} else {
Some(rest.trim())
}
} else {
None
}
}
fn unquote(s: &str) -> String {
let trimmed = s.trim();
if (trimmed.starts_with('"') && trimmed.ends_with('"'))
|| (trimmed.starts_with('\'') && trimmed.ends_with('\''))
{
if trimmed.len() >= 2 {
trimmed[1..trimmed.len() - 1].to_string()
} else {
trimmed.to_string()
}
} else {
trimmed.to_string()
}
}
fn parse_float_list(s: &str) -> Vec<f64> {
let inner = s
.trim()
.trim_start_matches('[')
.trim_end_matches(']')
.trim();
if inner.is_empty() {
return Vec::new();
}
inner
.split(',')
.filter_map(|part| part.trim().parse::<f64>().ok())
.collect()
}
fn flush_model_entry(
section: &ParserSection,
models: &mut Vec<ModelSource>,
path: &mut String,
weight: &mut Option<f64>,
) {
if matches!(section, ParserSection::ModelEntry) && !path.is_empty() {
models.push(ModelSource {
path: std::mem::take(path),
weight: weight.take(),
});
}
}
fn flush_layer_entry(
section: &ParserSection,
rules: &mut Vec<LayerRule>,
pattern: &mut String,
strategy: &mut String,
weights: &mut Option<Vec<f64>>,
scale: &mut Option<f64>,
) {
if matches!(section, ParserSection::LayerEntry) && !pattern.is_empty() {
rules.push(LayerRule {
layer_pattern: std::mem::take(pattern),
strategy: if strategy.is_empty() {
"average".to_string()
} else {
std::mem::take(strategy)
},
weights: weights.take(),
scale: scale.take(),
});
}
}
#[cfg(test)]
#[path = "per_layer_merge_tests.rs"]
mod tests;