1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
//! Minimum of an integer array
use crate::custom_ops::CustomOperation;
use crate::data_types::{array_type, scalar_size_in_bits, ScalarType, BIT};
use crate::errors::Result;
use crate::graphs::{Context, Graph, SliceElement};
use crate::ops::min_max::Min;
/// Creates a graph that finds the minimum of an array.
///
/// Minimum works only for arrays with unsigned scalar values.
///
/// # Arguments
///
/// * `context` - context where a minimum graph should be created
/// * `n` - number of elements of an array (i.e., 2<sup>n</sup>)
/// * `st` - scalar type of array elements
///
/// # Returns
///
/// Graph that finds the minimum of an array
pub fn create_minimum_graph(context: Context, n: u64, st: ScalarType) -> Result<Graph> {
// First, check that the input scalar type is unsigned
if st.get_signed() {
return Err(runtime_error!("Sorting doesn't accept signed scalar types"));
}
// Create a graph in a given context that will be used for finding the minimum
let g = context.create_graph()?;
// Create the type of the input array with `n` elements.
// To find the minimum of an array, we resort to the custom operation Min (see ops.rs) that accepts only binary input.
let input_type = if st == BIT {
array_type(vec![1 << n], BIT)
} else {
let st_size = scalar_size_in_bits(st.clone());
array_type(vec![1 << n, st_size], BIT)
};
// Add an input node to the empty graph g created above.
// This input node requires the input array type generated previously.
let mut binary_array = g.input(input_type)?;
// We find the minimum using the tournament method. This allows to reduce the graph size to O(n) from O(2<sup>n</sup>) nodes.
// Namely, we split the input array into pairs, find a minimum within each pair and create a new array from these minima.
// Then, we repeat this procedure for the new array.
// For example, let [2,7,0,3,11,5,0,4] be an input array.
// The 1st iteration yields [min(2,11), min(7,5), min(0,0), min(3,4)] = [2,5,0,3]
// The 2nd iteration results in [min(2,0), min(5,3)] = [0,3]
// The 3rd iteration returns [min(0,3)] = [0]
for level in (0..n).rev() {
// Extract the first half of the array using the [Graph::get_slice] operation.
// Our slicing conventions follow [the NumPy rules](https://numpy.org/doc/stable/user/basics.indexing.html).
let half1 =
binary_array.get_slice(vec![SliceElement::SubArray(None, Some(1 << level), None)])?;
// Extract the first half of the array using the [Graph::get_slice] operation.
let half2 =
binary_array.get_slice(vec![SliceElement::SubArray(Some(1 << level), None, None)])?;
// Compare the first half with the second half elementwise to find minimums.
// This is done via the custom operation Min (see ops.rs).
binary_array = g.custom_op(CustomOperation::new(Min {}), vec![half1, half2])?;
}
// Convert output from the binary form to the arithmetic form
let output = if st != BIT {
binary_array.b2a(st)?
} else {
binary_array
};
// Before computation every graph should be finalized, which means that it should have a designated output node.
// This can be done by calling `g.set_output_node(output)?` or as below.
output.set_as_output()?;
// Finalization checks that the output node of the graph g is set. After finalization the graph can't be changed.
g.finalize()?;
Ok(g)
}
#[cfg(test)]
mod tests {
use crate::custom_ops::run_instantiation_pass;
use crate::data_types::UINT32;
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::create_context;
use std::ops::Not;
use super::*;
fn test_minimum_helper<T: TryInto<u64> + Not<Output = T> + TryInto<u8> + Copy>(
input_value: &[T],
n: u64,
st: ScalarType,
) -> Value {
|| -> Result<Value> {
let c = create_context()?;
let g = create_minimum_graph(c.clone(), n, st.clone())?;
g.set_as_main()?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?.get_context();
let mapped_g = mapped_c.get_main_graph()?;
let input_type = array_type(vec![n], st.clone());
let val = Value::from_flattened_array(input_value, input_type.get_scalar_type())?;
random_evaluate(mapped_g, vec![val])
}()
.unwrap()
}
#[test]
fn test_minimum() {
|| -> Result<()> {
assert!(
test_minimum_helper(
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
4,
UINT32
) == Value::from_flattened_array(&[1], UINT32)?
);
Ok(())
}()
.unwrap();
|| -> Result<()> {
assert!(
test_minimum_helper(&[0, 1, 1, 0, 1, 1, 0, 0], 3, BIT)
== Value::from_flattened_array(&[0], BIT)?
);
Ok(())
}()
.unwrap();
}
}