use aprender::online::per_layer_merge::{
match_layer_rule, parse_merge_yaml, validate_merge_config, LayerMergeConfig, LayerRule,
};
fn main() {
println!("=== Per-Layer Merge Granularity (GH-452) ===\n");
println!("── 1. Layer Rule Matching ──");
let rules = vec![
LayerRule {
layer_pattern: "self_attn".to_string(),
strategy: "slerp".to_string(),
weights: Some(vec![0.7, 0.3]),
scale: None,
},
LayerRule {
layer_pattern: "mlp".to_string(),
strategy: "average".to_string(),
weights: None,
scale: None,
},
LayerRule {
layer_pattern: "embed".to_string(),
strategy: "passthrough".to_string(),
weights: None,
scale: Some(1.0),
},
];
let tensors = [
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.embed_tokens.weight",
"model.layers.5.self_attn.v_proj.weight",
"model.norm.weight",
];
for tensor in &tensors {
match match_layer_rule(tensor, &rules) {
Some(rule) => println!(" {} -> {} ({})", tensor, rule.strategy, rule.layer_pattern),
None => println!(" {} -> default", tensor),
}
}
println!("\n── 2. YAML Parsing ──");
let yaml = r#"
models:
- path: llama-7b-base.apr
weight: 0.6
- path: llama-7b-chat.apr
weight: 0.4
output: merged-llama.apr
default_strategy: average
layers:
- layer_pattern: self_attn
strategy: slerp
weights: [0.7, 0.3]
- layer_pattern: mlp
strategy: ties
weights: [0.5, 0.5]
scale: 0.8
"#;
let config = parse_merge_yaml(yaml).expect("parse_merge_yaml");
println!(" Models: {}", config.models.len());
for m in &config.models {
println!(" {} (weight: {:?})", m.path, m.weight);
}
println!(" Output: {}", config.output);
println!(" Default strategy: {}", config.default_strategy);
if let Some(ref layers) = config.layers {
println!(" Layer rules: {}", layers.len());
for r in layers {
println!(
" pattern='{}' strategy='{}'",
r.layer_pattern, r.strategy
);
}
}
println!("\n── 3. Config Validation ──");
match validate_merge_config(&config) {
Ok(()) => println!(" Config is valid"),
Err(e) => println!(" Validation error: {}", e),
}
println!("\n── 4. Invalid Config Rejection ──");
let bad_yaml = "models:\n - path: only-one.apr\noutput: out.apr\ndefault_strategy: average\n";
let bad_config = parse_merge_yaml(bad_yaml).expect("parse_merge_yaml");
match validate_merge_config(&bad_config) {
Ok(()) => println!(" Unexpected: accepted"),
Err(e) => println!(" Correctly rejected: {}", e),
}
println!("\n── 5. LayerMergeConfig ──");
let layer_cfg = LayerMergeConfig {
layer_rules: rules,
default_strategy: "average".to_string(),
default_weights: vec![0.5, 0.5],
};
println!(" Rules: {}", layer_cfg.layer_rules.len());
println!(" Default: {}", layer_cfg.default_strategy);
println!(" Default weights: {:?}", layer_cfg.default_weights);
println!("\n=== Per-layer merge verified ===");
}