1use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, Type};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node};
6
7use super::comparisons::GreaterThan;
8use super::multiplexer::Mux;
9
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
46pub struct Min {
47 pub signed_comparison: bool,
49}
50
51fn normalize_cmp(cmp: Node) -> Result<Node> {
57 let cmp_type = cmp.get_type()?;
58 let normalized_cmp = if cmp_type.is_array() {
59 let mut new_shape = cmp_type.get_shape();
60 let st = cmp_type.get_scalar_type();
61 new_shape.push(1);
62 cmp.reshape(array_type(new_shape, st))?
63 } else {
64 cmp
65 };
66 Ok(normalized_cmp)
67}
68
69#[typetag::serde]
70impl CustomOperationBody for Min {
71 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
72 if arguments_types.len() != 2 {
73 return Err(runtime_error!("Invalid number of arguments for Min"));
74 }
75 let g = context.create_graph()?;
76 let i1 = g.input(arguments_types[0].clone())?;
77 let i2 = g.input(arguments_types[1].clone())?;
78 let cmp = g.custom_op(
79 CustomOperation::new(GreaterThan {
80 signed_comparison: self.signed_comparison,
81 }),
82 vec![i1.clone(), i2.clone()],
83 )?;
84 let normalized_cmp = normalize_cmp(cmp)?;
85 let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i2, i1])?;
86 g.set_output_node(o)?;
87 g.finalize()?;
88 Ok(g)
89 }
90
91 fn get_name(&self) -> String {
92 format!("Min(signed_comparison={})", self.signed_comparison)
93 }
94}
95
96#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
130pub struct Max {
131 pub signed_comparison: bool,
133}
134
135#[typetag::serde]
136impl CustomOperationBody for Max {
137 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
138 if arguments_types.len() != 2 {
139 return Err(runtime_error!("Invalid number of arguments for Max"));
140 }
141 let g = context.create_graph()?;
142 let i1 = g.input(arguments_types[0].clone())?;
143 let i2 = g.input(arguments_types[1].clone())?;
144 let cmp = g.custom_op(
145 CustomOperation::new(GreaterThan {
146 signed_comparison: self.signed_comparison,
147 }),
148 vec![i1.clone(), i2.clone()],
149 )?;
150 let normalized_cmp = normalize_cmp(cmp)?;
151 let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i1, i2])?;
152 g.set_output_node(o)?;
153 g.finalize()?;
154 Ok(g)
155 }
156
157 fn get_name(&self) -> String {
158 format!("Max(signed_comparison={})", self.signed_comparison)
159 }
160}
161
162#[cfg(test)]
163mod tests {
164
165 use crate::custom_ops::run_instantiation_pass;
166 use crate::data_types::{array_type, scalar_type, BIT, INT64, UINT64};
167 use crate::data_values::Value;
168 use crate::evaluators::random_evaluate;
169 use crate::graphs::create_context;
170 use crate::graphs::util::simple_context;
171
172 use super::*;
173
174 use std::cmp::{max, min};
175
176 #[test]
177 fn test_well_formed() {
178 || -> Result<()> {
179 let test_data: Vec<(u64, u64)> = vec![
180 (31, 32),
181 (76543, 76544),
182 (0, 1),
183 (0, 0),
184 (761523, 761523),
185 (u64::MAX, u64::MAX - 1),
186 (u64::MAX - 761522, u64::MAX - 761523),
187 ];
188 let context = || -> Result<Context> {
189 let c = simple_context(|g| {
190 let i1 = g.input(scalar_type(UINT64))?.a2b()?;
191 let i2 = g.input(scalar_type(UINT64))?.a2b()?;
192 g.create_tuple(vec![
193 g.custom_op(
194 CustomOperation::new(Min {
195 signed_comparison: false,
196 }),
197 vec![i1.clone(), i2.clone()],
198 )?,
199 g.custom_op(
200 CustomOperation::new(Max {
201 signed_comparison: true,
202 }),
203 vec![i1.clone(), i2.clone()],
204 )?,
205 ])
206 })?;
207 let mapped_c = run_instantiation_pass(c)?;
208 Ok(mapped_c.get_context())
209 }()?;
210 for (u, v) in test_data {
211 let minmax = random_evaluate(
212 context.get_main_graph()?,
213 vec![
214 Value::from_scalar(u, UINT64)?,
215 Value::from_scalar(v, UINT64)?,
216 ],
217 )?
218 .to_vector()?;
219 let computed_min = minmax[0].to_u64(UINT64)?;
220 let computed_max = minmax[1].to_i64(INT64)?;
221 assert_eq!(min(u, v), computed_min);
222 assert_eq!(max(u as i64, v as i64), computed_max);
223 }
224 Ok(())
225 }()
226 .unwrap();
227 }
228
229 #[test]
230 fn test_malformed() {
231 || -> Result<()> {
232 let c = create_context()?;
233 let g = c.create_graph()?;
234 let i1 = g.input(scalar_type(UINT64))?.a2b()?;
235 assert!(g
236 .custom_op(
237 CustomOperation::new(Min {
238 signed_comparison: false
239 }),
240 vec![i1.clone()]
241 )
242 .is_err());
243 assert!(g
244 .custom_op(
245 CustomOperation::new(Max {
246 signed_comparison: false
247 }),
248 vec![i1.clone()]
249 )
250 .is_err());
251 Ok(())
252 }()
253 .unwrap();
254 }
255
256 #[test]
257 fn test_vector() {
258 || -> Result<()> {
259 let context = || -> Result<Context> {
260 let c = simple_context(|g| {
261 let i1 = g.input(array_type(vec![1, 3, 64], BIT))?;
262 let i2 = g.input(array_type(vec![3, 1, 64], BIT))?;
263 g.create_tuple(vec![
264 g.custom_op(
265 CustomOperation::new(Min {
266 signed_comparison: false,
267 }),
268 vec![i1.clone(), i2.clone()],
269 )?,
270 g.custom_op(
271 CustomOperation::new(Max {
272 signed_comparison: false,
273 }),
274 vec![i1.clone(), i2.clone()],
275 )?,
276 ])
277 })?;
278 let mapped_c = run_instantiation_pass(c)?;
279 Ok(mapped_c.get_context())
280 }()?;
281 let a = vec![0, 30, 100];
282 let b = vec![10, 50, 150];
283 let v = random_evaluate(
284 context.get_main_graph()?,
285 vec![
286 Value::from_flattened_array(&a, UINT64)?,
287 Value::from_flattened_array(&b, UINT64)?,
288 ],
289 )?
290 .to_vector()?;
291 let min_a_b = v[0].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
292 let max_a_b = v[1].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
293 assert_eq!(min_a_b, vec![0, 10, 10, 0, 30, 50, 0, 30, 100]);
294 assert_eq!(max_a_b, vec![10, 30, 100, 50, 50, 100, 150, 150, 150]);
295 Ok(())
296 }()
297 .unwrap();
298 }
299}