1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
//! Sorting of an array
use crate::custom_ops::CustomOperation;
use crate::data_types::{array_type, scalar_size_in_bits, vector_type, ScalarType, Type, BIT};
use crate::errors::Result;
use crate::graphs::SliceElement::SubArray;
use crate::graphs::*;
use crate::ops::min_max::{Max, Min};
/// Creates a graph that sorts an array using [Batcher's algorithm](https://math.mit.edu/~shor/18.310/batcher.pdf).
///
/// # Arguments
///
/// * `context` - context where a minimum graph should be created
/// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
/// * `st` - scalar type of array elements
///
/// # Returns
///
/// Graph that sorts an array
pub fn create_batchers_sorting_graph(context: Context, k: u32, st: ScalarType) -> Result<Graph> {
let b = scalar_size_in_bits(st.clone());
// NOTE: The implementation is based on the 'bottom up' approach as described in
// https://math.mit.edu/~shor/18.310/batcher.pdf.
// Commenting about the initial few shape changes done with the help of a
// 16 element array example
let n = 2_u64.pow(k);
// Create a graph in a given context that will be used for sorting
let b_graph = context.create_graph()?;
// Create an input node accepting binary arrays of shape [n, b]
let i_a = b_graph.input(Type::Array(vec![n, b], BIT))?;
// Stash of nodes uses as input of each iteration of the following loop
let mut stage_ops = vec![i_a];
// The following loop, over 'it', corresponds to sorting (SORT()) operation
// in https://math.mit.edu/~shor/18.310/batcher.pdf.
for it in 1..(k + 1) {
let num_classes: u64 = 2_u64.pow(it);
let num_class_reps = n / num_classes;
let data_to_sort = stage_ops[(it - 1) as usize].clone();
// For it==1, we are sorting into pairs i.e. we will have pairs of sorted keys
// For it==2, we are creating sorted groups of size 4
// For it==3, we are creating sorted groups of size 8
// For it==4, we are creating sorted groups of size 16 and so on
// For the purposes of the discussion, we will temporarily disregard the
// final dimension i.e. the bit dimension so as to understand how the
// jiggling of array shape is happening for the elements involved
// Divide the keys into 2^{it} classes or groups
let global_a_reshape = b_graph.reshape(
data_to_sort.clone(),
array_type(vec![num_class_reps, num_classes, b], BIT),
)?;
// 1-D Array Indices: 0 1 2 3 14 15
// At it==1, we would have 2^1 classes: 0, 1, 0, 1, ..., 0, 1
// Now, global_a_reshape shape (2-D shape), in terms of indices, looks like:
// class0| class1|
// ______|________|
// 0| 1|
// 2| 3|
// .| .|
// .| .|
// 12| 13|
// 14| 15|
// 1-D Array Indices: 0 1 2 3 4 5 10 11 12 13 14 15
// At it==2, we would have 2^2 classes: 0, 1, 2, 3, 0, 1, ..., 2, 3, 0, 1, 2, 3
// Now, 2-D global_a_reshape shape, in terms of indices, looks like:
// class0 class1 class2 class3
// 0 1 2 3
// 4 5 6 7
// 8 9 10 11
// 12 13 14 15
// 1-D Array Indices: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// At it==3, 2^3 classes: 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7
// Now, 2-D global_a_reshape shape, in terms of indices, looks like:
// class0 class1 class2 class3 class4 class5 class6 class7
// 0 1 2 3 4 5 6 7
// 8 9 10 11 12 13 14 15
// Permute the axes to perform the transpose operation
// This is done so that each row now corresponds to a single class or group
let mut global_chunks_a = b_graph.permute_axes(global_a_reshape, vec![1, 0, 2])?;
// Based on 'it' the global_chunks_a shape looks like
// For it == 1, locations of flat (1-D) indices
// class0: 0, 2, ..., 12, 14
// class1: 1, 3, ..., 13, 15
// For it == 2, locations of flat (1-D) indices
// class0: [[0, 4, 8, 12]
// class1: [1, 5, 9, 13]
// class2: [2, 6, 10, 14]
// class3: [3, 7, 11, 15]]
// For it == 3, locations of flat (1-D) indices
// class0: [[0, 8]
// class1: [1, 9]
// class2: [2, 10]
// class3: [3, 11]
// class4: [4, 12]
// class5: [5, 13]
// class6: [6, 14]
// class7: [7, 15]]
let mut intermediate_chunks_a: Vec<Node> = vec![];
// The below loop, over 'i', corresponds to the MERGE() operation in https://math.mit.edu/~shor/18.310/batcher.pdf
// - In the 'bottom up' approach, the operations contained in loop are
// also referred as 'round(s) of comparisons' in https://math.mit.edu/~shor/18.310/batcher.pdf
// - For groups or classes of size 2^{it}, you would require 'it' rounds
// of comparisons
// - The operations are vectorized to leverage the inherent parallelism
// - For each group or class to be sorted, intially pairs are formed for
// sorting then groups of 4 are formed for sorting, likewise for 8, 16 and so on.
// - Technically, here, the number of dimensions are 4, however, we will
// ignore the innermost dimension that corresponds to bits as it would
// be handled by the custom_operations Min{} and Max{} and is not as relevant
// to the Batcher's algorithm logic
// - Formation of sub-groups of key sizes 2 or 4, 8, 16, ... for each group
// of size 2^{it} happens along the outermost axis, whose size is
// referenced here by 'chunks_a_sz_z'
for i in (0..it).rev() {
let chunks_a_sz_y = 2_u64.pow(i);
let chunks_a_sz_z = 2_u64.pow(it - i); //n / (chunks_a_sz_y * num_class_reps);
// Reshape to create an additional dimension that corresponds to
// each sub-group (of sizes 2, 4, 8, 16, ..., 2^{it-i}) within
// original group of 2^{it} keys, which is to be sorted
let chunks_a = b_graph.reshape(
global_chunks_a.clone(),
array_type(vec![chunks_a_sz_z, chunks_a_sz_y, num_class_reps, b], BIT),
)?;
// For it==1 and i==0,
// the two sub-groups are placed side-by-side along for sorting pairs
// the outermost (Z) axis, Y-axis (height) is 1 and X-axis (breadth) is 8
// i.e. Z_0 corresponds to class0, Z_1 corresponds to class1 and so on
// For it==2 and i==1,
// sorting the groups of 4 by first sorting the pairs within them
// chunks_a:
// [ [[0, 4, 8, 12], Values [ [[min(0, 1), min(4, 5), min(8, 9), min(12, 13)],
// [1, 5, 9, 13]], =====> [max(0, 1), max(4, 5), max(8, 9)], max(12, 13)],
// [[2, 6, 10, 14], [[min(2, 3), min(6, 7), min(10, 11), min(14, 15)],
// [3, 7, 11, 15]] [max(2, 3), max(2, 7), max(10, 11), max(14, 15)]]
// ] ]
//
// For it==2 and i==0,
// sorting the groups of 4 by sorting all 4 elements,
// i.e., Z_0 corresponds to class0, Z_1 corresponds to class1 and so on
// chunks_a:
// [
// [[min(0, 2), max(0, 2), min(4, 6), max(4, 6)]],
// [[min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
// [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)]],
// [[min(9, 11), max(9, 11), min(13, 15), max(13, 15)]]
//
// ]
let (chunks_a_shape, chunks_a_scalar_t) = match chunks_a.get_type()? {
Type::Array(shape, scalar_type) => (shape, scalar_type),
_ => return Err(runtime_error!("Array Type not found")),
};
// Get the first class elements i.e. Z_0
let single_first_element = b_graph.get(chunks_a.clone(), vec![0])?;
// For it==1, i==0, single_first_element shape would be [1, 8, x]
// For it==2, i==1, single_first_element shape would be [2, 4, x]
// For it==2, i==0, single_first_element shape would be [1, 4, x]
// If first step, then arrange odd-even adjacent pairs of keys into ordered pairs
if i == it - 1 {
// Code to sort the only two chunks/halfs only
// Here, we are dealing with just two classes: Z_0 (odds) and Z_1 (evens)
// For it==1, i==0,
// Z_0 and Z_1 shapes are [1, 8, x]
// For it==2, i==1,
// Z_0 and Z_1 shapes are [2, 4, x]
// Get the group of odd indexed keys from each group or class, i.e., Z_{0}
let uu = single_first_element;
// Get the group of even indexed keys from each group or class, i.e., Z_{1}
let vv = b_graph.get(chunks_a.clone(), vec![1])?;
// Get minimums from both the classes
let chunks_a_0 = b_graph
.custom_op(CustomOperation::new(Min {}), vec![uu.clone(), vv.clone()])?;
// For it==1, i==0, chunks_a_0 = [[min(0, 1), min(2, 3), ..., min(12, 13), min(14, 15)]]
// For it==2, i==1, chunks_a_0 = [[min(0, 2), min(4, 6), min(8, 10), min(12, 14)],
// [min(1, 3), min(5, 7), min(9, 11), min(13, 15)]]
// Get maximums from both the classes
let chunks_a_1 = b_graph
.custom_op(CustomOperation::new(Max {}), vec![uu.clone(), vv.clone()])?;
// For it==1, i==0, chunks_a_1 = [[max(0, 1), max(2, 3), ..., max(12, 13), max(14, 15)]]
// For it==2, i==1, chunks_a_0 = [[max(0, 2), max(4, 6), max(8, 10), max(12, 14)],
// [max(1, 3), max(5, 7), max(9, 11), max(13, 15)]]
// Collect these maximums and minimums together for reshaping later
let a_combined = b_graph.create_tuple(vec![chunks_a_0, chunks_a_1])?;
// For it==1, i==0,
// a_combined = [(min(0, 1), max(0, 1)), (min(2, 3), max(2, 3)), ..., (min(12, 13), max(12, 13)), (min(14, 15), max(14, 15))]
// For it==2, i==1,
// a_combined = [[(min(0, 2), max(0, 2)), (min(4, 6), max(4, 6)), (min(8, 10), max(8, 10)), (min(12, 14), max(12, 14))],
// [(min(1, 3), max(1, 3)), (min(5, 7), max(5, 7)), (min(9, 11), max(9, 11)), (min(13, 15), max(13, 15))]]
// Reshape these combined elements back into a vector shape
let interm_chunks_a = b_graph.reshape(
a_combined,
vector_type(
chunks_a_sz_z,
array_type(vec![chunks_a_sz_y, num_class_reps, b], chunks_a_scalar_t),
),
)?;
// For it==1, i==0,
// i.e., chunks_a's shape [2, 1, 8, x] for further processing
// interm_chunks_a = <[min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(6, 7)]>,
// <[min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]>
// For it==2, i==1,
// chunks_a's shape [2, 2, 4, x]
// interm_chunks_a = <[ [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// [min(8, 10), max(8, 10), min(12, 14), max(12, 14)] ],
// [ [min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)] ]>
// Convert these combined elements back to an array of original shape
intermediate_chunks_a.push(b_graph.vector_to_array(interm_chunks_a)?);
// For it==1, i==0,
// i.e., into the chunks_a's shape [2, 1, 8, x] for further processing
// intermediate_chunks_a[intermediate_chunks_a.len()-1] =
// [
// [[min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(2, 7)]],
//
// [[min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]]
// ]
// For it==2, i==1,
// i.e., into the chunks_a's shape [2, 2, 4, x] for further processing
// intermediate_chunks_a[intermediate_chunks_a.len()-1] =
// [
// [[min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// [min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
// [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]]
// ]
} else {
// This else block corresponds to the COMP() operations
// specified within the MERGE() function in (https://math.mit.edu/~shor/18.310/batcher.pdf, p. 3) and
// if x_{1}, x_{2}, ..., x_{n} are the keys to be sorted then
// this COMP is operated as COMP(x 2 , x 3 ), COMP(x 4 , x 5 ), · · ·
// COMP(x n−2 , x n−1 ).
// In this case, we would not be considering terminal sub-groups
// i.e. Z_{0} and Z_{2^{it-i}-1}
// Set the shape of Z_0
let a_single_first_elem = b_graph.reshape(
single_first_element,
array_type(
chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
chunks_a_scalar_t.clone(),
),
)?;
// For it==2, i==0,
// a_single_first_elem =
// [min(0, 2), max(0, 2), min(4, 6), max(4, 6)]
// Obtain all the odd components of Z, except the first and last one,
// i.e., Z_{i} s.t. 1 <=i < 2^{it-i}-1 && i % 2 == 1
let uu = b_graph
.get_slice(chunks_a.clone(), vec![SubArray(Some(1), Some(-1), Some(2))])?;
// For it==2, i==0, uu shape = [1, 1, 4, x], uu =
// [
// [[min(8, 10), max(8, 10), min(12, 14), max(12, 14)]],
// ]
// Obtain all the even components of Z, except the first one i.e.
// Z_{i} s.t. 2 <= i < 2^{it-i} && i % 2 == 0
let vv =
b_graph.get_slice(chunks_a.clone(), vec![SubArray(Some(2), None, Some(2))])?;
// For it==2, i==0, vv shape = [1, 1, 4, x], vv =
// [
// [[min(1, 3), max(1, 3), min(5, 7), max(5, 7)]],
// ]
// Obtain the minimum of these two arrays - uu and vv
let chunks_a_evens = b_graph
.custom_op(CustomOperation::new(Min {}), vec![uu.clone(), vv.clone()])?;
// For it==2, i==0, chunks_a_evens shape = [1, 1, 4, x], chunks_a_evens =
// [
// [[min(8, 10, 1, 3), min(max(8, 10), max(1, 3)), min(12, 14, 5, 7), min(max(12, 14), max(5, 7))]]
// ]
// Obtain the maximum of these two arrays - uu and vv
let chunks_a_odds = b_graph
.custom_op(CustomOperation::new(Max {}), vec![uu.clone(), vv.clone()])?;
// For it==2, i==0, chunks_a_odds shape = [1, 1, 4, x], chunks_a_odds =
// [
// [[max(min(8, 10), min(1, 3)), max(8, 10, 1, 3), max(min(12, 14), min(5, 7)), max(12, 14, 5, 7)]]
// ]
// Convert the array to vector and remove the extra Z-dimension
let v_non_terminal_evens = b_graph.array_to_vector(chunks_a_evens)?;
// For it==2, i==0, v_non_terminal_evens shape = [1, 4, x]<1>
// v_non_terminal_evens =
// <[min(8, 10, 1, 3), min(max(8, 10), max(1, 3)), min(12, 14, 5, 7), min(max(12, 14), max(5, 7))]>
// Convert the array to vector and remove the extra Z-dimension
let v_non_terminal_odds = b_graph.array_to_vector(chunks_a_odds)?;
// For it==2, i==0, v_non_terminal_odds shape = [1, 4, x]<1>
// v_non_terminal_odds =
// <[max(min(8, 10), min(1, 3)), max(8, 10, 1, 3), max(min(12, 14), min(5, 7)), max(12, 14, 5, 7)]>
// Zip both the results together
let v_non_term_elems =
b_graph.zip(vec![v_non_terminal_evens, v_non_terminal_odds])?;
// For it==2, i==0, v_non_term_elems shape = ((1, 4, x)(1, 4, x))<1>
// v_non_term_elems =
// <(min(8, 10, 1, 3), max(min(8, 10), min(1, 3))),
// (min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)),
// (min(12, 14, 5, 7), max(min(12, 14), min(5, 7))),
// (min(max(12, 14), max(5, 7)), max(12, 14, 5, 7))>
// In a similar way to the first element i.e. Z_{0}, extract the last element
let single_last_elem =
b_graph.get(chunks_a.clone(), vec![chunks_a_shape[0] - 1])?;
// For it==2, i==0, single_last_element shape would be [1, 4, x]
// Set the shape of Z_{2^{it-i}-1} to [1, 4, x]
let a_single_last_elem = b_graph.reshape(
single_last_elem,
array_type(
chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
chunks_a_scalar_t.clone(),
),
)?;
// For it==2, i==0,
// a_single_last_elem =
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
// Create a tuple of Z: (first element-Z_{0}, vector, last element-Z_{2^{it-i}-1})
let v_combined = b_graph.create_tuple(vec![
a_single_first_elem,
v_non_term_elems,
a_single_last_elem,
])?;
// For it==2, i==0,
// v_combined =
// ([min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// <(min(8, 10, 1, 3), max(min(8, 10), min(1, 3))),
// (min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)),
// (min(12, 14, 5, 7), max(min(12, 14), min(5, 7))),
// (min(max(12, 14), max(5, 7)), max(12, 14, 5, 7))>,
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
// )
// Reshape the tuple back into vector form
let v_chunk_a = b_graph.reshape(
v_combined,
vector_type(
chunks_a_shape[0],
array_type(
chunks_a_shape[1..chunks_a_shape.len()].to_vec(),
chunks_a_scalar_t,
),
),
)?;
// For it==2, i==0,
// v_chunk_a's shape is {[1, 4, x]}<4> i.e. 4 components, each an
// array of size [1, 4, x]
// v_chunk_a =
// <
// [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// [min(8, 10, 1, 3), max(min(8, 10), min(1, 3)), min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)],
// [min(12, 14, 5, 7), max(min(12, 14), min(5, 7)), min(max(12, 14), max(5, 7)), max(12, 14, 5, 7),
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
// >
//
// Convert the vector form to the array form
intermediate_chunks_a.push(b_graph.vector_to_array(v_chunk_a)?);
// For it==2, i==0,
// intermediate_chunks_a[intermediate_chunks_a.len()-1] =
// [
// [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// [min(8, 10, 1, 3), max(min(8, 10), min(1, 3)), min(max(8, 10), max(1, 3)), max(8, 10, 1, 3)],
// [min(12, 14, 5, 7), max(min(12, 14), min(5, 7)), min(max(12, 14), max(5, 7)), max(12, 14, 5, 7),
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
// ]
}
// Reshape/Merge it back into 2-D from the 3-D we created for performing
// the Min/Max compare and switches
global_chunks_a = b_graph.reshape(
intermediate_chunks_a[(intermediate_chunks_a.len() - 1) as usize].clone(),
array_type(vec![num_classes, num_class_reps, b], BIT),
)?;
// For it==1, i==0, reshape latest intermediate_chunk_a from [2, 1, 8, x] -> [2, 8, x] for next global_chunks_a
// global_chunks_a:
// [
// [min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(2, 7)],
// [min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]
// ]
// For it==2, i==1, reshape latest intermediate_chunk_a from [2, 2, 4, x] -> [4, 4, x] for next global_chunks_a
// global_chunks_a:
// [
// [min(0, 2), max(0, 2), min(4, 6), max(4, 6)],
// [min(8, 10), max(8, 10), min(12, 14), max(12, 14)],
// [min(1, 3), max(1, 3), min(5, 7), max(5, 7)],
// [min(9, 11), max(9, 11), min(13, 15), max(13, 15)]
// ]
}
// Permute axes to revert original transpose
let aa_transposed = b_graph.permute_axes(global_chunks_a.clone(), vec![1, 0, 2])?;
// For it==1, i==0, aa_transposed shape: [8, 2, x], with X-axis representing the classes
// aa_transposed:
// [
// [min(0, 1), max(0, 1)],
// [min(2, 3), max(2, 3)],
// [min(4, 5), max(4, 5)],
// [min(6, 7), max(2, 7)],
// [min(8, 9), max(8, 9)],
// [min(10, 11), max(10, 11)],
// [min(12, 13), max(12, 13)],
// [min(14, 15), max(14, 15)]
// ]
// Reshape data to flatten into shape [n, x] for further processing
stage_ops.push(b_graph.reshape(aa_transposed, array_type(vec![n, b], BIT))?)
// In terms of the initial index positions of elements, this looks like:
// data idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// data: [100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85]
// For it==1, i==0, aa_transposed ==
// idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Current round idx: [min(0, 1), max(0, 1), min(2, 3), max(2, 3), min(4, 5), max(4, 5), min(6, 7), max(6, 7), min(8, 9), max(8, 9), min(10, 11), max(10, 11), min(12, 13), max(12, 13), min(14, 15), max(14, 15)]
// data idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// data post ops.: [99, 100, 97, 98, 95, 96, 93, 94, 91, 92, 89, 90, 87, 88, 85, 86]
}
// Convert output from the binary form to the arithmetic form
let output = if st != BIT {
stage_ops[k as usize].b2a(st)?
} else {
stage_ops[k as usize].clone()
};
// Before computation every graph should be finalized, which means that it should have a designated output node
// This can be done by calling `g.set_output_node(output)?` or as below
b_graph.set_output_node(output)?;
// Finalization checks that the output node of the graph g is set. After finalization the graph can't be changed
b_graph.finalize()?;
Ok(b_graph)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_ops::run_instantiation_pass;
use crate::data_types::{ScalarType, BIT, INT16, INT32, INT64};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::random::PRNG;
use std::cmp::Reverse;
/// Helper function to test the sorting network graph for large inputs
/// Testing is done by first sorting it with the given graph and then
/// comparing its result with the non-graph-sorted result
///
/// # Arguments
///
/// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
/// * `st` - scalar type of array elements
fn test_large_vec_unsigned_batchers_sorting(k: u32, st: ScalarType) -> Result<()> {
let context = create_context()?;
let graph: Graph = create_batchers_sorting_graph(context.clone(), k, st.clone())?;
context.set_main_graph(graph.clone())?;
context.finalize()?;
let mapped_c = run_instantiation_pass(graph.get_context())?;
let seed = b"\xB6\xD7\x1A\x2F\x88\xC1\x12\xBA\x3F\x2E\x17\xAB\xB7\x46\x15\x9A";
let mut prng = PRNG::new(Some(seed.clone()))?;
let array_t: Type = array_type(vec![2_u64.pow(k)], st);
let data = prng.get_random_value(array_t.clone())?;
let data_v_u64 = data.to_flattened_array_u64(array_t.clone())?;
let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
.to_flattened_array_u64(array_t)?;
let mut sorted_data = data_v_u64;
sorted_data.sort_unstable();
assert_eq!(sorted_data, result);
Ok(())
}
/// Helper function to test the sorting network graph for large inputs
/// Testing is done by first sorting it with the given graph and then
/// comparing its result with the non-graph-sorted result
///
/// # Arguments
///
/// * `k` - number of elements of an array (i.e., 2<sup>k</sup>)
/// * `st` - scalar type of array elements
fn test_unsigned_batchers_sorting_graph_helper(
k: u32,
st: ScalarType,
data: Vec<u64>,
) -> Result<()> {
let context = create_context()?;
let graph: Graph = create_batchers_sorting_graph(context.clone(), k, st.clone())?;
context.set_main_graph(graph.clone())?;
context.finalize()?;
let mapped_c = run_instantiation_pass(graph.get_context())?;
let v_a = Value::from_flattened_array(&data, st.clone())?;
let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![v_a])?
.to_flattened_array_u64(array_type(vec![data.len() as u64], st))?;
let mut sorted_data = data;
sorted_data.sort_unstable();
assert_eq!(sorted_data, result);
Ok(())
}
/// This function tests the well-formed sorting graph for its correctness
/// Parameters varied are k, st and the input data could be unsorted,
/// sorted or sorted in a decreasing order.
#[test]
fn test_wellformed_unsigned_batchers_sorting_graph() -> Result<()> {
let mut data = vec![65535, 0, 2, 32768];
test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;
data.sort_unstable();
test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;
data.sort_by_key(|w| Reverse(*w));
test_unsigned_batchers_sorting_graph_helper(2, INT16, data.clone())?;
let data = vec![548890456, 402403639693304868, u64::MAX, 999790788];
test_unsigned_batchers_sorting_graph_helper(2, INT64, data.clone())?;
let data = vec![643082556];
test_unsigned_batchers_sorting_graph_helper(0, INT32, data.clone())?;
let data = vec![1, 0, 0, 1];
test_unsigned_batchers_sorting_graph_helper(2, BIT, data.clone())?;
test_large_vec_unsigned_batchers_sorting(7, BIT)?;
test_large_vec_unsigned_batchers_sorting(4, INT64)?;
Ok(())
}
}