ciphercore_base/applications/
sort.rs

1//! Sorting of an array
2use crate::custom_ops::CustomOperation;
3use crate::data_types::{array_type, ScalarType, BIT};
4use crate::errors::Result;
5use crate::graphs::*;
6use crate::ops::integer_key_sort::SortByIntegerKey;
7
8/// Creates a graph that sorts an array using [Radix Sort MPC protocol](https://eprint.iacr.org/2019/695.pdf).
9///
10/// # Arguments
11///
12/// * `context` - context where a sorting graph should be created
13/// * `n` - number of elements of an array
14/// * `st` - scalar type of array elements
15///
16/// # Returns
17///
18/// Graph that sorts an array
19pub fn create_sort_graph(context: Context, n: u64, st: ScalarType) -> Result<Graph> {
20    // Create a graph in a given context that will be used for sorting
21    let graph = context.create_graph()?;
22    // Define the input node with an array of n integers
23    let input = graph.input(array_type(vec![n], st))?;
24    // Create named a tuple as required by the interface.
25    let key = "key".to_string();
26    let node = graph.create_named_tuple(vec![(key.clone(), input)])?;
27    // Sort an array
28    let sorted_node = graph.custom_op(
29        CustomOperation::new(SortByIntegerKey { key: key.clone() }),
30        vec![node],
31    )?;
32    // Extract result from tuple by key.
33    let output = sorted_node.named_tuple_get(key)?;
34
35    // Before computation every graph should be finalized, which means that it should have a designated output node
36    // This can be done by calling `g.set_output_node(output)?` or as below
37    graph.set_output_node(output)?;
38    // Finalization checks that the output node of the graph g is set. After finalization the graph can't be changed
39    graph.finalize()?;
40
41    Ok(graph)
42}
43
44/// Creates a graph that sorts an array of bitstrings using [Radix Sort MPC protocol](https://eprint.iacr.org/2019/695.pdf).
45///
46/// # Arguments
47///
48/// * `context` - context where a sorting graph should be created
49/// * `n` - number of elements of an array
50/// * `b` - length of bitstrings
51///
52/// # Returns
53///
54/// Graph that sorts an array
55pub fn create_binary_sort_graph(context: Context, n: u64, b: u64) -> Result<Graph> {
56    // Create a graph in a given context that will be used for sorting
57    let graph = context.create_graph()?;
58    // Define the input node with an array of n integers
59    let input = graph.input(array_type(vec![n, b], BIT))?;
60    // Create a named tuple as required by the interface.
61    let key = "key".to_string();
62    let node = graph.create_named_tuple(vec![(key.clone(), input)])?;
63    // Sort an array
64    let sorted_node = node.sort(key.clone())?;
65    // Extract the result from the tuple by key.
66    let output = sorted_node.named_tuple_get(key)?;
67
68    // Before computation every graph should be finalized, which means that it should have a designated output node
69    // This can be done by calling `g.set_output_node(output)?` or as below
70    graph.set_output_node(output)?;
71    // Finalization checks that the output node of the graph g is set. After finalization the graph can't be changed
72    graph.finalize()?;
73
74    Ok(graph)
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::custom_ops::run_instantiation_pass;
81    use crate::data_types::{ScalarType, BIT, INT64, UINT16, UINT32, UINT64};
82    use crate::data_values::Value;
83    use crate::evaluators::random_evaluate;
84    use crate::random::PRNG;
85    use std::cmp::Reverse;
86
87    /// Helper function to test the sorting network graph for large inputs
88    /// Testing is done by first sorting it with the given graph and then
89    /// comparing its result with the non-graph-sorted result
90    ///
91    /// # Arguments
92    ///
93    /// * `n` - number of elements of an array
94    /// * `st` - scalar type of array elements
95    fn test_large_vec_sort(n: u64, st: ScalarType) -> Result<()> {
96        let context = create_context()?;
97        let graph: Graph = create_sort_graph(context.clone(), n, st)?;
98        context.set_main_graph(graph.clone())?;
99        context.finalize()?;
100
101        let mapped_c = run_instantiation_pass(graph.get_context())?;
102
103        let seed = b"\xB6\xD7\x1A\x2F\x88\xC1\x12\xBA\x3F\x2E\x17\xAB\xB7\x46\x15\x9A";
104        let mut prng = PRNG::new(Some(seed.clone()))?;
105        let array_t = array_type(vec![n], st);
106        let data = prng.get_random_value(array_t.clone())?;
107        if st.is_signed() {
108            let data_v_i64 = data.to_flattened_array_i64(array_t.clone())?;
109            let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
110                .to_flattened_array_i64(array_t)?;
111            let mut sorted_data = data_v_i64;
112            sorted_data.sort_unstable();
113            assert_eq!(sorted_data, result);
114        } else {
115            let data_v_u64 = data.to_flattened_array_u64(array_t.clone())?;
116            let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
117                .to_flattened_array_u64(array_t)?;
118            let mut sorted_data = data_v_u64;
119            sorted_data.sort_unstable();
120            assert_eq!(sorted_data, result);
121        }
122        Ok(())
123    }
124
125    /// Helper function to test the sorting network graph for large inputs
126    /// Testing is done by first sorting it with the given graph and then
127    /// comparing its result with the non-graph-sorted result
128    ///
129    /// # Arguments
130    ///
131    /// * `n` - number of elements of an array
132    /// * `st` - scalar type of array elements
133    fn test_sort_graph_helper(n: u64, st: ScalarType, data: Vec<u64>) -> Result<()> {
134        let context = create_context()?;
135        let graph: Graph = create_sort_graph(context.clone(), n, st)?;
136        context.set_main_graph(graph.clone())?;
137        context.finalize()?;
138
139        let mapped_c = run_instantiation_pass(graph.get_context())?;
140
141        let v_a = Value::from_flattened_array(&data, st)?;
142        let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![v_a])?
143            .to_flattened_array_u64(array_type(vec![data.len() as u64], st))?;
144        let mut sorted_data = data;
145        sorted_data.sort_unstable();
146        assert_eq!(sorted_data, result);
147        Ok(())
148    }
149
150    /// This function tests the well-formed sorting graph for its correctness
151    /// Parameters varied are k, st and the input data could be unsorted,
152    /// sorted or sorted in a decreasing order.
153    #[test]
154    fn test_sort_graph() -> Result<()> {
155        let mut data = vec![65535, 0, 2, 32768];
156        test_sort_graph_helper(4, UINT16, data.clone())?;
157        data.sort_unstable();
158        test_sort_graph_helper(4, UINT16, data.clone())?;
159        data.sort_by_key(|w| Reverse(*w));
160        test_sort_graph_helper(4, UINT16, data.clone())?;
161
162        let data = vec![548890456, 402403639693304868, u64::MAX, 999790788];
163        test_sort_graph_helper(4, UINT64, data.clone())?;
164
165        let data = vec![643082556];
166        test_sort_graph_helper(1, UINT32, data.clone())?;
167
168        let data = vec![1, 0, 0, 1];
169        test_sort_graph_helper(4, BIT, data.clone())?;
170
171        test_large_vec_sort(1000, BIT)?;
172        test_large_vec_sort(1000, UINT64)?;
173        test_large_vec_sort(1000, INT64)?;
174
175        Ok(())
176    }
177}