ciphercore_base/ops/
multiplexer.rs

1//! Multiplexer (Mux) operation that takes three inputs a, b, c and returns b if a is 1 or c if a is 0.
2use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{scalar_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9/// A structure that defines the custom operation Mux that takes three inputs a, b, c and returns b if a is 1 or c if a is 0.
10///
11/// The input `a` should be arrays of bitstrings. The inputs `b` and `c` must have the same type. This operation is applied elementwise.
12///
13/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
14/// For example, if a,b,c are of shapes `[2,3]`, `[1,3]` and `[2,1]`, the resulting array has shape `[2,3]`.
15///
16/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
17///
18/// # Custom operation arguments
19///
20/// - Node containing a binary array or scalar
21/// - Node containing a binary array or scalar that will be chosen if the first input is 1
22/// - Node containing a binary array or scalar that will be chosen if the first input is 0
23///
24/// # Custom operation returns
25///
26/// New Mux node
27///
28/// # Example
29///
30/// ```
31/// # use ciphercore_base::graphs::create_context;
32/// # use ciphercore_base::data_types::{array_type, BIT, INT32};
33/// # use ciphercore_base::custom_ops::{CustomOperation};
34/// # use ciphercore_base::ops::multiplexer::Mux;
35/// let c = create_context().unwrap();
36/// let g = c.create_graph().unwrap();
37/// let t1 = array_type(vec![2, 3], BIT);
38/// let t2 = array_type(vec![3], BIT);
39/// let n1 = g.input(t1.clone()).unwrap();
40/// let n2 = g.input(t1.clone()).unwrap();
41/// let n3 = g.input(t2.clone()).unwrap();
42/// let n4 = g.custom_op(CustomOperation::new(Mux {}), vec![n1, n2, n3]).unwrap();
43/// ```
44#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
45pub struct Mux {}
46
47#[typetag::serde]
48impl CustomOperationBody for Mux {
49    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
50        if arguments_types.len() != 3 {
51            return Err(runtime_error!("Invalid number of arguments for Mux"));
52        }
53        let t = arguments_types[0].clone();
54        if !t.is_scalar() && !t.is_array() {
55            return Err(runtime_error!("Flag for Mux must be a scalar or an array"));
56        }
57        if t.get_scalar_type() != BIT {
58            return Err(runtime_error!("Flag for Mux must consist of bits"));
59        }
60        if arguments_types[1].get_scalar_type() != arguments_types[2].get_scalar_type() {
61            return Err(runtime_error!(
62                "Choices for Mux must have the same scalar type"
63            ));
64        }
65
66        let g = context.create_graph()?;
67        let i_flag = g.input(arguments_types[0].clone())?;
68        let i_choice1 = g.input(arguments_types[1].clone())?;
69        let i_choice0 = g.input(arguments_types[2].clone())?;
70        if arguments_types[1].get_scalar_type() == BIT {
71            i_choice0
72                .add(i_flag.multiply(i_choice0.add(i_choice1)?)?)?
73                .set_as_output()?;
74        } else {
75            let i_choice0 = i_choice0.mixed_multiply(i_flag.clone())?;
76            let i_choice1 = i_choice1.mixed_multiply(i_flag.add(g.ones(scalar_type(BIT))?)?)?;
77            i_choice0.add(i_choice1)?.set_as_output()?;
78        }
79        g.finalize()?;
80        Ok(g)
81    }
82
83    fn get_name(&self) -> String {
84        "Mux".to_owned()
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    use crate::custom_ops::run_instantiation_pass;
93    use crate::custom_ops::CustomOperation;
94    use crate::data_types::INT32;
95    use crate::data_types::UINT32;
96    use crate::data_values::Value;
97    use crate::evaluators::random_evaluate;
98    use crate::graphs::create_context;
99
100    #[test]
101    fn test_mux_bits() {
102        || -> Result<()> {
103            let c = create_context()?;
104            let g = c.create_graph()?;
105            let i_flag = g.input(Type::Scalar(BIT))?;
106            let i_choice1 = g.input(Type::Scalar(BIT))?;
107            let i_choice0 = g.input(Type::Scalar(BIT))?;
108            let o = g.custom_op(
109                CustomOperation::new(Mux {}),
110                vec![i_flag, i_choice1, i_choice0],
111            )?;
112            g.set_output_node(o)?;
113            g.finalize()?;
114            c.set_main_graph(g.clone())?;
115            c.finalize()?;
116            let mapped_c = run_instantiation_pass(c)?;
117            for flag in vec![0, 1] {
118                for x1 in vec![0, 1] {
119                    for x0 in vec![0, 1] {
120                        let expected_result = if flag != 0 { x1 } else { x0 };
121                        let result = random_evaluate(
122                            mapped_c.mappings.get_graph(g.clone()),
123                            vec![
124                                Value::from_scalar(flag, BIT)?,
125                                Value::from_scalar(x1, BIT)?,
126                                Value::from_scalar(x0, BIT)?,
127                            ],
128                        )?
129                        .to_u8(BIT)?;
130                        assert_eq!(result, expected_result);
131                    }
132                }
133            }
134            Ok(())
135        }()
136        .unwrap();
137    }
138
139    #[test]
140    fn test_mux_broadcast() {
141        || -> Result<()> {
142            let c = create_context()?;
143            let g = c.create_graph()?;
144            let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
145            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
146            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
147            let o = g.custom_op(
148                CustomOperation::new(Mux {}),
149                vec![i_flag, i_choice1, i_choice0],
150            )?;
151            g.set_output_node(o)?;
152            g.finalize()?;
153            c.set_main_graph(g.clone())?;
154            c.finalize()?;
155            let mapped_c = run_instantiation_pass(c)?;
156            let a_flag = vec![0, 1, 1];
157            let a_1 = vec![0, 1, 0, 0, 1];
158            let a_0 = vec![1, 0, 0, 1, 0, 1];
159            let v_flag = Value::from_flattened_array(&a_flag, BIT)?;
160            let v_1 = Value::from_flattened_array(&a_1, BIT)?;
161            let v_0 = Value::from_flattened_array(&a_0, BIT)?;
162            let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_flag, v_1, v_0])?
163                .to_flattened_array_u64(Type::Array(vec![6, 3, 5], BIT))?;
164            for i in 0..6 {
165                for j in 0..3 {
166                    for k in 0..5 {
167                        let r = result[i * 15 + j * 5 + k];
168                        let u = a_flag[j];
169                        let v = a_1[k];
170                        let w = a_0[i];
171                        let er = if u != 0 { v } else { w };
172                        assert_eq!(r, er);
173                    }
174                }
175            }
176            Ok(())
177        }()
178        .unwrap();
179    }
180
181    #[test]
182    fn test_malformed() {
183        || -> Result<()> {
184            let c = create_context()?;
185            let g = c.create_graph()?;
186            let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
187            let i_choice1 = g.input(Type::Array(vec![1, 5, 1], UINT32))?;
188            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], INT32))?;
189            assert!(g
190                .custom_op(
191                    CustomOperation::new(Mux {}),
192                    vec![i_flag, i_choice1, i_choice0]
193                )
194                .is_err());
195            Ok(())
196        }()
197        .unwrap();
198
199        || -> Result<()> {
200            let c = create_context()?;
201            let g = c.create_graph()?;
202            let i_flag = g.input(Type::Array(vec![3, 1], INT32))?;
203            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
204            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
205            assert!(g
206                .custom_op(
207                    CustomOperation::new(Mux {}),
208                    vec![i_flag, i_choice1, i_choice0]
209                )
210                .is_err());
211            Ok(())
212        }()
213        .unwrap();
214
215        || -> Result<()> {
216            let c = create_context()?;
217            let g = c.create_graph()?;
218            let i_flag = g.input(Type::Array(vec![3, 7], BIT))?;
219            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
220            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
221            assert!(g
222                .custom_op(
223                    CustomOperation::new(Mux {}),
224                    vec![i_flag, i_choice1, i_choice0]
225                )
226                .is_err());
227            Ok(())
228        }()
229        .unwrap();
230
231        || -> Result<()> {
232            let c = create_context()?;
233            let g = c.create_graph()?;
234            let i_flag = g.input(Type::Array(vec![3, 7], BIT))?;
235            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
236            let _i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
237            assert!(g
238                .custom_op(CustomOperation::new(Mux {}), vec![i_flag, i_choice1])
239                .is_err());
240            Ok(())
241        }()
242        .unwrap();
243
244        || -> Result<()> {
245            let c = create_context()?;
246            let g = c.create_graph()?;
247            let i_flag = g.input(Type::Tuple(vec![]))?;
248            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
249            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
250            assert!(g
251                .custom_op(
252                    CustomOperation::new(Mux {}),
253                    vec![i_flag, i_choice1, i_choice0]
254                )
255                .is_err());
256            Ok(())
257        }()
258        .unwrap();
259
260        || -> Result<()> {
261            let c = create_context()?;
262            let g = c.create_graph()?;
263            let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
264            let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
265            let i_choice0 = g.input(Type::Array(vec![6, 1, 1], INT32))?;
266            assert!(g
267                .custom_op(
268                    CustomOperation::new(Mux {}),
269                    vec![i_flag, i_choice1, i_choice0]
270                )
271                .is_err());
272            Ok(())
273        }()
274        .unwrap();
275    }
276}