1use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, BIT, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, SliceElement};
6use crate::ops::utils::{ones_like, reduce_mul, unsqueeze};
7use crate::typed_value::TypedValue;
8use crate::typed_value_operations::TypedValueArrayOperations;
9
10use serde::{Deserialize, Serialize};
11
12use super::fixed_precision_config::FixedPrecisionConfig;
13
14#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct FixedMultiply {
45 pub config: FixedPrecisionConfig,
46}
47
48#[typetag::serde]
49impl CustomOperationBody for FixedMultiply {
50 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
51 if arguments_types.len() != 2 {
52 return Err(runtime_error!("FixedMultiply takes two arguments"));
53 }
54 for arg in arguments_types.iter() {
55 if !arg.is_array() && !arg.is_scalar() {
56 return Err(runtime_error!(
57 "FixedMultiply expects scalar or array, got {:?}",
58 arg
59 ));
60 }
61 if arg.get_scalar_type() != INT64 {
62 return Err(runtime_error!("FixedMultiply expects INT64, got {:?}", arg));
63 }
64 }
65
66 let g = context.create_graph()?;
67 let a = g.input(arguments_types[0].clone())?;
68 let b = g.input(arguments_types[1].clone())?;
69 let mut a_times_b = a.multiply(b.clone())?;
70 if self.config.debug {
71 a_times_b = g.assert(
72 "Integer overflow".into(),
73 is_multiplication_safe_from_overflow(a, b)?,
74 a_times_b,
75 )?
76 }
77 let a_times_b_shifted = a_times_b.truncate(self.config.denominator())?;
78 a_times_b_shifted.set_as_output()?;
79 g.finalize()?;
80 Ok(g)
81 }
82
83 fn get_name(&self) -> String {
84 format!("FixedMultiply({})", self.config.fractional_bits)
85 }
86}
87
88pub fn is_multiplication_safe_from_overflow(x: Node, y: Node) -> Result<Node> {
91 let x_bits = x.a2b()?;
92 let y_bits = y.a2b()?;
93 let msb_x = x_bits.get_slice(vec![
97 SliceElement::Ellipsis,
98 SliceElement::SubArray(Some(-1), None, None),
99 ])?;
100 let x_bits = x_bits.add(msb_x)?;
101 let msb_y = y_bits.get_slice(vec![
102 SliceElement::Ellipsis,
103 SliceElement::SubArray(Some(-1), None, None),
104 ])?;
105 let y_bits = y_bits.add(msb_y)?;
106 let xy_bits = unsqueeze(x_bits, -1)?.multiply(unsqueeze(y_bits, -2)?)?;
116 let mut mask_arr = ndarray::Array2::zeros((64, 64));
118 for i in 0..64 {
119 for j in 0..64 {
120 if i + j >= 56 {
121 mask_arr[[i, j]] = 1;
122 }
123 }
124 }
125 let mask_tv = TypedValue::from_ndarray(mask_arr.into_dyn(), BIT)?;
126 let g = x.get_graph();
127 let mask = g.constant(mask_tv.t, mask_tv.value)?;
128 let xy_bits = xy_bits.multiply(mask)?;
129 let one = ones_like(xy_bits.clone())?;
132 let not_xy_bits = xy_bits.add(one)?;
133 let mut reduction_result = not_xy_bits;
134 while reduction_result.get_type()?.is_array() {
135 reduction_result = reduce_mul(reduction_result)?;
136 }
137 Ok(reduction_result)
139}
140
141#[cfg(test)]
142mod tests {
143 use ndarray::array;
144
145 use super::*;
146 use crate::custom_ops::run_instantiation_pass;
147 use crate::custom_ops::CustomOperation;
148 use crate::evaluators::random_evaluate;
149 use crate::graphs::create_context;
150 use crate::typed_value_operations::ToNdarray;
151 use crate::typed_value_operations::TypedValueArrayOperations;
152
153 fn multiply_helper(
154 a: TypedValue,
155 b: TypedValue,
156 config: FixedPrecisionConfig,
157 ) -> Result<TypedValue> {
158 let c = create_context()?;
159 let g = c.create_graph()?;
160 let node_a = g.input(a.t.clone())?;
161 let node_b = g.input(b.t.clone())?;
162 let o = g.custom_op(
163 CustomOperation::new(FixedMultiply { config }),
164 vec![node_a, node_b],
165 )?;
166 let t = o.get_type()?;
167 o.set_as_output()?;
168 g.finalize()?;
169 g.set_as_main()?;
170 c.finalize()?;
171 let mapped_c = run_instantiation_pass(c)?;
172 let result = random_evaluate(
173 mapped_c.get_context().get_main_graph()?,
174 vec![a.value, b.value],
175 )?;
176 TypedValue::new(t, result)
177 }
178
179 #[test]
180 fn test_multiply_scalars() -> Result<()> {
181 let int_config = FixedPrecisionConfig {
182 fractional_bits: 0,
183 debug: false,
184 };
185 let two_times_two = multiply_helper(
186 TypedValue::from_scalar(2, INT64)?,
187 TypedValue::from_scalar(2, INT64)?,
188 int_config,
189 )?
190 .to_u64()?;
191 assert_eq!(two_times_two, 4);
192 let five_times_six = multiply_helper(
193 TypedValue::from_scalar(5, INT64)?,
194 TypedValue::from_scalar(6, INT64)?,
195 int_config,
196 )?
197 .to_u64()?;
198 assert_eq!(five_times_six, 30);
199
200 let fixed_config = FixedPrecisionConfig {
201 fractional_bits: 15,
202 debug: false,
203 };
204 let two_times_two = multiply_helper(
205 TypedValue::from_scalar(2 << 15, INT64)?,
206 TypedValue::from_scalar(2 << 15, INT64)?,
207 fixed_config,
208 )?
209 .to_u64()?;
210 assert_eq!(two_times_two, 4 << 15);
211 let five_times_six = multiply_helper(
212 TypedValue::from_scalar(5 << 15, INT64)?,
213 TypedValue::from_scalar(6 << 15, INT64)?,
214 fixed_config,
215 )?
216 .to_u64()?;
217 assert_eq!(five_times_six, 30 << 15);
218 Ok(())
219 }
220
221 #[test]
222 fn test_multiply_negative() -> Result<()> {
223 let fixed_config = FixedPrecisionConfig {
224 fractional_bits: 15,
225 debug: false,
226 };
227 let two_times_minus_three = multiply_helper(
228 TypedValue::from_scalar(2 << 15, INT64)?,
229 TypedValue::from_scalar(-3 << 15, INT64)?,
230 fixed_config,
231 )?
232 .to_u64()?;
233 assert_eq!(two_times_minus_three as i64, -6 << 15);
234 Ok(())
235 }
236
237 #[test]
238 fn test_multiply_arrays() -> Result<()> {
239 let fixed_config = FixedPrecisionConfig {
240 fractional_bits: 15,
241 debug: false,
242 };
243 let two_times_x = ToNdarray::<i64>::to_ndarray(&multiply_helper(
244 TypedValue::from_scalar(2 << 15, INT64)?,
245 TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
246 fixed_config,
247 )?)?;
248 assert_eq!(two_times_x.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
249 let x_times_two = ToNdarray::<i64>::to_ndarray(&multiply_helper(
250 TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
251 TypedValue::from_scalar(2 << 15, INT64)?,
252 fixed_config,
253 )?)?;
254 assert_eq!(x_times_two.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
255 let x_times_y = ToNdarray::<i64>::to_ndarray(&multiply_helper(
256 TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
257 TypedValue::from_ndarray(array![4 << 15, 5 << 15, 6 << 15].into_dyn(), INT64)?,
258 fixed_config,
259 )?)?;
260 assert_eq!(x_times_y.into_raw_vec(), vec![4 << 15, 10 << 15, 18 << 15]);
261 Ok(())
262 }
263
264 #[test]
265 fn test_multiply_broadcast() -> Result<()> {
266 let fixed_config = FixedPrecisionConfig {
267 fractional_bits: 15,
268 debug: false,
269 };
270 let x_times_y = ToNdarray::<i64>::to_ndarray(&multiply_helper(
271 TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
272 TypedValue::from_ndarray(array![2 << 15].into_dyn(), INT64)?,
273 fixed_config,
274 )?)?;
275 assert_eq!(x_times_y.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
276 Ok(())
277 }
278
279 #[test]
280 fn test_multiply_debug_mode_success() -> Result<()> {
281 let fixed_config = FixedPrecisionConfig {
282 fractional_bits: 15,
283 debug: true,
284 };
285 let two_times_two = multiply_helper(
286 TypedValue::from_scalar(2 << 15, INT64)?,
287 TypedValue::from_scalar(2 << 15, INT64)?,
288 fixed_config,
289 )?
290 .to_u64()?;
291 assert_eq!(two_times_two, 4 << 15);
292 Ok(())
293 }
294
295 #[test]
296 fn test_multiply_debug_mode_fail() -> Result<()> {
297 let fixed_config = FixedPrecisionConfig {
298 fractional_bits: 15,
299 debug: true,
300 };
301 let err = multiply_helper(
302 TypedValue::from_scalar(1 << 30, INT64)?,
303 TypedValue::from_scalar(1 << 30, INT64)?,
304 fixed_config,
305 );
306 assert!(err.is_err());
307 Ok(())
308 }
309
310 fn overflow_helper(a: TypedValue, b: TypedValue) -> Result<bool> {
311 let c = create_context()?;
312 let g = c.create_graph()?;
313 let node_a = g.input(a.t.clone())?;
314 let node_b = g.input(b.t.clone())?;
315 let o = is_multiplication_safe_from_overflow(node_a, node_b)?;
316 let t = o.get_type()?;
317 o.set_as_output()?;
318 g.finalize()?;
319 g.set_as_main()?;
320 c.finalize()?;
321 let mapped_c = run_instantiation_pass(c)?;
322 let result = random_evaluate(
323 mapped_c.get_context().get_main_graph()?,
324 vec![a.value, b.value],
325 )?;
326 Ok(TypedValue::new(t, result)?.to_u64()? > 0)
327 }
328
329 #[test]
330 fn test_overflow_check_success() -> Result<()> {
331 let one = TypedValue::from_scalar(1, INT64)?;
332 let two = TypedValue::from_scalar(2, INT64)?;
333 let small_number = TypedValue::from_scalar(4243, INT64)?;
334 let two_to_twenty_five = TypedValue::from_scalar(1 << 25, INT64)?;
335 let medium_number = TypedValue::from_scalar(71479832, INT64)?;
336 let two_to_thirty = TypedValue::from_scalar(1 << 30, INT64)?;
337 let two_to_fifty = TypedValue::from_scalar(1_i64 << 50, INT64)?;
338 let minus_two = TypedValue::from_scalar(-2, INT64)?;
339 let minus_one = TypedValue::from_scalar(-1, INT64)?;
340 assert!(overflow_helper(one.clone(), two.clone())?);
341 assert!(overflow_helper(one.clone(), minus_two.clone())?);
342 assert!(overflow_helper(minus_one.clone(), minus_one)?);
343 assert!(overflow_helper(small_number.clone(), small_number.clone())?);
344 assert!(overflow_helper(small_number, medium_number.clone())?);
345 assert!(overflow_helper(
346 two_to_twenty_five.clone(),
347 two_to_twenty_five
348 )?);
349 assert!(overflow_helper(medium_number.clone(), medium_number)?);
350 assert!(overflow_helper(two, two_to_thirty.clone())?);
351 assert!(overflow_helper(minus_two, two_to_thirty)?);
352 assert!(overflow_helper(one, two_to_fifty)?);
353 Ok(())
354 }
355
356 #[test]
357 fn test_overflow_check_fail() -> Result<()> {
358 let two_to_twenty_five = TypedValue::from_scalar(1 << 25, INT64)?;
359 let two_to_thirty = TypedValue::from_scalar(1 << 30, INT64)?;
360 let two_to_fifty = TypedValue::from_scalar(1_i64 << 50, INT64)?;
361 let large_number = TypedValue::from_scalar(2363897937439121_i64, INT64)?;
362 let minus_two_to_thirty = TypedValue::from_scalar(-1 << 30, INT64)?;
363 assert!(!overflow_helper(
364 two_to_twenty_five.clone(),
365 two_to_fifty.clone()
366 )?);
367 assert!(!overflow_helper(
368 two_to_thirty.clone(),
369 two_to_thirty.clone()
370 )?);
371 assert!(!overflow_helper(
372 two_to_thirty.clone(),
373 large_number.clone()
374 )?);
375 assert!(!overflow_helper(large_number.clone(), large_number)?);
376 assert!(!overflow_helper(
377 minus_two_to_thirty,
378 two_to_thirty.clone()
379 )?);
380 assert!(!overflow_helper(two_to_thirty, two_to_fifty)?);
381 Ok(())
382 }
383
384 #[test]
385 fn test_overflow_check_success_arrays() -> Result<()> {
386 let x = TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?;
387 let y = TypedValue::from_scalar(2, INT64)?;
388 assert!(overflow_helper(x.clone(), y.clone())?);
389 let x = TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?;
390 let y = TypedValue::from_ndarray(array![10 << 15, 20 << 15, 30 << 15].into_dyn(), INT64)?;
391 assert!(overflow_helper(x.clone(), y.clone())?);
392 Ok(())
393 }
394
395 #[test]
396 fn test_overflow_check_fail_arrays() -> Result<()> {
397 let x = TypedValue::from_ndarray(array![1 << 25, 1 << 26, 1 << 27].into_dyn(), INT64)?;
398 let y = TypedValue::from_scalar(1 << 30, INT64)?;
399 assert!(!overflow_helper(x.clone(), y.clone())?);
400 let x = TypedValue::from_ndarray(array![1 << 25, 1 << 26, 1 << 27].into_dyn(), INT64)?;
401 let y = TypedValue::from_ndarray(array![1 << 28, 1 << 29, 1 << 30].into_dyn(), INT64)?;
402 assert!(!overflow_helper(x.clone(), y.clone())?);
403 Ok(())
404 }
405}