1use crate::arity::*;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::i256;
29use arrow_buffer::ArrowNativeType;
30use arrow_schema::*;
31use std::cmp::min;
32use std::sync::Arc;
33
34fn get_fixed_point_info(
37    left: (u8, i8),
38    right: (u8, i8),
39    required_scale: i8,
40) -> Result<(u8, i8, i256), ArrowError> {
41    let product_scale = left.1 + right.1;
42    let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44    if required_scale > product_scale {
45        return Err(ArrowError::ComputeError(format!(
46            "Required scale {} is greater than product scale {}",
47            required_scale, product_scale
48        )));
49    }
50
51    let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
52
53    Ok((precision, product_scale, divisor))
54}
55
56pub fn multiply_fixed_point_dyn(
71    left: &dyn Array,
72    right: &dyn Array,
73    required_scale: i8,
74) -> Result<ArrayRef, ArrowError> {
75    match (left.data_type(), right.data_type()) {
76        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
77            let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
78            let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
79
80            multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
81        }
82        (_, _) => Err(ArrowError::CastError(format!(
83            "Unsupported data type {}, {}",
84            left.data_type(),
85            right.data_type()
86        ))),
87    }
88}
89
90pub fn multiply_fixed_point_checked(
103    left: &PrimitiveArray<Decimal128Type>,
104    right: &PrimitiveArray<Decimal128Type>,
105    required_scale: i8,
106) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
107    let (precision, product_scale, divisor) = get_fixed_point_info(
108        (left.precision(), left.scale()),
109        (right.precision(), right.scale()),
110        required_scale,
111    )?;
112
113    if required_scale == product_scale {
114        return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
115            .with_precision_and_scale(precision, required_scale);
116    }
117
118    try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
119        let a = i256::from_i128(a);
120        let b = i256::from_i128(b);
121
122        let mut mul = a.wrapping_mul(b);
123        mul = divide_and_round::<Decimal256Type>(mul, divisor);
124        mul.to_i128().ok_or_else(|| {
125            ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b))
126        })
127    })
128    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
129}
130
131pub fn multiply_fixed_point(
147    left: &PrimitiveArray<Decimal128Type>,
148    right: &PrimitiveArray<Decimal128Type>,
149    required_scale: i8,
150) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
151    let (precision, product_scale, divisor) = get_fixed_point_info(
152        (left.precision(), left.scale()),
153        (right.precision(), right.scale()),
154        required_scale,
155    )?;
156
157    if required_scale == product_scale {
158        return binary(left, right, |a, b| a.mul_wrapping(b))?
159            .with_precision_and_scale(precision, required_scale);
160    }
161
162    binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
163        let a = i256::from_i128(a);
164        let b = i256::from_i128(b);
165
166        let mut mul = a.wrapping_mul(b);
167        mul = divide_and_round::<Decimal256Type>(mul, divisor);
168        mul.as_i128()
169    })
170    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
171}
172
173fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
175where
176    I: DecimalType,
177    I::Native: ArrowNativeTypeOp,
178{
179    let d = input.div_wrapping(div);
180    let r = input.mod_wrapping(div);
181
182    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
183    let half_neg = half.neg_wrapping();
184
185    match input >= I::Native::ZERO {
187        true if r >= half => d.add_wrapping(I::Native::ONE),
188        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
189        _ => d,
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::numeric::mul;
197
198    #[test]
199    fn test_decimal_multiply_allow_precision_loss() {
200        let a = Decimal128Array::from(vec![123456789000000000000000000])
203            .with_precision_and_scale(38, 18)
204            .unwrap();
205
206        let b = Decimal128Array::from(vec![10000000000000000000])
208            .with_precision_and_scale(38, 18)
209            .unwrap();
210
211        let err = mul(&a, &b).unwrap_err();
212        assert!(err
213            .to_string()
214            .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000"));
215
216        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
218        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
220            .with_precision_and_scale(38, 28)
221            .unwrap();
222
223        assert_eq!(&expected, &result);
224        assert_eq!(
225            result.value_as_string(0),
226            "1234567890.0000000000000000000000000000"
227        );
228
229        let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
232            .with_precision_and_scale(38, 18)
233            .unwrap();
234
235        let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
237            .with_precision_and_scale(38, 18)
238            .unwrap();
239
240        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
241        let expected = Decimal128Array::from(vec![
247            15555555556,
248            13854595272345679012071330528765432099,
249            15555555556,
250        ])
251        .with_precision_and_scale(38, 28)
252        .unwrap();
253
254        assert_eq!(&expected, &result);
255
256        assert_eq!(
258            result.value_as_string(1),
259            "1385459527.2345679012071330528765432099"
260        );
261        assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
262        assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
263
264        let a = Decimal128Array::from(vec![1230])
265            .with_precision_and_scale(4, 2)
266            .unwrap();
267
268        let b = Decimal128Array::from(vec![1000])
269            .with_precision_and_scale(4, 2)
270            .unwrap();
271
272        let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
274        assert_eq!(result.precision(), 9);
275        assert_eq!(result.scale(), 4);
276
277        let expected = mul(&a, &b).unwrap();
278        assert_eq!(expected.as_ref(), &result);
279
280        let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
282        assert!(result
283            .to_string()
284            .contains("Required scale 5 is greater than product scale 4"));
285    }
286
287    #[test]
288    fn test_decimal_multiply_allow_precision_loss_overflow() {
289        let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
291            .with_precision_and_scale(38, 18)
292            .unwrap();
293
294        let b = Decimal128Array::from(vec![9999999999910000000000000000000])
296            .with_precision_and_scale(38, 18)
297            .unwrap();
298
299        let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
300        assert!(err.to_string().contains(
301            "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
302        ));
303
304        let result = multiply_fixed_point(&a, &b, 28).unwrap();
305        let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
306            .with_precision_and_scale(38, 28)
307            .unwrap();
308
309        assert_eq!(&expected, &result);
310    }
311
312    #[test]
313    fn test_decimal_multiply_fixed_point() {
314        let a = Decimal128Array::from(vec![123456789000000000000000000])
316            .with_precision_and_scale(38, 18)
317            .unwrap();
318
319        let b = Decimal128Array::from(vec![10000000000000000000])
321            .with_precision_and_scale(38, 18)
322            .unwrap();
323
324        let err = mul(&a, &b).unwrap_err();
326        assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000");
327
328        let result = multiply_fixed_point(&a, &b, 28).unwrap();
330        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
332            .with_precision_and_scale(38, 28)
333            .unwrap();
334
335        assert_eq!(&expected, &result);
336        assert_eq!(
337            result.value_as_string(0),
338            "1234567890.0000000000000000000000000000"
339        );
340    }
341}