1use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{scalar_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
45pub struct Mux {}
46
47#[typetag::serde]
48impl CustomOperationBody for Mux {
49 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
50 if arguments_types.len() != 3 {
51 return Err(runtime_error!("Invalid number of arguments for Mux"));
52 }
53 let t = arguments_types[0].clone();
54 if !t.is_scalar() && !t.is_array() {
55 return Err(runtime_error!("Flag for Mux must be a scalar or an array"));
56 }
57 if t.get_scalar_type() != BIT {
58 return Err(runtime_error!("Flag for Mux must consist of bits"));
59 }
60 if arguments_types[1].get_scalar_type() != arguments_types[2].get_scalar_type() {
61 return Err(runtime_error!(
62 "Choices for Mux must have the same scalar type"
63 ));
64 }
65
66 let g = context.create_graph()?;
67 let i_flag = g.input(arguments_types[0].clone())?;
68 let i_choice1 = g.input(arguments_types[1].clone())?;
69 let i_choice0 = g.input(arguments_types[2].clone())?;
70 if arguments_types[1].get_scalar_type() == BIT {
71 i_choice0
72 .add(i_flag.multiply(i_choice0.add(i_choice1)?)?)?
73 .set_as_output()?;
74 } else {
75 let i_choice0 = i_choice0.mixed_multiply(i_flag.clone())?;
76 let i_choice1 = i_choice1.mixed_multiply(i_flag.add(g.ones(scalar_type(BIT))?)?)?;
77 i_choice0.add(i_choice1)?.set_as_output()?;
78 }
79 g.finalize()?;
80 Ok(g)
81 }
82
83 fn get_name(&self) -> String {
84 "Mux".to_owned()
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 use crate::custom_ops::run_instantiation_pass;
93 use crate::custom_ops::CustomOperation;
94 use crate::data_types::INT32;
95 use crate::data_types::UINT32;
96 use crate::data_values::Value;
97 use crate::evaluators::random_evaluate;
98 use crate::graphs::create_context;
99
100 #[test]
101 fn test_mux_bits() {
102 || -> Result<()> {
103 let c = create_context()?;
104 let g = c.create_graph()?;
105 let i_flag = g.input(Type::Scalar(BIT))?;
106 let i_choice1 = g.input(Type::Scalar(BIT))?;
107 let i_choice0 = g.input(Type::Scalar(BIT))?;
108 let o = g.custom_op(
109 CustomOperation::new(Mux {}),
110 vec![i_flag, i_choice1, i_choice0],
111 )?;
112 g.set_output_node(o)?;
113 g.finalize()?;
114 c.set_main_graph(g.clone())?;
115 c.finalize()?;
116 let mapped_c = run_instantiation_pass(c)?;
117 for flag in vec![0, 1] {
118 for x1 in vec![0, 1] {
119 for x0 in vec![0, 1] {
120 let expected_result = if flag != 0 { x1 } else { x0 };
121 let result = random_evaluate(
122 mapped_c.mappings.get_graph(g.clone()),
123 vec![
124 Value::from_scalar(flag, BIT)?,
125 Value::from_scalar(x1, BIT)?,
126 Value::from_scalar(x0, BIT)?,
127 ],
128 )?
129 .to_u8(BIT)?;
130 assert_eq!(result, expected_result);
131 }
132 }
133 }
134 Ok(())
135 }()
136 .unwrap();
137 }
138
139 #[test]
140 fn test_mux_broadcast() {
141 || -> Result<()> {
142 let c = create_context()?;
143 let g = c.create_graph()?;
144 let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
145 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
146 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
147 let o = g.custom_op(
148 CustomOperation::new(Mux {}),
149 vec![i_flag, i_choice1, i_choice0],
150 )?;
151 g.set_output_node(o)?;
152 g.finalize()?;
153 c.set_main_graph(g.clone())?;
154 c.finalize()?;
155 let mapped_c = run_instantiation_pass(c)?;
156 let a_flag = vec![0, 1, 1];
157 let a_1 = vec![0, 1, 0, 0, 1];
158 let a_0 = vec![1, 0, 0, 1, 0, 1];
159 let v_flag = Value::from_flattened_array(&a_flag, BIT)?;
160 let v_1 = Value::from_flattened_array(&a_1, BIT)?;
161 let v_0 = Value::from_flattened_array(&a_0, BIT)?;
162 let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_flag, v_1, v_0])?
163 .to_flattened_array_u64(Type::Array(vec![6, 3, 5], BIT))?;
164 for i in 0..6 {
165 for j in 0..3 {
166 for k in 0..5 {
167 let r = result[i * 15 + j * 5 + k];
168 let u = a_flag[j];
169 let v = a_1[k];
170 let w = a_0[i];
171 let er = if u != 0 { v } else { w };
172 assert_eq!(r, er);
173 }
174 }
175 }
176 Ok(())
177 }()
178 .unwrap();
179 }
180
181 #[test]
182 fn test_malformed() {
183 || -> Result<()> {
184 let c = create_context()?;
185 let g = c.create_graph()?;
186 let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
187 let i_choice1 = g.input(Type::Array(vec![1, 5, 1], UINT32))?;
188 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], INT32))?;
189 assert!(g
190 .custom_op(
191 CustomOperation::new(Mux {}),
192 vec![i_flag, i_choice1, i_choice0]
193 )
194 .is_err());
195 Ok(())
196 }()
197 .unwrap();
198
199 || -> Result<()> {
200 let c = create_context()?;
201 let g = c.create_graph()?;
202 let i_flag = g.input(Type::Array(vec![3, 1], INT32))?;
203 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
204 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
205 assert!(g
206 .custom_op(
207 CustomOperation::new(Mux {}),
208 vec![i_flag, i_choice1, i_choice0]
209 )
210 .is_err());
211 Ok(())
212 }()
213 .unwrap();
214
215 || -> Result<()> {
216 let c = create_context()?;
217 let g = c.create_graph()?;
218 let i_flag = g.input(Type::Array(vec![3, 7], BIT))?;
219 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
220 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
221 assert!(g
222 .custom_op(
223 CustomOperation::new(Mux {}),
224 vec![i_flag, i_choice1, i_choice0]
225 )
226 .is_err());
227 Ok(())
228 }()
229 .unwrap();
230
231 || -> Result<()> {
232 let c = create_context()?;
233 let g = c.create_graph()?;
234 let i_flag = g.input(Type::Array(vec![3, 7], BIT))?;
235 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
236 let _i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
237 assert!(g
238 .custom_op(CustomOperation::new(Mux {}), vec![i_flag, i_choice1])
239 .is_err());
240 Ok(())
241 }()
242 .unwrap();
243
244 || -> Result<()> {
245 let c = create_context()?;
246 let g = c.create_graph()?;
247 let i_flag = g.input(Type::Tuple(vec![]))?;
248 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
249 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], BIT))?;
250 assert!(g
251 .custom_op(
252 CustomOperation::new(Mux {}),
253 vec![i_flag, i_choice1, i_choice0]
254 )
255 .is_err());
256 Ok(())
257 }()
258 .unwrap();
259
260 || -> Result<()> {
261 let c = create_context()?;
262 let g = c.create_graph()?;
263 let i_flag = g.input(Type::Array(vec![3, 1], BIT))?;
264 let i_choice1 = g.input(Type::Array(vec![1, 5], BIT))?;
265 let i_choice0 = g.input(Type::Array(vec![6, 1, 1], INT32))?;
266 assert!(g
267 .custom_op(
268 CustomOperation::new(Mux {}),
269 vec![i_flag, i_choice1, i_choice0]
270 )
271 .is_err());
272 Ok(())
273 }()
274 .unwrap();
275 }
276}