#[allow(clippy::disallowed_methods)]
fn run_plan(
file: &Path,
method: PruneMethod,
target_ratio: f32,
sparsity: f32,
json_output: bool,
) -> Result<()> {
let file_size = std::fs::metadata(file)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read model: {e}")))?
.len();
let estimated_output = (file_size as f64 * (1.0 - target_ratio as f64)) as u64;
let peak_memory = file_size + estimated_output;
if json_output {
let json = serde_json::json!({
"plan": true,
"input": file.display().to_string(),
"input_size": file_size,
"method": format!("{method:?}"),
"target_ratio": target_ratio,
"sparsity": sparsity,
"estimated_output_size": estimated_output,
"peak_memory": peak_memory,
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
output::header("APR Prune — Plan");
println!(
"{}",
output::kv_table(&[
("Input", file.display().to_string()),
(
"Input size",
humansize::format_size(file_size, humansize::BINARY),
),
("Method", format!("{method:?}")),
("Target ratio", format!("{target_ratio:.2}")),
(
"Est. output",
humansize::format_size(estimated_output, humansize::BINARY),
),
(
"Peak memory",
humansize::format_size(peak_memory, humansize::BINARY),
),
])
);
println!();
println!(
" {} Run without --plan to execute.",
output::badge_info("INFO"),
);
}
Ok(())
}
fn prune_magnitude(
tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
sparsity: f32,
) -> std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let mut result = std::collections::BTreeMap::new();
for (name, (data, shape)) in tensors {
let mut abs_vals: Vec<f32> = data.iter().map(|v| v.abs()).collect();
abs_vals.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let cutoff_idx = ((abs_vals.len() as f64 * sparsity as f64) as usize)
.min(abs_vals.len().saturating_sub(1));
let threshold = abs_vals[cutoff_idx];
let pruned: Vec<f32> = data
.iter()
.map(|v| if v.abs() < threshold { 0.0 } else { *v })
.collect();
result.insert(name.clone(), (pruned, shape.clone()));
}
result
}
#[allow(clippy::type_complexity)]
fn prune_depth(
tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
layer_spec: &str,
) -> Result<std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
let layers_to_remove: Vec<usize> = if layer_spec.contains('-') {
let parts: Vec<&str> = layer_spec.split('-').collect();
if parts.len() != 2 {
return Err(CliError::ValidationFailed(format!(
"Invalid layer range: {layer_spec}"
)));
}
let start: usize = parts[0].parse().map_err(|_| {
CliError::ValidationFailed(format!("Invalid layer number: {}", parts[0]))
})?;
let end: usize = parts[1].parse().map_err(|_| {
CliError::ValidationFailed(format!("Invalid layer number: {}", parts[1]))
})?;
(start..=end).collect()
} else {
layer_spec
.split(',')
.map(|s| {
s.trim()
.parse::<usize>()
.map_err(|_| CliError::ValidationFailed(format!("Invalid layer number: {s}")))
})
.collect::<std::result::Result<Vec<_>, _>>()?
};
let mut result = std::collections::BTreeMap::new();
for (name, (data, shape)) in tensors {
let should_remove = layers_to_remove.iter().any(|layer_idx| {
let patterns = [
format!("layers.{layer_idx}."),
format!("blk.{layer_idx}."),
format!("h.{layer_idx}."),
];
patterns.iter().any(|p| name.contains(p))
});
if !should_remove {
result.insert(name.clone(), (data.clone(), shape.clone()));
}
}
Ok(result)
}
fn format_params(params: u64) -> String {
if params >= 1_000_000_000 {
format!("{:.1}B", params as f64 / 1_000_000_000.0)
} else if params >= 1_000_000 {
format!("{:.1}M", params as f64 / 1_000_000.0)
} else {
format!("{params}")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_prune_method_parse() {
assert!(matches!(
"magnitude".parse::<PruneMethod>(),
Ok(PruneMethod::Magnitude)
));
assert!(matches!(
"mag".parse::<PruneMethod>(),
Ok(PruneMethod::Magnitude)
));
assert!(matches!(
"structured".parse::<PruneMethod>(),
Ok(PruneMethod::Structured)
));
assert!(matches!(
"depth".parse::<PruneMethod>(),
Ok(PruneMethod::Depth)
));
assert!(matches!(
"width".parse::<PruneMethod>(),
Ok(PruneMethod::Width)
));
assert!(matches!(
"wanda".parse::<PruneMethod>(),
Ok(PruneMethod::Wanda)
));
assert!(matches!(
"sparsegpt".parse::<PruneMethod>(),
Ok(PruneMethod::SparseGpt)
));
assert!("unknown".parse::<PruneMethod>().is_err());
}
#[test]
fn test_run_file_not_found() {
let result = run(
Path::new("/nonexistent.apr"),
"magnitude",
0.5,
0.0,
Some(Path::new("/tmp/out.apr")),
None,
false,
false,
None,
false,
);
assert!(result.is_err());
assert!(matches!(result, Err(CliError::FileNotFound(_))));
}
#[test]
fn test_run_invalid_target_ratio_zero() {
let input = NamedTempFile::with_suffix(".apr").expect("create input");
let result = run(
input.path(),
"magnitude",
0.0,
0.0,
Some(Path::new("/tmp/out.apr")),
None,
false,
false,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Target ratio")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_invalid_target_ratio_one() {
let input = NamedTempFile::with_suffix(".apr").expect("create input");
let result = run(
input.path(),
"magnitude",
1.0,
0.0,
Some(Path::new("/tmp/out.apr")),
None,
false,
false,
None,
false,
);
assert!(result.is_err());
}
#[test]
fn test_run_invalid_sparsity() {
let input = NamedTempFile::with_suffix(".apr").expect("create input");
let result = run(
input.path(),
"magnitude",
0.5,
1.5,
Some(Path::new("/tmp/out.apr")),
None,
false,
false,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Sparsity")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_depth_requires_layers() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 512]).expect("write");
let result = run(
input.path(),
"depth",
0.5,
0.0,
Some(Path::new("/tmp/out.apr")),
None,
false,
false,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("remove-layers")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_run_no_output() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 512]).expect("write");
let result = run(
input.path(),
"magnitude",
0.5,
0.0,
None,
None,
false,
false,
None,
false,
);
assert!(result.is_err());
match result {
Err(CliError::ValidationFailed(msg)) => assert!(msg.contains("Output path")),
_ => panic!("Expected ValidationFailed"),
}
}
#[test]
fn test_analyze_mode() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 1024]).expect("write");
let result = run(
input.path(),
"magnitude",
0.5,
0.0,
None,
None,
true,
false,
None,
false,
);
assert!(result.is_ok());
}
#[test]
fn test_analyze_json() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 1024]).expect("write");
let result = run(
input.path(),
"magnitude",
0.5,
0.0,
None,
None,
true,
false,
None,
true,
);
assert!(result.is_ok());
}
#[test]
fn test_plan_mode() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 2048]).expect("write");
let result = run(
input.path(),
"structured",
0.3,
0.0,
None,
None,
false,
true,
None,
false,
);
assert!(result.is_ok());
}
#[test]
fn test_plan_json() {
let mut input = NamedTempFile::with_suffix(".apr").expect("create input");
input.write_all(&[0u8; 2048]).expect("write");
let result = run(
input.path(),
"magnitude",
0.5,
0.2,
None,
None,
false,
true,
None,
true,
);
assert!(result.is_ok());
}
#[test]
fn test_run_with_valid_input() {
let mut writer = aprender::serialization::apr::AprWriter::new();
writer.set_metadata("model_type", serde_json::json!("test"));
let weights: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
writer.add_tensor_f32("layers.0.self_attn.q_proj.weight", vec![8, 8], &weights);
let bias: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
writer.add_tensor_f32("layers.0.self_attn.q_proj.bias", vec![8], &bias);
let bytes = writer.to_bytes().expect("serialize");
let input = NamedTempFile::with_suffix(".apr").expect("create input");
std::fs::write(input.path(), &bytes).expect("write apr");
let output = NamedTempFile::with_suffix(".apr").expect("create output");
let result = run(
input.path(),
"magnitude",
0.5,
0.0,
Some(output.path()),
None,
false,
false,
None,
false,
);
assert!(result.is_ok(), "prune failed: {:?}", result.err());
let meta = std::fs::metadata(output.path()).expect("output exists");
assert!(meta.len() > 0, "Output file should not be empty");
}
}