#![allow(unreachable_pub)]
use serde::{Deserialize, Serialize};
const IMPORTANCE_LATENCY_MAX_MS: u64 = 5000;
const CALIBRATION_SYSCALL_BUDGET: u64 = 5000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningAssertion {
pub name: &'static str,
pub assertion_type: AssertionType,
pub max_value: u64,
pub fail_on_violation: bool,
pub enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AssertionType {
Latency,
Memory,
SpanCount,
AntiPattern,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleGoldenTrace {
pub schedule_type: &'static str,
pub expected_prune_steps: Vec<usize>,
pub expected_sparsity_curve: Vec<(usize, f32)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigGoldenTrace {
pub config_id: &'static str,
pub requires_calibration: bool,
pub expected_valid: bool,
}
pub struct PruningGoldenTraces;
impl PruningGoldenTraces {
pub fn performance_assertions() -> Vec<PruningAssertion> {
vec![
PruningAssertion {
name: "pruning_importance_latency",
assertion_type: AssertionType::Latency,
max_value: IMPORTANCE_LATENCY_MAX_MS, fail_on_violation: true,
enabled: true,
},
PruningAssertion {
name: "pruning_mask_generation_latency",
assertion_type: AssertionType::Latency,
max_value: 1000, fail_on_violation: true,
enabled: true,
},
PruningAssertion {
name: "pruning_memory_budget",
assertion_type: AssertionType::Memory,
max_value: 2_147_483_648, fail_on_violation: true,
enabled: true,
},
PruningAssertion {
name: "calibration_syscall_budget",
assertion_type: AssertionType::SpanCount,
max_value: CALIBRATION_SYSCALL_BUDGET,
fail_on_violation: false, enabled: true,
},
PruningAssertion {
name: "detect_redundant_computation",
assertion_type: AssertionType::AntiPattern,
max_value: 70, fail_on_violation: false,
enabled: true,
},
PruningAssertion {
name: "detect_memory_thrashing",
assertion_type: AssertionType::AntiPattern,
max_value: 80, fail_on_violation: false,
enabled: true,
},
]
}
pub fn schedule_traces() -> Vec<ScheduleGoldenTrace> {
vec![
ScheduleGoldenTrace {
schedule_type: "oneshot_1000",
expected_prune_steps: vec![1000],
expected_sparsity_curve: vec![(0, 0.0), (999, 0.0), (1000, 1.0), (1001, 1.0)],
},
ScheduleGoldenTrace {
schedule_type: "gradual_0_100_50pct",
expected_prune_steps: vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
expected_sparsity_curve: vec![
(0, 0.0),
(25, 0.125),
(50, 0.25),
(75, 0.375),
(100, 0.5),
],
},
ScheduleGoldenTrace {
schedule_type: "cubic_0_100_50pct",
expected_prune_steps: (0..=100).collect(),
expected_sparsity_curve: vec![
(0, 0.0),
(25, 0.2890625), (50, 0.4375), (75, 0.4921875), (100, 0.5),
],
},
]
}
pub fn config_traces() -> Vec<ConfigGoldenTrace> {
vec![
ConfigGoldenTrace {
config_id: "default",
requires_calibration: false, expected_valid: true,
},
ConfigGoldenTrace {
config_id: "wanda_nm24",
requires_calibration: true,
expected_valid: true,
},
ConfigGoldenTrace {
config_id: "sparsegpt_unstructured",
requires_calibration: true,
expected_valid: true,
},
ConfigGoldenTrace {
config_id: "invalid_nm_n_gte_m",
requires_calibration: false,
expected_valid: false,
},
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prune::{PruneMethod, PruningConfig, PruningSchedule, SparsityPatternConfig};
#[test]
fn test_oneshot_schedule_matches_golden_trace() {
let schedule = PruningSchedule::OneShot { step: 1000 };
let golden = PruningGoldenTraces::schedule_traces()
.into_iter()
.find(|t| t.schedule_type == "oneshot_1000")
.expect("Golden trace not found");
for (step, expected) in &golden.expected_sparsity_curve {
let actual = schedule.sparsity_at_step(*step);
assert!(
(actual - expected).abs() < 1e-6,
"GOLD-001 FALSIFIED: OneShot at step {step} expected {expected}, got {actual}"
);
}
for step in &golden.expected_prune_steps {
assert!(
schedule.should_prune_at_step(*step),
"GOLD-001 FALSIFIED: OneShot should prune at step {step}"
);
}
}
#[test]
fn test_gradual_schedule_matches_golden_trace() {
let schedule = PruningSchedule::Gradual {
start_step: 0,
end_step: 100,
initial_sparsity: 0.0,
final_sparsity: 0.5,
frequency: 10,
};
let golden = PruningGoldenTraces::schedule_traces()
.into_iter()
.find(|t| t.schedule_type == "gradual_0_100_50pct")
.expect("Golden trace not found");
for (step, expected) in &golden.expected_sparsity_curve {
let actual = schedule.sparsity_at_step(*step);
assert!(
(actual - expected).abs() < 1e-6,
"GOLD-002 FALSIFIED: Gradual at step {step} expected {expected}, got {actual}"
);
}
assert_eq!(
schedule.num_pruning_steps(),
golden.expected_prune_steps.len(),
"GOLD-002 FALSIFIED: Gradual num_pruning_steps mismatch"
);
}
#[test]
fn test_cubic_schedule_matches_golden_trace() {
let schedule = PruningSchedule::Cubic { start_step: 0, end_step: 100, final_sparsity: 0.5 };
let golden = PruningGoldenTraces::schedule_traces()
.into_iter()
.find(|t| t.schedule_type == "cubic_0_100_50pct")
.expect("Golden trace not found");
for (step, expected) in &golden.expected_sparsity_curve {
let actual = schedule.sparsity_at_step(*step);
assert!(
(actual - expected).abs() < 1e-6,
"GOLD-003 FALSIFIED: Cubic at step {step} expected {expected}, got {actual}"
);
}
}
#[test]
fn test_config_calibration_matches_golden_trace() {
let configs = [
("default", PruningConfig::default()),
(
"wanda_nm24",
PruningConfig::new()
.with_method(PruneMethod::Wanda)
.with_pattern(SparsityPatternConfig::nm_2_4()),
),
(
"sparsegpt_unstructured",
PruningConfig::new()
.with_method(PruneMethod::SparseGpt)
.with_pattern(SparsityPatternConfig::Unstructured),
),
];
for (id, config) in &configs {
let golden = PruningGoldenTraces::config_traces()
.into_iter()
.find(|t| t.config_id == *id)
.expect("Golden trace not found");
assert_eq!(
config.requires_calibration(),
golden.requires_calibration,
"GOLD-004 FALSIFIED: Config {id} calibration requirement mismatch"
);
}
}
#[test]
fn test_config_validation_matches_golden_trace() {
let valid_config = PruningConfig::default();
assert!(
valid_config.validate().is_ok(),
"GOLD-005 FALSIFIED: Default config should be valid"
);
let invalid_config =
PruningConfig::new().with_pattern(SparsityPatternConfig::NM { n: 5, m: 4 });
assert!(
invalid_config.validate().is_err(),
"GOLD-005 FALSIFIED: Config with n >= m should be invalid"
);
}
#[test]
fn test_performance_assertions_defined() {
let assertions = PruningGoldenTraces::performance_assertions();
let expected_names = [
"pruning_importance_latency",
"pruning_mask_generation_latency",
"pruning_memory_budget",
"calibration_syscall_budget",
"detect_redundant_computation",
"detect_memory_thrashing",
];
for name in &expected_names {
assert!(
assertions.iter().any(|a| a.name == *name),
"GOLD-010 FALSIFIED: Missing assertion: {name}"
);
}
}
#[test]
fn test_latency_assertions_reasonable() {
let assertions = PruningGoldenTraces::performance_assertions();
for assertion in &assertions {
if assertion.assertion_type == AssertionType::Latency {
assert!(
assertion.max_value <= 60000, "GOLD-011 FALSIFIED: Latency {} unreasonably high: {}ms",
assertion.name,
assertion.max_value
);
assert!(
assertion.max_value >= 100, "GOLD-011 FALSIFIED: Latency {} unreasonably low: {}ms",
assertion.name,
assertion.max_value
);
}
}
}
#[test]
fn test_memory_assertion_reasonable() {
let assertions = PruningGoldenTraces::performance_assertions();
let memory_assertion = assertions
.iter()
.find(|a| a.assertion_type == AssertionType::Memory)
.expect("Memory assertion not found");
assert!(
memory_assertion.max_value >= 100_000_000,
"GOLD-012 FALSIFIED: Memory budget too low: {}",
memory_assertion.max_value
);
assert!(
memory_assertion.max_value <= 17_179_869_184, "GOLD-012 FALSIFIED: Memory budget too high: {}",
memory_assertion.max_value
);
}
#[test]
fn test_golden_trace_serialization() {
let schedule_traces = PruningGoldenTraces::schedule_traces();
let json = serde_json::to_string(&schedule_traces)
.expect("GOLD-020 FALSIFIED: Failed to serialize schedule traces");
assert!(
json.contains("oneshot_1000"),
"GOLD-020 FALSIFIED: Serialized trace should contain schedule type"
);
let assertions = PruningGoldenTraces::performance_assertions();
let json = serde_json::to_string(&assertions)
.expect("GOLD-020 FALSIFIED: Failed to serialize assertions");
assert!(
json.contains("pruning_importance_latency"),
"GOLD-020 FALSIFIED: Serialized assertion should contain name"
);
}
}