use crate::error::{CliError, Result};
use crate::output;
use aprender::format::v2::{AprV2Header, AprV2Metadata, HEADER_SIZE_V2};
use std::io::{Read as _, Seek, SeekFrom};
use std::path::Path;
#[derive(Debug, Clone, Copy, Default)]
pub enum PruneMethod {
#[default]
Magnitude,
Structured,
Depth,
Width,
Wanda,
SparseGpt,
}
impl std::str::FromStr for PruneMethod {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"magnitude" | "mag" => Ok(Self::Magnitude),
"structured" | "struct" => Ok(Self::Structured),
"depth" | "layer" => Ok(Self::Depth),
"width" | "hidden" => Ok(Self::Width),
"wanda" => Ok(Self::Wanda),
"sparsegpt" | "sparse_gpt" => Ok(Self::SparseGpt),
_ => Err(format!(
"Unknown pruning method: {s}. Supported: magnitude, structured, depth, width, wanda, sparsegpt"
)),
}
}
}
fn validate_prune_params(
file: &Path,
method: &str,
target_ratio: f32,
sparsity: f32,
) -> Result<PruneMethod> {
if !file.exists() {
return Err(CliError::FileNotFound(file.to_path_buf()));
}
let prune_method: PruneMethod = method.parse().map_err(CliError::ValidationFailed)?;
if target_ratio <= 0.0 || target_ratio >= 1.0 {
return Err(CliError::ValidationFailed(format!(
"Target ratio must be between 0 and 1 (exclusive), got {target_ratio}"
)));
}
if !(0.0..=1.0).contains(&sparsity) {
return Err(CliError::ValidationFailed(format!(
"Sparsity must be between 0 and 1, got {sparsity}"
)));
}
Ok(prune_method)
}
#[allow(clippy::disallowed_methods)]
fn print_config_table(
file: &Path,
out: &Path,
prune_method: PruneMethod,
target_ratio: f32,
sparsity: f32,
remove_layers: Option<&str>,
calibration: Option<&Path>,
) {
output::header("APR Prune");
let mut pairs = vec![
("Input", file.display().to_string()),
("Method", format!("{prune_method:?}")),
("Target ratio", format!("{target_ratio:.2}")),
("Output", out.display().to_string()),
];
if sparsity > 0.0 {
pairs.push(("Sparsity", format!("{sparsity:.2}")));
}
if let Some(layers) = remove_layers {
pairs.push(("Remove layers", layers.to_string()));
}
if let Some(cal) = calibration {
pairs.push(("Calibration", cal.display().to_string()));
}
println!("{}", output::kv_table(&pairs));
println!();
}
fn validate_depth_args(prune_method: PruneMethod, remove_layers: Option<&str>) -> Result<()> {
if matches!(prune_method, PruneMethod::Depth) && remove_layers.is_none() {
return Err(CliError::ValidationFailed(
"Depth pruning requires --remove-layers (e.g., --remove-layers 20-24)".to_string(),
));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::disallowed_methods)]
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "mutating_output_contract"
)]
pub(crate) fn run(
file: &Path,
method: &str,
target_ratio: f32,
sparsity: f32,
output_path: Option<&Path>,
remove_layers: Option<&str>,
analyze_only: bool,
plan_only: bool,
calibration: Option<&Path>,
json_output: bool,
) -> Result<()> {
contract_pre_exit_code_on_failure!();
let prune_method = validate_prune_params(file, method, target_ratio, sparsity)?;
if calibration.is_some() {
return Err(CliError::ValidationFailed(
"--calibration is not yet implemented. Use magnitude or depth pruning without calibration data.".to_string(),
));
}
if analyze_only {
return run_analyze(file, prune_method, json_output);
}
if plan_only {
return run_plan(file, prune_method, target_ratio, sparsity, json_output);
}
let out = output_path.ok_or_else(|| {
CliError::ValidationFailed(
"Output path required. Use -o <path> to specify output.".to_string(),
)
})?;
if !json_output {
print_config_table(
file,
out,
prune_method,
target_ratio,
sparsity,
remove_layers,
calibration,
);
}
validate_depth_args(prune_method, remove_layers)?;
if !json_output {
output::pipeline_stage("Pruning", output::StageStatus::Running);
}
let prune_result = execute_pruning(
file,
prune_method,
target_ratio,
sparsity,
remove_layers,
out,
)?;
if !json_output {
output::pipeline_stage("Pruning", output::StageStatus::Done);
}
print_prune_output(
file,
out,
prune_method,
target_ratio,
sparsity,
&prune_result,
json_output,
);
contract_post_exit_code_on_failure!(&());
Ok(())
}
struct PruneResult {
file_size: u64,
output_size: u64,
original_count: usize,
pruned_count: usize,
original_params: usize,
pruned_params: usize,
zeros: usize,
}
fn execute_pruning(
file: &Path,
prune_method: PruneMethod,
target_ratio: f32,
sparsity: f32,
remove_layers: Option<&str>,
out: &Path,
) -> Result<PruneResult> {
let file_size = std::fs::metadata(file)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read model: {e}")))?
.len();
use aprender::format::converter::load_model_tensors;
let tensors = load_model_tensors(file)
.map_err(|e| CliError::ValidationFailed(format!("Failed to load model: {e}")))?;
let original_count = tensors.len();
let original_params: usize = tensors
.values()
.map(|(data, _shape): &(Vec<f32>, Vec<usize>)| data.len())
.sum();
let pruned_tensors = apply_pruning(
&tensors,
prune_method,
target_ratio,
sparsity,
remove_layers,
)?;
let pruned_count = pruned_tensors.len();
let pruned_params: usize = pruned_tensors
.values()
.map(|(data, _shape): &(Vec<f32>, Vec<usize>)| data.len())
.sum();
let zeros: usize = pruned_tensors
.values()
.map(|(data, _shape): &(Vec<f32>, Vec<usize>)| data.iter().filter(|v| **v == 0.0).count())
.sum();
let bytes = write_pruned_model(
file,
prune_method,
target_ratio,
sparsity,
&pruned_tensors,
out,
)?;
let output_size = bytes.len() as u64;
Ok(PruneResult {
file_size,
output_size,
original_count,
pruned_count,
original_params,
pruned_params,
zeros,
})
}
#[allow(clippy::type_complexity)]
fn apply_pruning(
tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
prune_method: PruneMethod,
target_ratio: f32,
sparsity: f32,
remove_layers: Option<&str>,
) -> Result<std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
match prune_method {
PruneMethod::Magnitude => Ok(prune_magnitude(tensors, sparsity.max(target_ratio))),
PruneMethod::Depth => {
let layers = remove_layers.expect("validated above");
prune_depth(tensors, layers)
}
PruneMethod::Structured | PruneMethod::Width => {
Ok(prune_magnitude(tensors, sparsity.max(target_ratio)))
}
PruneMethod::Wanda | PruneMethod::SparseGpt => {
Ok(prune_magnitude(tensors, sparsity.max(target_ratio)))
}
}
}
#[allow(clippy::disallowed_methods)]
fn write_pruned_model(
source_file: &Path,
prune_method: PruneMethod,
target_ratio: f32,
sparsity: f32,
pruned_tensors: &std::collections::BTreeMap<String, (Vec<f32>, Vec<usize>)>,
out: &Path,
) -> Result<Vec<u8>> {
let mut writer = aprender::serialization::apr::AprWriter::new();
writer.set_metadata(
"pruning_method",
serde_json::json!(format!("{prune_method:?}")),
);
writer.set_metadata("pruning_ratio", serde_json::json!(target_ratio));
writer.set_metadata("pruning_sparsity", serde_json::json!(sparsity));
writer.set_metadata(
"source_file",
serde_json::json!(source_file.display().to_string()),
);
for (name, (data, shape)) in pruned_tensors {
writer.add_tensor_f32(name, shape.clone(), data);
}
let bytes = writer.to_bytes().map_err(|e| {
CliError::ValidationFailed(format!("Failed to serialize pruned model: {e}"))
})?;
std::fs::write(out, &bytes)
.map_err(|e| CliError::ValidationFailed(format!("Failed to write output: {e}")))?;
Ok(bytes)
}
#[allow(clippy::disallowed_methods)]
fn print_prune_output(
file: &Path,
out: &Path,
prune_method: PruneMethod,
target_ratio: f32,
sparsity: f32,
result: &PruneResult,
json_output: bool,
) {
if json_output {
let json = serde_json::json!({
"status": "completed",
"input": file.display().to_string(),
"output": out.display().to_string(),
"method": format!("{prune_method:?}"),
"target_ratio": target_ratio,
"sparsity": sparsity,
"input_size": result.file_size,
"output_size": result.output_size,
"tensors": result.pruned_count,
"original_params": result.original_params,
"pruned_params": result.pruned_params,
"zero_params": result.zeros,
"actual_sparsity": if result.pruned_params > 0 { result.zeros as f64 / result.pruned_params as f64 } else { 0.0 },
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
println!();
output::subheader("Pruning Complete");
println!(
"{}",
output::kv_table(&[
(
"Input size",
humansize::format_size(result.file_size, humansize::BINARY)
),
(
"Output size",
humansize::format_size(result.output_size, humansize::BINARY)
),
(
"Tensors",
format!("{} → {}", result.original_count, result.pruned_count)
),
(
"Parameters",
format!("{} → {}", result.original_params, result.pruned_params)
),
(
"Zeros",
format!(
"{} ({:.1}%)",
result.zeros,
if result.pruned_params > 0 {
result.zeros as f64 / result.pruned_params as f64 * 100.0
} else {
0.0
}
)
),
("Output", out.display().to_string()),
])
);
}
}
#[allow(clippy::disallowed_methods)]
fn run_analyze(file: &Path, method: PruneMethod, json_output: bool) -> Result<()> {
let file_size = std::fs::metadata(file)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read model: {e}")))?
.len();
let estimated_params = read_param_count(file).unwrap_or_else(|_| {
file_size / 4
});
if json_output {
let json = serde_json::json!({
"analysis": true,
"input": file.display().to_string(),
"file_size": file_size,
"estimated_params": estimated_params,
"method": format!("{method:?}"),
"recommendations": [
{"ratio": 0.2, "description": "Conservative (minimal quality loss)"},
{"ratio": 0.5, "description": "Balanced (moderate compression)"},
{"ratio": 0.8, "description": "Aggressive (significant quality loss)"},
],
});
println!(
"{}",
serde_json::to_string_pretty(&json).unwrap_or_default()
);
} else {
output::header("APR Prune — Analysis");
println!(
"{}",
output::kv_table(&[
("Input", file.display().to_string()),
(
"File size",
humansize::format_size(file_size, humansize::BINARY),
),
("Est. parameters", format_params(estimated_params),),
("Method", format!("{method:?}")),
])
);
println!();
output::subheader("Pruning Recommendations");
println!(" 20% — Conservative (minimal quality loss)");
println!(" 50% — Balanced (moderate compression)");
println!(" 80% — Aggressive (significant quality loss)");
println!();
println!(
" {} Use --target-ratio <0-1> to set pruning target.",
output::badge_info("INFO"),
);
}
Ok(())
}
fn read_param_count(file: &Path) -> Result<u64> {
let mut reader = std::io::BufReader::new(std::fs::File::open(file).map_err(CliError::Io)?);
let mut header_bytes = [0u8; HEADER_SIZE_V2];
reader.read_exact(&mut header_bytes)?;
let header = AprV2Header::from_bytes(&header_bytes)
.map_err(|e| CliError::InvalidFormat(format!("Failed to parse header: {e}")))?;
if header.metadata_size > 0 {
reader
.seek(SeekFrom::Start(header.metadata_offset))
.map_err(CliError::Io)?;
let mut meta_bytes = vec![0u8; header.metadata_size as usize];
reader.read_exact(&mut meta_bytes)?;
if let Ok(meta) = AprV2Metadata::from_json(&meta_bytes) {
if meta.param_count > 0 {
return Ok(meta.param_count);
}
}
}
Err(CliError::ValidationFailed(
"No param_count in metadata".into(),
))
}
include!("prune_include_01.rs");