1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
impl PruningPlan {
/// Create a new empty pruning plan.
#[must_use]
pub fn new() -> Self {
Self {
channel_removals: HashMap::new(),
layer_removals: Vec::new(),
validated: false,
}
}
/// Add channel removals for a layer.
pub fn remove_channels(&mut self, layer_id: impl Into<String>, channels: Vec<usize>) {
self.channel_removals.insert(layer_id.into(), channels);
self.validated = false;
}
/// Add a layer to be removed.
pub fn remove_layer(&mut self, layer_id: impl Into<String>) {
self.layer_removals.push(layer_id.into());
self.validated = false;
}
/// Get channels to remove from a layer.
#[must_use]
pub fn channels_to_remove(&self, layer_id: &str) -> Option<&Vec<usize>> {
self.channel_removals.get(layer_id)
}
/// Check if a layer should be removed.
#[must_use]
pub fn is_layer_removed(&self, layer_id: &str) -> bool {
self.layer_removals.contains(&layer_id.to_string())
}
/// Validate the plan against a dependency graph.
///
/// Ensures:
/// - All referenced layers exist
/// - Channel indices are valid
/// - Dependencies are satisfied
pub fn validate(&mut self, graph: &DependencyGraph) -> Result<(), PruningError> {
// Check all referenced layers exist
for layer_id in self.channel_removals.keys() {
if graph.get_node(layer_id).is_none() {
return Err(PruningError::InvalidPattern {
message: format!("Layer '{layer_id}' not found in graph"),
});
}
}
for layer_id in &self.layer_removals {
if graph.get_node(layer_id).is_none() {
return Err(PruningError::InvalidPattern {
message: format!("Layer '{layer_id}' not found in graph"),
});
}
}
// Check channel indices are valid
for (layer_id, channels) in &self.channel_removals {
if let Some(node) = graph.get_node(layer_id) {
for &ch in channels {
if ch >= node.output_dim && node.output_dim > 0 {
return Err(PruningError::InvalidSparsity {
value: ch as f32,
constraint: format!(
"Channel {} >= output_dim {} for layer {}",
ch, node.output_dim, layer_id
),
});
}
}
}
}
self.validated = true;
Ok(())
}
/// Check if plan is validated.
#[must_use]
pub fn is_validated(&self) -> bool {
self.validated
}
/// Get total number of channels being removed.
pub fn total_channels_removed(&self) -> usize {
self.channel_removals.values().map(Vec::len).sum()
}
/// Get number of layers being removed.
#[must_use]
pub fn total_layers_removed(&self) -> usize {
self.layer_removals.len()
}
}
impl Default for PruningPlan {
fn default() -> Self {
Self::new()
}
}
/// Propagate channel pruning through a dependency graph.
///
/// When output channels are pruned from a layer, this function
/// determines which input channels of downstream layers must also be pruned.
///
/// # Arguments
/// * `graph` - The dependency graph
/// * `source_layer` - Layer where channels are being pruned
/// * `pruned_channels` - Indices of channels being removed
///
/// # Returns
/// Map of `layer_id` -> input channels to remove
#[must_use]
pub fn propagate_channel_pruning(
graph: &DependencyGraph,
source_layer: &str,
pruned_channels: &[usize],
) -> HashMap<String, Vec<usize>> {
let mut result = HashMap::new();
// Find all downstream layers connected via sequential edges
for edge in graph.edges_from(source_layer) {
if edge.dep_type == DependencyType::Sequential {
// Downstream layer needs same input channels removed
result.insert(edge.to.clone(), pruned_channels.to_vec());
}
}
result
}