ciphercore_base/ops/
min_max.rs

1//! Minimum and maximum operations. They operate on unsigned integers represented as bitstrings.
2use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, Type};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node};
6
7use super::comparisons::GreaterThan;
8use super::multiplexer::Mux;
9
10use serde::{Deserialize, Serialize};
11
12/// A structure that defines the custom operation Min that computes the minimum of length-n bitstring arrays elementwise.
13///
14/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
15/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
16/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
17///
18/// To compare signed numbers, `signed_comparison` should be set `true`.
19///
20/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
21///
22/// # Custom operation arguments
23///
24/// - Node containing a binary array or scalar
25/// - Node containing a binary array or scalar
26///
27/// # Custom operation returns
28///
29/// New Min node
30///
31/// # Example
32///
33/// ```
34/// # use ciphercore_base::graphs::create_context;
35/// # use ciphercore_base::data_types::{array_type, BIT};
36/// # use ciphercore_base::custom_ops::{CustomOperation};
37/// # use ciphercore_base::ops::min_max::Min;
38/// let c = create_context().unwrap();
39/// let g = c.create_graph().unwrap();
40/// let t = array_type(vec![2, 3], BIT);
41/// let n1 = g.input(t.clone()).unwrap();
42/// let n2 = g.input(t.clone()).unwrap();
43/// let n3 = g.custom_op(CustomOperation::new(Min {signed_comparison: false}), vec![n1, n2]).unwrap();
44/// ```
45#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
46pub struct Min {
47    /// Boolean value indicating whether input bitstring represent signed integers
48    pub signed_comparison: bool,
49}
50
51/// If `cmp` is an array, add `1` to the shape by reshaping,
52/// otherwise, do nothing.
53/// This helper function is necessary for min/max, since
54/// we need to pad the shape of the result of the comparison
55/// in order to be able to call mux later.
56fn normalize_cmp(cmp: Node) -> Result<Node> {
57    let cmp_type = cmp.get_type()?;
58    let normalized_cmp = if cmp_type.is_array() {
59        let mut new_shape = cmp_type.get_shape();
60        let st = cmp_type.get_scalar_type();
61        new_shape.push(1);
62        cmp.reshape(array_type(new_shape, st))?
63    } else {
64        cmp
65    };
66    Ok(normalized_cmp)
67}
68
69#[typetag::serde]
70impl CustomOperationBody for Min {
71    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
72        if arguments_types.len() != 2 {
73            return Err(runtime_error!("Invalid number of arguments for Min"));
74        }
75        let g = context.create_graph()?;
76        let i1 = g.input(arguments_types[0].clone())?;
77        let i2 = g.input(arguments_types[1].clone())?;
78        let cmp = g.custom_op(
79            CustomOperation::new(GreaterThan {
80                signed_comparison: self.signed_comparison,
81            }),
82            vec![i1.clone(), i2.clone()],
83        )?;
84        let normalized_cmp = normalize_cmp(cmp)?;
85        let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i2, i1])?;
86        g.set_output_node(o)?;
87        g.finalize()?;
88        Ok(g)
89    }
90
91    fn get_name(&self) -> String {
92        format!("Min(signed_comparison={})", self.signed_comparison)
93    }
94}
95
96/// A structure that defines the custom operation Max that computes the maximum of length-n bitstring arrays elementwise.
97///
98/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
99/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
100/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
101///
102/// To compare signed numbers, `signed_comparison` should be set `true`.
103///
104/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
105///
106/// # Custom operation arguments
107///
108/// - Node containing a binary array or scalar
109/// - Node containing a binary array or scalar
110///
111/// # Custom operation returns
112///
113/// New Max node
114///
115/// # Example
116///
117/// ```
118/// # use ciphercore_base::graphs::create_context;
119/// # use ciphercore_base::data_types::{array_type, BIT};
120/// # use ciphercore_base::custom_ops::{CustomOperation};
121/// # use ciphercore_base::ops::min_max::Max;
122/// let c = create_context().unwrap();
123/// let g = c.create_graph().unwrap();
124/// let t = array_type(vec![2, 3], BIT);
125/// let n1 = g.input(t.clone()).unwrap();
126/// let n2 = g.input(t.clone()).unwrap();
127/// let n3 = g.custom_op(CustomOperation::new(Max {signed_comparison: true}), vec![n1, n2]).unwrap();
128/// ```
129#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
130pub struct Max {
131    /// Boolean value indicating whether input bitstring represent signed integers
132    pub signed_comparison: bool,
133}
134
135#[typetag::serde]
136impl CustomOperationBody for Max {
137    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
138        if arguments_types.len() != 2 {
139            return Err(runtime_error!("Invalid number of arguments for Max"));
140        }
141        let g = context.create_graph()?;
142        let i1 = g.input(arguments_types[0].clone())?;
143        let i2 = g.input(arguments_types[1].clone())?;
144        let cmp = g.custom_op(
145            CustomOperation::new(GreaterThan {
146                signed_comparison: self.signed_comparison,
147            }),
148            vec![i1.clone(), i2.clone()],
149        )?;
150        let normalized_cmp = normalize_cmp(cmp)?;
151        let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i1, i2])?;
152        g.set_output_node(o)?;
153        g.finalize()?;
154        Ok(g)
155    }
156
157    fn get_name(&self) -> String {
158        format!("Max(signed_comparison={})", self.signed_comparison)
159    }
160}
161
162#[cfg(test)]
163mod tests {
164
165    use crate::custom_ops::run_instantiation_pass;
166    use crate::data_types::{array_type, scalar_type, BIT, INT64, UINT64};
167    use crate::data_values::Value;
168    use crate::evaluators::random_evaluate;
169    use crate::graphs::create_context;
170    use crate::graphs::util::simple_context;
171
172    use super::*;
173
174    use std::cmp::{max, min};
175
176    #[test]
177    fn test_well_formed() {
178        || -> Result<()> {
179            let test_data: Vec<(u64, u64)> = vec![
180                (31, 32),
181                (76543, 76544),
182                (0, 1),
183                (0, 0),
184                (761523, 761523),
185                (u64::MAX, u64::MAX - 1),
186                (u64::MAX - 761522, u64::MAX - 761523),
187            ];
188            let context = || -> Result<Context> {
189                let c = simple_context(|g| {
190                    let i1 = g.input(scalar_type(UINT64))?.a2b()?;
191                    let i2 = g.input(scalar_type(UINT64))?.a2b()?;
192                    g.create_tuple(vec![
193                        g.custom_op(
194                            CustomOperation::new(Min {
195                                signed_comparison: false,
196                            }),
197                            vec![i1.clone(), i2.clone()],
198                        )?,
199                        g.custom_op(
200                            CustomOperation::new(Max {
201                                signed_comparison: true,
202                            }),
203                            vec![i1.clone(), i2.clone()],
204                        )?,
205                    ])
206                })?;
207                let mapped_c = run_instantiation_pass(c)?;
208                Ok(mapped_c.get_context())
209            }()?;
210            for (u, v) in test_data {
211                let minmax = random_evaluate(
212                    context.get_main_graph()?,
213                    vec![
214                        Value::from_scalar(u, UINT64)?,
215                        Value::from_scalar(v, UINT64)?,
216                    ],
217                )?
218                .to_vector()?;
219                let computed_min = minmax[0].to_u64(UINT64)?;
220                let computed_max = minmax[1].to_i64(INT64)?;
221                assert_eq!(min(u, v), computed_min);
222                assert_eq!(max(u as i64, v as i64), computed_max);
223            }
224            Ok(())
225        }()
226        .unwrap();
227    }
228
229    #[test]
230    fn test_malformed() {
231        || -> Result<()> {
232            let c = create_context()?;
233            let g = c.create_graph()?;
234            let i1 = g.input(scalar_type(UINT64))?.a2b()?;
235            assert!(g
236                .custom_op(
237                    CustomOperation::new(Min {
238                        signed_comparison: false
239                    }),
240                    vec![i1.clone()]
241                )
242                .is_err());
243            assert!(g
244                .custom_op(
245                    CustomOperation::new(Max {
246                        signed_comparison: false
247                    }),
248                    vec![i1.clone()]
249                )
250                .is_err());
251            Ok(())
252        }()
253        .unwrap();
254    }
255
256    #[test]
257    fn test_vector() {
258        || -> Result<()> {
259            let context = || -> Result<Context> {
260                let c = simple_context(|g| {
261                    let i1 = g.input(array_type(vec![1, 3, 64], BIT))?;
262                    let i2 = g.input(array_type(vec![3, 1, 64], BIT))?;
263                    g.create_tuple(vec![
264                        g.custom_op(
265                            CustomOperation::new(Min {
266                                signed_comparison: false,
267                            }),
268                            vec![i1.clone(), i2.clone()],
269                        )?,
270                        g.custom_op(
271                            CustomOperation::new(Max {
272                                signed_comparison: false,
273                            }),
274                            vec![i1.clone(), i2.clone()],
275                        )?,
276                    ])
277                })?;
278                let mapped_c = run_instantiation_pass(c)?;
279                Ok(mapped_c.get_context())
280            }()?;
281            let a = vec![0, 30, 100];
282            let b = vec![10, 50, 150];
283            let v = random_evaluate(
284                context.get_main_graph()?,
285                vec![
286                    Value::from_flattened_array(&a, UINT64)?,
287                    Value::from_flattened_array(&b, UINT64)?,
288                ],
289            )?
290            .to_vector()?;
291            let min_a_b = v[0].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
292            let max_a_b = v[1].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
293            assert_eq!(min_a_b, vec![0, 10, 10, 0, 30, 50, 0, 30, 100]);
294            assert_eq!(max_a_b, vec![10, 30, 100, 50, 50, 100, 150, 150, 150]);
295            Ok(())
296        }()
297        .unwrap();
298    }
299}