1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
//! Advanced edge case tests for rangeify.
//!
//! Tests for IR-level edge cases that aren't covered by basic tests:
//! - Symbolic (variable-sized) ranges
//! - Nested BUFFERIZE operations
//! - Multi-consumer patterns
//! - Complex indexing scenarios
use morok_ir::{DType, Op, UOp};
use crate::rangeify::transforms::rangeify;
use super::helpers::{create_bufferize, create_const, create_range, create_range_symbolic};
// ============================================================================
// Symbolic Range Size Tests
// ============================================================================
#[test]
fn test_symbolic_range_size() {
// Test BUFFERIZE with symbolic (variable) range size
// This tests that rangeify doesn't crash on non-constant range sizes
let size_var = UOp::var("size", DType::Index, 0, 1024);
let compute = UOp::native_const(1.0f32);
// Create range with symbolic size
let range = create_range_symbolic(size_var, 0);
let bufferized = create_bufferize(compute, vec![range]);
// Symbolic ranges work correctly and create kernels
let (_result, _ctx) = rangeify(bufferized, None).unwrap();
// Note: Dead-axis optimization now works for provably-dead symbolic ranges
// (uses vmax analysis - see test_is_dead_axis_symbolic_bounded test)
}
#[test]
fn test_symbolic_range_multiple() {
// Test multiple symbolic ranges
let size1 = UOp::var("size1", DType::Index, 0, 1024);
let size2 = UOp::var("size2", DType::Index, 0, 1024);
let compute = UOp::native_const(2.0f32);
let range1 = create_range_symbolic(size1, 0);
let range2 = create_range_symbolic(size2, 1);
let bufferized = create_bufferize(compute.clone(), vec![range1, range2]);
// Symbolic ranges work correctly with multiple dimensions
let (_result, _ctx) = rangeify(bufferized, None).unwrap();
// Note: Dead-axis optimization is skipped for symbolic ranges
// TODO: Enhance dead-axis detection to handle provably-dead symbolic ranges
}
#[test]
fn test_symbolic_range_with_arithmetic() {
// Test symbolic range size with arithmetic expression
let n = UOp::var("n", DType::Index, 0, 512);
let size = n.try_mul(&create_const(2)).unwrap();
let compute = UOp::native_const(3.0f32);
let range = create_range_symbolic(size, 0);
let bufferized = create_bufferize(compute, vec![range]);
// Symbolic arithmetic expressions work correctly as range sizes
let (_result, _ctx) = rangeify(bufferized, None).unwrap();
// Note: Dead-axis optimization is skipped for symbolic ranges
// TODO: Enhance dead-axis detection to handle provably-dead symbolic ranges
}
// ============================================================================
// Nested BUFFERIZE Tests
// ============================================================================
#[test]
fn test_nested_bufferize_different_ranges() {
// Test BUFFERIZE(BUFFERIZE(x, R1), R2) where R1 != R2
// This tests multi-level buffering with different iteration spaces
let inner_compute = UOp::native_const(1.0f32);
// Inner bufferize with range [0, 10)
let inner_range = create_range(10, 0);
let inner_buf = create_bufferize(inner_compute, vec![inner_range]);
// Outer bufferize with different range [0, 20)
let outer_range = create_range(20, 1);
let outer_buf = create_bufferize(inner_buf, vec![outer_range]);
// Should handle nested bufferization without crashing
let (_result, _ctx) = rangeify(outer_buf, None).unwrap();
// Note: Tests robustness - nested BUFFERIZE operations should be handled gracefully
}
#[test]
fn test_deeply_nested_bufferize() {
// Test 3-level nesting: BUFFERIZE(BUFFERIZE(BUFFERIZE(x)))
let compute = UOp::native_const(1.0f32);
let r1 = create_range(5, 0);
let buf1 = create_bufferize(compute, vec![r1]);
let r2 = create_range(10, 1);
let buf2 = create_bufferize(buf1, vec![r2]);
let r3 = create_range(15, 2);
let buf3 = create_bufferize(buf2, vec![r3]);
// Should handle deep nesting without crashing
let (_result, _ctx) = rangeify(buf3, None).unwrap();
// Note: Tests that deeply nested BUFFERIZE operations don't cause stack overflow or panics
}
// ============================================================================
// Multi-Consumer Pattern Tests
// ============================================================================
#[test]
fn test_bufferize_multiple_consumers() {
use morok_ir::SInt;
use morok_ir::shape::Shape;
// Test single BUFFERIZE with multiple consumers
// Pattern: buf = BUFFERIZE(x); y = f(buf); z = g(buf)
let compute = UOp::native_const(1.0f32);
let range = create_range(10, 0);
let buf = create_bufferize(compute, vec![range]);
// Get BUFFERIZE shape and broadcast constants to match
// BUFFERIZE now has shape [10], so we need to reshape [] -> [1] -> expand [10]
let buf_shape = buf.shape().unwrap().unwrap();
let ones_shape: Shape = buf_shape.iter().map(|_| SInt::Const(1)).collect();
// Two independent consumers of the same buffer
let const2 = UOp::native_const(2.0f32).try_reshape(&ones_shape).unwrap().try_expand(buf_shape).unwrap();
let consumer1 = buf.try_add(&const2).unwrap();
let const3 = UOp::native_const(3.0f32).try_reshape(&ones_shape).unwrap().try_expand(buf_shape).unwrap();
let consumer2 = buf.try_mul(&const3).unwrap();
// Combine consumers with SINK
let sink = UOp::sink(vec![consumer1, consumer2]);
// Should handle multi-consumer pattern without crashing
let (_result, _ctx) = rangeify(sink, None).unwrap();
// Note: Tests that multiple consumers of the same BUFFERIZE don't cause issues
}
#[test]
fn test_operation_with_multiple_uses() {
// Test intermediate operation used multiple times
// Pattern: x = CONST; buf1 = BUFFERIZE(x); buf2 = BUFFERIZE(x)
let compute = UOp::native_const(1.0f32);
let r1 = create_range(10, 0);
let buf1 = create_bufferize(compute.clone(), vec![r1]);
let r2 = create_range(20, 1);
let buf2 = create_bufferize(compute.clone(), vec![r2]);
// Both bufferize the same compute
let sink = UOp::sink(vec![buf1, buf2]);
// Should handle same operation bufferized with different ranges
let (_result, _ctx) = rangeify(sink, None).unwrap();
// Note: Tests that same compute can be buffered with different iteration spaces
}
// ============================================================================
// Complex Indexing Tests
// ============================================================================
#[test]
fn test_index_with_multiple_ranges() {
// Test INDEX operation with multiple range dimensions
let compute = UOp::native_const(1.0f32);
let r1 = create_range(10, 0);
let r2 = create_range(20, 1);
let r3 = create_range(5, 2);
let bufferized = create_bufferize(compute, vec![r1.clone(), r2.clone(), r3.clone()]);
// Create INDEX with all three ranges
let index_op = UOp::new(
Op::Index { buffer: bufferized.clone(), indices: vec![r1, r2, r3].into(), gate: None },
DType::Float32,
);
let (_result, _ctx) = rangeify(index_op, None).unwrap();
}
#[test]
fn test_range_size_mismatch() {
// Test BUFFERIZE with mixed constant and symbolic range sizes
let const_range = create_range(10, 0);
let sym_size = UOp::param(0, 1, DType::Index, None);
let sym_range = create_range_symbolic(sym_size, 1);
let compute = UOp::native_const(1.0f32);
let bufferized = create_bufferize(compute, vec![const_range, sym_range]);
// Mixed constant and symbolic ranges work correctly
let (_result, _ctx) = rangeify(bufferized, None).unwrap();
}
// ============================================================================
// Dead Axis Detection Tests (is_dead_axis with vmax analysis)
// ============================================================================
#[test]
fn test_is_dead_axis_constant_ranges() {
use crate::rangeify::indexing::is_dead_axis;
// Dead: RANGE(0) - vmax = -1
let range_0 = create_range(0, 0);
assert!(is_dead_axis(&range_0));
// Dead: RANGE(1) - vmax = 0
let range_1 = create_range(1, 0);
assert!(is_dead_axis(&range_1));
// Live: RANGE(2) - vmax = 1
let range_2 = create_range(2, 0);
assert!(!is_dead_axis(&range_2));
// Live: RANGE(10) - vmax = 9
let range_10 = create_range(10, 0);
assert!(!is_dead_axis(&range_10));
}
#[test]
fn test_is_dead_axis_symbolic_bounded() {
use crate::rangeify::indexing::is_dead_axis;
// Dead: variable bounded to [1, 1]
let size = UOp::var("size", DType::Index, 0, 1);
let range = create_range_symbolic(size, 0);
assert!(is_dead_axis(&range));
// Live: variable with max > 1
let size = UOp::var("size", DType::Index, 0, 1024);
let range = create_range_symbolic(size, 0);
assert!(!is_dead_axis(&range));
// Live: variable with min > 1 (still live range)
let size = UOp::var("size", DType::Index, 0, 100);
let range = create_range_symbolic(size, 0);
assert!(!is_dead_axis(&range));
}
#[test]
fn test_is_dead_axis_non_range() {
use crate::rangeify::indexing::is_dead_axis;
// Non-RANGE operations should return false
let const_op = UOp::index_const(0);
assert!(!is_dead_axis(&const_op));
let add_op = const_op.try_add(&const_op).unwrap();
assert!(!is_dead_axis(&add_op));
}
#[test]
fn test_symbolic_dead_range_smoke_test() {
// Smoke test: verify that symbolic dead ranges don't cause crashes
// This tests that the is_dead_axis() vmax analysis works end-to-end,
// but doesn't validate that the optimization actually happens.
//
// NOTE: Full validation would require checking that the dead axis
// is actually removed from the result (e.g., verify kernel has 1D ranges
// instead of 2D). This would depend on dead axis elimination passes that
// may run in later optimization stages.
let size = UOp::var("size", DType::Index, 0, 1); // Bounded to [1, 1] - provably dead
let compute = UOp::native_const(1.0f32);
// Create BUFFERIZE with dead symbolic range and live range
let dead_range = create_range_symbolic(size, 0);
let live_range = create_range(10, 1);
// Clone for later assertions (create_bufferize moves the ranges)
let dead_range_clone = dead_range.clone();
let live_range_clone = live_range.clone();
let bufferized = create_bufferize(compute, vec![dead_range, live_range]);
// Rangeify should process this without errors
let bufferized_clone = bufferized.clone();
let (result, _ctx) = rangeify(bufferized, None).unwrap();
// Basic smoke test: verify transformation occurred
assert!(!std::sync::Arc::ptr_eq(&result, &bufferized_clone), "Result should be transformed");
// Verify is_dead_axis() correctly identifies the dead range
use crate::rangeify::indexing::is_dead_axis;
assert!(is_dead_axis(&dead_range_clone), "var[1,1] range should be detected as dead");
assert!(!is_dead_axis(&live_range_clone), "Range(10) should be detected as live");
}