1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
//! Sorting of an array
use crate::custom_ops::CustomOperation;
use crate::data_types::{array_type, scalar_size_in_bits, vector_type, ScalarType, Type, BIT};
use crate::errors::Result;
use crate::graphs::SliceElement::SubArray;
use crate::graphs::*;
use crate::ops::min_max::{Max, Min};

/// Creates a graph that sorts an array using [Batcher's algorithm](https://math.mit.edu/~shor/18.310/batcher.pdf).
///
/// # Arguments
///
/// * `context` - context where a minimum graph should be created
/// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
/// * `st` - scalar type of array elements
///
/// # Returns
///
/// Graph that sorts an array
pub fn create_batchers_sorting_graph(context: Context, k: u32, st: ScalarType) -> Result<Graph> {
    let b = scalar_size_in_bits(st.clone());

    // NOTE: The implementation is based on the 'bottom up' approach as described in
    // https://math.mit.edu/~shor/18.310/batcher.pdf.
    // Commenting about the initial few shape changes done with the help of a
    // 16 element array example
    let n = 2_u64.pow(k);
    // Create a graph in a given context that will be used for sorting
    let b_graph = context.create_graph()?;
    // Create an input node accepting binary arrays of shape [n, b]
    let i_a = b_graph.input(Type::Array(vec![n, b], BIT))?;
    // Stash of nodes uses as input of each iteration of the following loop
    let mut stage_ops = vec![i_a];
    // The following loop, over 'it', corresponds to sorting (SORT()) operation
    // in https://math.mit.edu/~shor/18.310/batcher.pdf.
    for it in 1..(k + 1) {
        let num_classes: u64 = 2_u64.pow(it);
        let num_class_reps = n / num_classes;
        let data_to_sort = stage_ops[(it - 1) as usize].clone();
        // For it==1, we are sorting into pairs i.e. we will have pairs of sorted keys
        // For it==2, we are creating sorted groups of size 4
        // For it==3, we are creating sorted groups of size 8
        // For it==4, we are creating sorted groups of size 16 and so on

        // For the purposes of the discussion, we will temporarily disregard the
        // final dimension i.e. the bit dimension so as to understand how the
        // jiggling of array shape is happening for the elements involved

        // Divide the keys into 2^{it} classes or groups
        let global_a_reshape = b_graph.reshape(
            data_to_sort.clone(),
            array_type(vec![num_class_reps, num_classes, b], BIT),
        )?;

        // 1-D Array Indices:                   0  1  2  3      14 15
        // At it==1, we would have 2^1 classes: 0, 1, 0, 1, ..., 0, 1
        // Now, global_a_reshape shape (2-D shape), in terms of indices, looks like:
        // class0|  class1|
        // ______|________|
        //      0|       1|
        //      2|       3|
        //      .|       .|
        //      .|       .|
        //     12|      13|
        //     14|      15|

        // 1-D Array Indices:                   0  1  2  3  4  5      10 11 12 13 14 15
        // At it==2, we would have 2^2 classes: 0, 1, 2, 3, 0, 1, ..., 2, 3, 0, 1, 2, 3
        // Now, 2-D global_a_reshape shape, in terms of indices, looks like:
        // class0  class1  class2  class3
        //      0       1       2       3
        //      4       5       6       7
        //      8       9      10      11
        //      12     13      14      15

        // 1-D Array Indices:     0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
        // At it==3, 2^3 classes: 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7
        // Now, 2-D global_a_reshape shape, in terms of indices, looks like:
        // class0  class1  class2  class3 class4  class5  class6  class7
        //      0       1       2       3      4       5       6       7
        //      8       9      10      11     12      13      14      15

        // Permute the axes to perform the transpose operation
        // This is done so that each row now corresponds to a single class or group
        let mut global_chunks_a = b_graph.permute_axes(global_a_reshape, vec![1, 0, 2])?;
        // Based on 'it' the global_chunks_a shape looks like
        // For it == 1, locations of flat (1-D) indices
        // class0:  0, 2, ..., 12, 14
        // class1:  1, 3, ..., 13, 15

        // For it == 2, locations of flat (1-D) indices
        // class0:  [[0, 4, 8, 12]
        // class1:   [1, 5, 9, 13]
        // class2:   [2, 6, 10, 14]
        // class3:   [3, 7, 11, 15]]

        // For it == 3, locations of flat (1-D) indices
        // class0:  [[0, 8]
        // class1:   [1, 9]
        // class2:   [2, 10]
        // class3:   [3, 11]
        // class4:   [4, 12]
        // class5:   [5, 13]
        // class6:   [6, 14]
        // class7:   [7, 15]]

        let mut intermediate_chunks_a: Vec<Node> = vec![];
        // The below loop, over 'i', corresponds to the MERGE() operation in https://math.mit.edu/~shor/18.310/batcher.pdf
        // - In the 'bottom up' approach, the operations contained in loop are
        // also referred as 'round(s) of comparisons' in https://math.mit.edu/~shor/18.310/batcher.pdf
        // - For groups or classes of size 2^{it}, you would require 'it' rounds
        // of comparisons
        // - The operations are vectorized to leverage the inherent parallelism
        // - For each group or class to be sorted, intially pairs are formed for
        // sorting then groups of 4 are formed for sorting, likewise for 8, 16 and so on.
        // - Technically, here, the number of dimensions are 4, however, we will
        // ignore the innermost dimension that corresponds to bits as it would
        // be handled by the custom_operations Min{} and Max{} and is not as relevant
        // to the Batcher's algorithm logic
        // - Formation of sub-groups of key sizes 2 or 4, 8, 16, ... for each group
        // of size 2^{it} happens along the outermost axis, whose size is
        // referenced here by 'chunks_a_sz_z'
        for i in (0..it).rev() {
            let chunks_a_sz_y = 2_u64.pow(i);
            let chunks_a_sz_z = 2_u64.pow(it - i); //n / (chunks_a_sz_y * num_class_reps);

            // Reshape to create an additional dimension that corresponds to
            // each sub-group (of sizes 2, 4, 8, 16, ..., 2^{it-i}) within
            // original group of 2^{it} keys, which is to be sorted
            let chunks_a = b_graph.reshape(
                global_chunks_a.clone(),
                array_type(vec![chunks_a_sz_z, chunks_a_sz_y, num_class_reps, b], BIT),
            )?;

            // For it==1 and i==0,
            // the two sub-groups are placed side-by-side along for sorting pairs
            // the outermost (Z) axis, Y-axis (height) is 1 and X-axis (breadth) is 8
            // i.e. Z_0 corresponds to class0, Z_1 corresponds to class1 and so on

            // For it==2 and i==1,
            // sorting the groups of 4 by first sorting the pairs within them
            // chunks_a:
            // [ [[0, 4, 8, 12],    Values    [ [[min(0, 1), min(4, 5), min(8, 9), min(12, 13)],
            //    [1, 5, 9, 13]],   =====>       [max(0, 1), max(4, 5), max(8, 9)], max(12, 13)],
            //   [[2, 6, 10, 14],               [[min(2, 3), min(6, 7), min(10, 11), min(14, 15)],
            //    [3, 7, 11, 15]]                [max(2, 3), max(2, 7), max(10, 11), max(14, 15)]]
            // ]                              ]
            //
            // For it==2 and i==0,
            // sorting the groups of 4 by sorting all 4 elements,
            // i.e., Z_0 corresponds to class0, Z_1 corresponds to class1 and so on
            // chunks_a:
            // [
            //      [[min(0, 2), max(0, 2), min(4, 6), max(4, 6)]],
            //      [[min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
            //      [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)]],
            //      [[min(9, 11), max(9, 11), min(13, 15), max(13, 15)]]
            //
            // ]

            let (chunks_a_shape, chunks_a_scalar_t) = match chunks_a.get_type()? {
                Type::Array(shape, scalar_type) => (shape, scalar_type),
                _ => return Err(runtime_error!("Array Type not found")),
            };

            // Get the first class elements i.e. Z_0
            let single_first_element = b_graph.get(chunks_a.clone(), vec![0])?;
            // For it==1, i==0, single_first_element shape would be [1, 8, x]
            // For it==2, i==1, single_first_element shape would be [2, 4, x]
            // For it==2, i==0, single_first_element shape would be [1, 4, x]

            // If first step, then arrange odd-even adjacent pairs of keys into ordered pairs
            if i == it - 1 {
                // Code to sort the only two chunks/halfs only

                // Here, we are dealing with just two classes: Z_0 (odds) and Z_1 (evens)
                // For it==1, i==0,
                // Z_0 and Z_1 shapes are [1, 8, x]

                // For it==2, i==1,
                // Z_0 and Z_1 shapes are [2, 4, x]

                // Get the group of odd indexed keys from each group or class, i.e., Z_{0}
                let uu = single_first_element;

                // Get the group of even indexed keys from each group or class, i.e., Z_{1}
                let vv = b_graph.get(chunks_a.clone(), vec![1])?;

                // Get minimums from both the classes
                let chunks_a_0 = b_graph
                    .custom_op(CustomOperation::new(Min {}), vec![uu.clone(), vv.clone()])?;
                // For it==1, i==0, chunks_a_0 = [[min(0, 1), min(2, 3), ..., min(12, 13), min(14, 15)]]
                // For it==2, i==1, chunks_a_0 = [[min(0, 2), min(4, 6), min(8, 10), min(12, 14)],
                //                                [min(1, 3), min(5, 7), min(9, 11), min(13, 15)]]

                // Get maximums from both the classes
                let chunks_a_1 = b_graph
                    .custom_op(CustomOperation::new(Max {}), vec![uu.clone(), vv.clone()])?;
                // For it==1, i==0, chunks_a_1 = [[max(0, 1), max(2, 3), ..., max(12, 13), max(14, 15)]]
                // For it==2, i==1, chunks_a_0 = [[max(0, 2), max(4, 6), max(8, 10), max(12, 14)],
                //                                [max(1, 3), max(5, 7), max(9, 11), max(13, 15)]]

                // Collect these maximums and minimums together for reshaping later
                let a_combined = b_graph.create_tuple(vec![chunks_a_0, chunks_a_1])?;
                // For it==1, i==0,
                // a_combined = [(min(0, 1), max(0, 1)), (min(2, 3), max(2, 3)), ..., (min(12, 13), max(12, 13)), (min(14, 15), max(14, 15))]
                // For it==2, i==1,
                // a_combined = [[(min(0, 2), max(0, 2)), (min(4, 6), max(4, 6)), (min(8, 10), max(8, 10)), (min(12, 14), max(12, 14))],
                //               [(min(1, 3), max(1, 3)), (min(5, 7), max(5, 7)), (min(9, 11), max(9, 11)), (min(13, 15), max(13, 15))]]

                // Reshape these combined elements back into a vector shape
                let interm_chunks_a = b_graph.reshape(
                    a_combined,
                    vector_type(
                        chunks_a_sz_z,
                        array_type(vec![chunks_a_sz_y, num_class_reps, b], chunks_a_scalar_t),
                    ),
                )?;
                // For it==1, i==0,
                // i.e., chunks_a's shape [2, 1, 8, x] for further processing
                // interm_chunks_a =    <[min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(6, 7)]>,
                //                      <[min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]>

                // For it==2, i==1,
                // chunks_a's shape [2, 2, 4, x]
                // interm_chunks_a = <[ [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
                //                      [min(8, 10), max(8, 10), min(12, 14), max(12, 14)] ],
                //                    [ [min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
                //                      [min(9, 11), max(9, 11), min(13, 15), max(13, 15)] ]>

                // Convert these combined elements back to an array of original shape
                intermediate_chunks_a.push(b_graph.vector_to_array(interm_chunks_a)?);
                // For it==1, i==0,
                // i.e., into the chunks_a's shape [2, 1, 8, x] for further processing
                // intermediate_chunks_a[intermediate_chunks_a.len()-1] =
                // [
                //  [[min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(2, 7)]],
                //
                //  [[min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]]
                // ]

                // For it==2, i==1,
                // i.e., into the chunks_a's shape [2, 2, 4, x] for further processing
                // intermediate_chunks_a[intermediate_chunks_a.len()-1] =
                //  [
                //      [[min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
                //       [min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
                //      [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
                //       [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]]
                //  ]
            } else {
                // This else block corresponds to the COMP() operations
                // specified within the MERGE() function in (https://math.mit.edu/~shor/18.310/batcher.pdf, p. 3) and
                // if x_{1}, x_{2}, ..., x_{n} are the keys to be sorted then
                // this COMP is operated as COMP(x 2 , x 3 ), COMP(x 4 , x 5 ), · · ·
                // COMP(x n−2 , x n−1 ).
                // In this case, we would not be considering terminal sub-groups
                // i.e. Z_{0} and Z_{2^{it-i}-1}

                // Set the shape of Z_0
                let a_single_first_elem = b_graph.reshape(
                    single_first_element,
                    array_type(
                        chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
                        chunks_a_scalar_t.clone(),
                    ),
                )?;
                // For it==2, i==0,
                // a_single_first_elem =
                // [min(0, 2), max(0, 2), min(4, 6), max(4, 6)]

                // Obtain all the odd components of Z, except the first and last one,
                // i.e., Z_{i} s.t. 1 <=i < 2^{it-i}-1 && i % 2 == 1
                let uu = b_graph
                    .get_slice(chunks_a.clone(), vec![SubArray(Some(1), Some(-1), Some(2))])?;
                // For it==2, i==0, uu shape = [1, 1, 4, x], uu =
                // [
                //      [[min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
                // ]

                // Obtain all the even components of Z, except the first one i.e.
                // Z_{i} s.t. 2 <= i < 2^{it-i} && i % 2 == 0
                let vv =
                    b_graph.get_slice(chunks_a.clone(), vec![SubArray(Some(2), None, Some(2))])?;
                // For it==2, i==0, vv shape = [1, 1, 4, x], vv =
                // [
                //      [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)]],
                // ]

                // Obtain the minimum of these two arrays - uu and vv
                let chunks_a_evens = b_graph
                    .custom_op(CustomOperation::new(Min {}), vec![uu.clone(), vv.clone()])?;
                // For it==2, i==0, chunks_a_evens shape = [1, 1, 4, x], chunks_a_evens =
                // [
                //      [[min(8, 10, 1, 3), min(max(8, 10), max(1, 3)), min(12, 14, 5, 7), min(max(12, 14), max(5, 7))]]
                // ]

                // Obtain the maximum of these two arrays - uu and vv
                let chunks_a_odds = b_graph
                    .custom_op(CustomOperation::new(Max {}), vec![uu.clone(), vv.clone()])?;
                // For it==2, i==0, chunks_a_odds shape = [1, 1, 4, x], chunks_a_odds =
                // [
                //      [[max(min(8, 10), min(1, 3)), max(8, 10, 1, 3), max(min(12, 14), min(5, 7)), max(12, 14, 5, 7)]]
                // ]

                // Convert the array to vector and remove the extra Z-dimension
                let v_non_terminal_evens = b_graph.array_to_vector(chunks_a_evens)?;
                // For it==2, i==0, v_non_terminal_evens shape = [1, 4, x]<1>
                // v_non_terminal_evens =
                // <[min(8, 10, 1, 3), min(max(8, 10), max(1, 3)), min(12, 14, 5, 7), min(max(12, 14), max(5, 7))]>

                // Convert the array to vector and remove the extra Z-dimension
                let v_non_terminal_odds = b_graph.array_to_vector(chunks_a_odds)?;
                // For it==2, i==0, v_non_terminal_odds shape = [1, 4, x]<1>
                // v_non_terminal_odds =
                // <[max(min(8, 10), min(1, 3)), max(8, 10, 1, 3), max(min(12, 14), min(5, 7)), max(12, 14, 5, 7)]>

                // Zip both the results together
                let v_non_term_elems =
                    b_graph.zip(vec![v_non_terminal_evens, v_non_terminal_odds])?;
                // For it==2, i==0, v_non_term_elems shape = ((1, 4, x)(1, 4, x))<1>
                // v_non_term_elems =
                // <(min(8, 10, 1, 3), max(min(8, 10), min(1, 3))),
                //  (min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)),
                //  (min(12, 14, 5, 7), max(min(12, 14), min(5, 7))),
                //  (min(max(12, 14), max(5, 7)), max(12, 14, 5, 7))>

                // In a similar way to the first element i.e. Z_{0}, extract the last element
                let single_last_elem =
                    b_graph.get(chunks_a.clone(), vec![chunks_a_shape[0] - 1])?;
                // For it==2, i==0, single_last_element shape would be [1, 4, x]

                // Set the shape of Z_{2^{it-i}-1} to [1, 4, x]
                let a_single_last_elem = b_graph.reshape(
                    single_last_elem,
                    array_type(
                        chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
                        chunks_a_scalar_t.clone(),
                    ),
                )?;
                // For it==2, i==0,
                // a_single_last_elem =
                //  [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]

                // Create a tuple of Z: (first element-Z_{0}, vector, last element-Z_{2^{it-i}-1})
                let v_combined = b_graph.create_tuple(vec![
                    a_single_first_elem,
                    v_non_term_elems,
                    a_single_last_elem,
                ])?;
                // For it==2, i==0,
                // v_combined =
                // ([min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
                //  <(min(8, 10, 1, 3), max(min(8, 10), min(1, 3))),
                //  (min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)),
                //  (min(12, 14, 5, 7), max(min(12, 14), min(5, 7))),
                //  (min(max(12, 14), max(5, 7)), max(12, 14, 5, 7))>,
                //  [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
                // )

                // Reshape the tuple back into vector form
                let v_chunk_a = b_graph.reshape(
                    v_combined,
                    vector_type(
                        chunks_a_shape[0],
                        array_type(
                            chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
                            chunks_a_scalar_t,
                        ),
                    ),
                )?;
                // For it==2, i==0,
                // v_chunk_a's shape is {[1, 4, x]}<4> i.e. 4 components, each an
                // array of size [1, 4, x]
                // v_chunk_a =
                // <
                //   [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
                //   [min(8, 10, 1, 3), max(min(8, 10), min(1, 3)), min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)],
                //   [min(12, 14, 5, 7), max(min(12, 14), min(5, 7)), min(max(12, 14), max(5, 7)), max(12, 14, 5, 7),
                //   [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
                // >
                //

                // Convert the vector form to the array form
                intermediate_chunks_a.push(b_graph.vector_to_array(v_chunk_a)?);
                // For it==2, i==0,
                // intermediate_chunks_a[intermediate_chunks_a.len()-1] =
                // [
                //   [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
                //   [min(8, 10, 1, 3), max(min(8, 10), min(1, 3)), min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)],
                //   [min(12, 14, 5, 7), max(min(12, 14), min(5, 7)), min(max(12, 14), max(5, 7)), max(12, 14, 5, 7),
                //   [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
                // ]
            }

            // Reshape/Merge it back into 2-D from the 3-D we created for performing
            // the Min/Max compare and switches
            global_chunks_a = b_graph.reshape(
                intermediate_chunks_a[(intermediate_chunks_a.len() - 1) as usize].clone(),
                array_type(vec![num_classes, num_class_reps, b], BIT),
            )?;
            // For it==1, i==0, reshape latest intermediate_chunk_a from [2, 1, 8, x] -> [2, 8, x] for next global_chunks_a
            // global_chunks_a:
            //  [
            //    [min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(2, 7)],
            //    [min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]
            //  ]

            // For it==2, i==1, reshape latest intermediate_chunk_a from [2, 2, 4, x] -> [4, 4, x] for next global_chunks_a
            // global_chunks_a:
            //  [
            //      [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
            //      [min(8, 10), max(8, 10), min(12, 14), max(12, 14)],
            //      [min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
            //      [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
            //  ]
        }

        // Permute axes to revert original transpose
        let aa_transposed = b_graph.permute_axes(global_chunks_a.clone(), vec![1, 0, 2])?;
        // For it==1, i==0, aa_transposed shape: [8, 2, x], with X-axis representing the classes
        // aa_transposed:
        // [
        //  [min(0, 1), max(0, 1)],
        //  [min(2, 3), max(2, 3)],
        //  [min(4, 5), max(4, 5)],
        //  [min(6, 7), max(2, 7)],
        //  [min(8, 9), max(8, 9)],
        //  [min(10, 11), max(10, 11)],
        //  [min(12, 13), max(12, 13)],
        //  [min(14, 15), max(14, 15)]
        // ]

        // Reshape data to flatten into shape [n, x] for further processing
        stage_ops.push(b_graph.reshape(aa_transposed, array_type(vec![n, b], BIT))?)
        // In terms of the initial index positions of elements, this looks like:

        // data idx:          0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15
        // data:           [100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85]
        // For it==1, i==0, aa_transposed ==
        // idx:                  0         1          2          3          4          5          6          7          8          9         10           11           12           13           14           15
        // Current round idx: [min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(6, 7), min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]
        // data idx:          0    1   2   3   4   5   6   7   8   9  10  11  12  13  14  15
        // data post ops.:  [99, 100, 97, 98, 95, 96, 93, 94, 91, 92, 89, 90, 87, 88, 85, 86]
    }
    // Convert output from the binary form to the arithmetic form
    let output = if st != BIT {
        stage_ops[k as usize].b2a(st)?
    } else {
        stage_ops[k as usize].clone()
    };
    // Before computation every graph should be finalized, which means that it should have a designated output node
    // This can be done by calling `g.set_output_node(output)?` or as below
    b_graph.set_output_node(output)?;
    // Finalization checks that the output node of the graph g is set. After finalization the graph can't be changed
    b_graph.finalize()?;

    Ok(b_graph)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::custom_ops::run_instantiation_pass;
    use crate::data_types::{ScalarType, BIT, INT16, INT32, INT64};
    use crate::data_values::Value;
    use crate::evaluators::random_evaluate;
    use crate::random::PRNG;
    use std::cmp::Reverse;

    /// Helper function to test the sorting network graph for large inputs
    /// Testing is done by first sorting it with the given graph and then
    /// comparing its result with the non-graph-sorted result
    ///
    /// # Arguments
    ///
    /// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
    /// * `st` - scalar type of array elements
    fn test_large_vec_unsigned_batchers_sorting(k: u32, st: ScalarType) -> Result<()> {
        let context = create_context()?;
        let graph: Graph = create_batchers_sorting_graph(context.clone(), k, st.clone())?;
        context.set_main_graph(graph.clone())?;
        context.finalize()?;

        let mapped_c = run_instantiation_pass(graph.get_context())?;

        let seed = b"\xB6\xD7\x1A\x2F\x88\xC1\x12\xBA\x3F\x2E\x17\xAB\xB7\x46\x15\x9A";
        let mut prng = PRNG::new(Some(seed.clone()))?;
        let array_t: Type = array_type(vec![2_u64.pow(k)], st);
        let data = prng.get_random_value(array_t.clone())?;
        let data_v_u64 = data.to_flattened_array_u64(array_t.clone())?;
        let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
            .to_flattened_array_u64(array_t)?;
        let mut sorted_data = data_v_u64;
        sorted_data.sort_unstable();
        assert_eq!(sorted_data, result);
        Ok(())
    }

    /// Helper function to test the sorting network graph for large inputs
    /// Testing is done by first sorting it with the given graph and then
    /// comparing its result with the non-graph-sorted result
    ///
    /// # Arguments
    ///
    /// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
    /// * `st` - scalar type of array elements
    fn test_unsigned_batchers_sorting_graph_helper(
        k: u32,
        st: ScalarType,
        data: Vec<u64>,
    ) -> Result<()> {
        let context = create_context()?;
        let graph: Graph = create_batchers_sorting_graph(context.clone(), k, st.clone())?;
        context.set_main_graph(graph.clone())?;
        context.finalize()?;

        let mapped_c = run_instantiation_pass(graph.get_context())?;

        let v_a = Value::from_flattened_array(&data, st.clone())?;
        let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![v_a])?
            .to_flattened_array_u64(array_type(vec![data.len() as u64], st))?;
        let mut sorted_data = data;
        sorted_data.sort_unstable();
        assert_eq!(sorted_data, result);
        Ok(())
    }

    /// This function tests the well-formed sorting graph for its correctness
    /// Parameters varied are k, st and the input data could be unsorted,
    /// sorted or sorted in a decreasing order.
    #[test]
    fn test_wellformed_unsigned_batchers_sorting_graph() -> Result<()> {
        let mut data = vec![65535, 0, 2, 32768];
        test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;
        data.sort_unstable();
        test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;
        data.sort_by_key(|w| Reverse(*w));
        test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;

        let data = vec![548890456, 402403639693304868, u64::MAX, 999790788];
        test_unsigned_batchers_sorting_graph_helper(2, INT64, data.clone())?;

        let data = vec![643082556];
        test_unsigned_batchers_sorting_graph_helper(0, INT32, data.clone())?;

        let data = vec![1, 0, 0, 1];
        test_unsigned_batchers_sorting_graph_helper(2, BIT, data.clone())?;

        test_large_vec_unsigned_batchers_sorting(7, BIT)?;
        test_large_vec_unsigned_batchers_sorting(4, INT64)?;

        Ok(())
    }
}