1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
//! Memory optimization pass
//!
//! This module provides memory optimization capabilities including tensor lifetime
//! analysis, in-place operation identification, and memory layout optimization.
use super::passes::{get_node_inputs, get_node_outputs, OptimizationPass};
use crate::graph::{Graph, NodeId};
use crate::Result;
use std::collections::{HashMap, HashSet};
/// Memory optimization pass
/// Minimizes memory usage by reducing tensor lifetime and enabling in-place operations
pub struct MemoryOptimizationPass;
impl Default for MemoryOptimizationPass {
fn default() -> Self {
Self::new()
}
}
impl MemoryOptimizationPass {
pub fn new() -> Self {
Self
}
/// Analyzes tensor lifetimes in the graph
fn analyze_tensor_lifetimes(&self, graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
let mut lifetimes = HashMap::new();
// Perform topological sort to determine execution order
let execution_order = self.topological_sort(graph);
for (position, &node_id) in execution_order.iter().enumerate() {
// Track when this tensor is first created (birth)
lifetimes.insert(node_id, (position, position));
// Update lifetime end for all inputs (last usage)
let inputs = get_node_inputs(graph, node_id);
for input_id in inputs {
if let Some((birth, _)) = lifetimes.get(&input_id) {
lifetimes.insert(input_id, (*birth, position));
}
}
}
lifetimes
}
/// Performs topological sort on the graph
fn topological_sort(&self, graph: &Graph) -> Vec<NodeId> {
let mut visited = HashSet::new();
let mut result = Vec::new();
let mut temp_visited = HashSet::new();
// Start from nodes with no inputs (source nodes)
for node in graph.nodes() {
let node_id = node.id;
if !visited.contains(&node_id) && get_node_inputs(graph, node_id).is_empty() {
self.dfs_visit(graph, node_id, &mut visited, &mut temp_visited, &mut result);
}
}
// Handle remaining nodes (in case of cycles or disconnected components)
for node in graph.nodes() {
let node_id = node.id;
if !visited.contains(&node_id) {
self.dfs_visit(graph, node_id, &mut visited, &mut temp_visited, &mut result);
}
}
result
}
/// DFS visit for topological sort
#[allow(clippy::only_used_in_recursion)]
fn dfs_visit(
&self,
graph: &Graph,
node_id: NodeId,
visited: &mut HashSet<NodeId>,
temp_visited: &mut HashSet<NodeId>,
result: &mut Vec<NodeId>,
) {
if temp_visited.contains(&node_id) {
// Cycle detected, skip for now
return;
}
if visited.contains(&node_id) {
return;
}
temp_visited.insert(node_id);
let outputs = get_node_outputs(graph, node_id);
for output_id in outputs {
self.dfs_visit(graph, output_id, visited, temp_visited, result);
}
temp_visited.remove(&node_id);
visited.insert(node_id);
result.push(node_id);
}
/// Identifies opportunities for in-place operations
fn identify_inplace_opportunities(&self, graph: &Graph) -> Vec<(NodeId, NodeId)> {
let mut opportunities = Vec::new();
for node in graph.nodes() {
let node_id = node.id;
// Check if this operation can be done in-place
if let crate::graph::NodeType::Operation(ref op_name) = node.op_type {
if self.can_be_inplace(op_name) {
let inputs = get_node_inputs(graph, node_id);
let outputs = get_node_outputs(graph, node_id);
// For in-place operation, we need exactly one input and output
// and the input should not be used elsewhere
if inputs.len() == 1 && outputs.len() == 1 {
let input_id = inputs[0];
// Check if input is only used by this node
if self.is_only_consumer(graph, input_id, node_id) {
opportunities.push((node_id, input_id));
}
}
}
}
}
opportunities
}
/// Checks if an operation can be performed in-place
fn can_be_inplace(&self, operation: &str) -> bool {
matches!(
operation,
"add"
| "sub"
| "mul"
| "div"
| "relu"
| "sigmoid"
| "tanh"
| "abs"
| "neg"
| "exp"
| "log"
| "sqrt"
| "square"
)
}
/// Checks if a node is the only consumer of an input
fn is_only_consumer(&self, graph: &Graph, input_id: NodeId, consumer_id: NodeId) -> bool {
let outputs = get_node_outputs(graph, input_id);
outputs.len() == 1 && outputs[0] == consumer_id
}
/// Optimizes memory layout by grouping operations
fn optimize_memory_layout(&self, graph: &mut Graph) -> bool {
let mut changed = false;
// Group operations that can share memory
let memory_groups = self.find_memory_sharing_groups(graph);
for group in memory_groups {
if group.len() > 1 {
// Mark operations in this group for memory sharing
for &node_id in &group {
if let Some(node) = graph.get_node_mut(node_id) {
// Add memory sharing hint to node attributes
node.attributes.insert(
"memory_group".to_string(),
crate::graph::AttributeValue::String(format!("{group:?}")),
);
changed = true;
}
}
}
}
changed
}
/// Finds groups of operations that can share memory
fn find_memory_sharing_groups(&self, graph: &Graph) -> Vec<Vec<NodeId>> {
let mut groups = Vec::new();
let mut visited = HashSet::new();
for node in graph.nodes() {
let node_id = node.id;
if visited.contains(&node_id) {
continue;
}
let mut group = Vec::new();
self.build_memory_group(graph, node_id, &mut group, &mut visited);
if group.len() > 1 {
groups.push(group);
}
}
groups
}
/// Recursively builds a memory sharing group
fn build_memory_group(
&self,
graph: &Graph,
node_id: NodeId,
group: &mut Vec<NodeId>,
visited: &mut HashSet<NodeId>,
) {
if visited.contains(&node_id) {
return;
}
visited.insert(node_id);
group.push(node_id);
// Only group with immediate successors that can share memory
let outputs = get_node_outputs(graph, node_id);
for output_id in outputs {
if let Some(node) = graph.get_node(output_id) {
// Only group operations that can safely share memory
if let crate::graph::NodeType::Operation(ref op_name) = node.op_type {
if self.can_share_memory(op_name) {
self.build_memory_group(graph, output_id, group, visited);
}
}
}
}
}
/// Checks if operations can safely share memory
fn can_share_memory(&self, operation: &str) -> bool {
// Operations that don't change tensor shape can often share memory
matches!(
operation,
"add"
| "sub"
| "mul"
| "div"
| "relu"
| "sigmoid"
| "tanh"
| "abs"
| "neg"
| "exp"
| "log"
| "sqrt"
| "square"
| "dropout"
)
}
}
impl OptimizationPass for MemoryOptimizationPass {
fn apply(&self, graph: &mut Graph) -> Result<bool> {
let mut changed = false;
// Step 1: Analyze tensor lifetimes
let lifetimes = self.analyze_tensor_lifetimes(graph);
// Step 2: Identify in-place operation opportunities
let inplace_ops = self.identify_inplace_opportunities(graph);
// Step 3: Apply in-place optimizations
for (node_id, input_id) in inplace_ops {
if let Some(node) = graph.get_node_mut(node_id) {
// Mark this operation as in-place
node.attributes.insert(
"inplace".to_string(),
crate::graph::AttributeValue::Bool(true),
);
node.attributes.insert(
"inplace_input".to_string(),
crate::graph::AttributeValue::String(input_id.to_string()),
);
changed = true;
}
}
// Step 4: Optimize memory layout
if self.optimize_memory_layout(graph) {
changed = true;
}
// Step 5: Add memory management hints
for (node_id, (birth, death)) in lifetimes {
if let Some(node) = graph.get_node_mut(node_id) {
node.attributes.insert(
"lifetime_start".to_string(),
crate::graph::AttributeValue::Int(birth as i64),
);
node.attributes.insert(
"lifetime_end".to_string(),
crate::graph::AttributeValue::Int(death as i64),
);
changed = true;
}
}
Ok(changed)
}
fn name(&self) -> &str {
"MemoryOptimization"
}
fn is_applicable(&self, graph: &Graph) -> bool {
// Memory optimization is applicable if there are at least 2 nodes
graph.nodes().count() >= 2
}
fn priority(&self) -> u32 {
80 // Run after CSE and fusion but before dead code elimination
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_optimization_pass() {
let pass = MemoryOptimizationPass::new();
assert_eq!(pass.name(), "MemoryOptimization");
assert_eq!(pass.priority(), 80);
let graph = Graph::new();
assert!(!pass.is_applicable(&graph));
}
#[test]
fn test_inplace_operations() {
let pass = MemoryOptimizationPass::new();
// Test that various operations can be in-place
assert!(pass.can_be_inplace("add"));
assert!(pass.can_be_inplace("relu"));
assert!(pass.can_be_inplace("exp"));
// Test that some operations cannot be in-place
assert!(!pass.can_be_inplace("matmul"));
assert!(!pass.can_be_inplace("conv2d"));
}
#[test]
fn test_memory_sharing() {
let pass = MemoryOptimizationPass::new();
// Test that element-wise operations can share memory
assert!(pass.can_share_memory("add"));
assert!(pass.can_share_memory("relu"));
assert!(pass.can_share_memory("sigmoid"));
// Test that shape-changing operations cannot share memory
assert!(!pass.can_share_memory("matmul"));
assert!(!pass.can_share_memory("conv2d"));
assert!(!pass.can_share_memory("reshape"));
}
}