ciphercore_base/applications/
sort.rs1use crate::custom_ops::CustomOperation;
3use crate::data_types::{array_type, ScalarType, BIT};
4use crate::errors::Result;
5use crate::graphs::*;
6use crate::ops::integer_key_sort::SortByIntegerKey;
7
8pub fn create_sort_graph(context: Context, n: u64, st: ScalarType) -> Result<Graph> {
20 let graph = context.create_graph()?;
22 let input = graph.input(array_type(vec![n], st))?;
24 let key = "key".to_string();
26 let node = graph.create_named_tuple(vec![(key.clone(), input)])?;
27 let sorted_node = graph.custom_op(
29 CustomOperation::new(SortByIntegerKey { key: key.clone() }),
30 vec![node],
31 )?;
32 let output = sorted_node.named_tuple_get(key)?;
34
35 graph.set_output_node(output)?;
38 graph.finalize()?;
40
41 Ok(graph)
42}
43
44pub fn create_binary_sort_graph(context: Context, n: u64, b: u64) -> Result<Graph> {
56 let graph = context.create_graph()?;
58 let input = graph.input(array_type(vec![n, b], BIT))?;
60 let key = "key".to_string();
62 let node = graph.create_named_tuple(vec![(key.clone(), input)])?;
63 let sorted_node = node.sort(key.clone())?;
65 let output = sorted_node.named_tuple_get(key)?;
67
68 graph.set_output_node(output)?;
71 graph.finalize()?;
73
74 Ok(graph)
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::custom_ops::run_instantiation_pass;
81 use crate::data_types::{ScalarType, BIT, INT64, UINT16, UINT32, UINT64};
82 use crate::data_values::Value;
83 use crate::evaluators::random_evaluate;
84 use crate::random::PRNG;
85 use std::cmp::Reverse;
86
87 fn test_large_vec_sort(n: u64, st: ScalarType) -> Result<()> {
96 let context = create_context()?;
97 let graph: Graph = create_sort_graph(context.clone(), n, st)?;
98 context.set_main_graph(graph.clone())?;
99 context.finalize()?;
100
101 let mapped_c = run_instantiation_pass(graph.get_context())?;
102
103 let seed = b"\xB6\xD7\x1A\x2F\x88\xC1\x12\xBA\x3F\x2E\x17\xAB\xB7\x46\x15\x9A";
104 let mut prng = PRNG::new(Some(seed.clone()))?;
105 let array_t = array_type(vec![n], st);
106 let data = prng.get_random_value(array_t.clone())?;
107 if st.is_signed() {
108 let data_v_i64 = data.to_flattened_array_i64(array_t.clone())?;
109 let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
110 .to_flattened_array_i64(array_t)?;
111 let mut sorted_data = data_v_i64;
112 sorted_data.sort_unstable();
113 assert_eq!(sorted_data, result);
114 } else {
115 let data_v_u64 = data.to_flattened_array_u64(array_t.clone())?;
116 let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![data])?
117 .to_flattened_array_u64(array_t)?;
118 let mut sorted_data = data_v_u64;
119 sorted_data.sort_unstable();
120 assert_eq!(sorted_data, result);
121 }
122 Ok(())
123 }
124
125 fn test_sort_graph_helper(n: u64, st: ScalarType, data: Vec<u64>) -> Result<()> {
134 let context = create_context()?;
135 let graph: Graph = create_sort_graph(context.clone(), n, st)?;
136 context.set_main_graph(graph.clone())?;
137 context.finalize()?;
138
139 let mapped_c = run_instantiation_pass(graph.get_context())?;
140
141 let v_a = Value::from_flattened_array(&data, st)?;
142 let result = random_evaluate(mapped_c.mappings.get_graph(graph), vec![v_a])?
143 .to_flattened_array_u64(array_type(vec![data.len() as u64], st))?;
144 let mut sorted_data = data;
145 sorted_data.sort_unstable();
146 assert_eq!(sorted_data, result);
147 Ok(())
148 }
149
150 #[test]
154 fn test_sort_graph() -> Result<()> {
155 let mut data = vec![65535, 0, 2, 32768];
156 test_sort_graph_helper(4, UINT16, data.clone())?;
157 data.sort_unstable();
158 test_sort_graph_helper(4, UINT16, data.clone())?;
159 data.sort_by_key(|w| Reverse(*w));
160 test_sort_graph_helper(4, UINT16, data.clone())?;
161
162 let data = vec![548890456, 402403639693304868, u64::MAX, 999790788];
163 test_sort_graph_helper(4, UINT64, data.clone())?;
164
165 let data = vec![643082556];
166 test_sort_graph_helper(1, UINT32, data.clone())?;
167
168 let data = vec![1, 0, 0, 1];
169 test_sort_graph_helper(4, BIT, data.clone())?;
170
171 test_large_vec_sort(1000, BIT)?;
172 test_large_vec_sort(1000, UINT64)?;
173 test_large_vec_sort(1000, INT64)?;
174
175 Ok(())
176 }
177}