1use crate::arity::*;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::ArrowNativeType;
29use arrow_buffer::i256;
30use arrow_schema::*;
31use std::cmp::min;
32use std::sync::Arc;
33
34fn get_fixed_point_info(
37    left: (u8, i8),
38    right: (u8, i8),
39    required_scale: i8,
40) -> Result<(u8, i8, i256), ArrowError> {
41    let product_scale = left.1 + right.1;
42    let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44    if required_scale > product_scale {
45        return Err(ArrowError::ComputeError(format!(
46            "Required scale {required_scale} is greater than product scale {product_scale}",
47        )));
48    }
49
50    let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
51
52    Ok((precision, product_scale, divisor))
53}
54
55pub fn multiply_fixed_point_dyn(
70    left: &dyn Array,
71    right: &dyn Array,
72    required_scale: i8,
73) -> Result<ArrayRef, ArrowError> {
74    match (left.data_type(), right.data_type()) {
75        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
76            let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
77            let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
78
79            multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
80        }
81        (_, _) => Err(ArrowError::CastError(format!(
82            "Unsupported data type {}, {}",
83            left.data_type(),
84            right.data_type()
85        ))),
86    }
87}
88
89pub fn multiply_fixed_point_checked(
102    left: &PrimitiveArray<Decimal128Type>,
103    right: &PrimitiveArray<Decimal128Type>,
104    required_scale: i8,
105) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
106    let (precision, product_scale, divisor) = get_fixed_point_info(
107        (left.precision(), left.scale()),
108        (right.precision(), right.scale()),
109        required_scale,
110    )?;
111
112    if required_scale == product_scale {
113        return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
114            .with_precision_and_scale(precision, required_scale);
115    }
116
117    try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
118        let a = i256::from_i128(a);
119        let b = i256::from_i128(b);
120
121        let mut mul = a.wrapping_mul(b);
122        mul = divide_and_round::<Decimal256Type>(mul, divisor);
123        mul.to_i128().ok_or_else(|| {
124            ArrowError::ArithmeticOverflow(format!("Overflow happened on: {a:?} * {b:?}"))
125        })
126    })
127    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
128}
129
130pub fn multiply_fixed_point(
146    left: &PrimitiveArray<Decimal128Type>,
147    right: &PrimitiveArray<Decimal128Type>,
148    required_scale: i8,
149) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
150    let (precision, product_scale, divisor) = get_fixed_point_info(
151        (left.precision(), left.scale()),
152        (right.precision(), right.scale()),
153        required_scale,
154    )?;
155
156    if required_scale == product_scale {
157        return binary(left, right, |a, b| a.mul_wrapping(b))?
158            .with_precision_and_scale(precision, required_scale);
159    }
160
161    binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
162        let a = i256::from_i128(a);
163        let b = i256::from_i128(b);
164
165        let mut mul = a.wrapping_mul(b);
166        mul = divide_and_round::<Decimal256Type>(mul, divisor);
167        mul.as_i128()
168    })
169    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
170}
171
172fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
174where
175    I: DecimalType,
176    I::Native: ArrowNativeTypeOp,
177{
178    let d = input.div_wrapping(div);
179    let r = input.mod_wrapping(div);
180
181    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
182    let half_neg = half.neg_wrapping();
183
184    match input >= I::Native::ZERO {
186        true if r >= half => d.add_wrapping(I::Native::ONE),
187        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
188        _ => d,
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::numeric::mul;
196
197    #[test]
198    fn test_decimal_multiply_allow_precision_loss() {
199        let a = Decimal128Array::from(vec![123456789000000000000000000])
202            .with_precision_and_scale(38, 18)
203            .unwrap();
204
205        let b = Decimal128Array::from(vec![10000000000000000000])
207            .with_precision_and_scale(38, 18)
208            .unwrap();
209
210        let err = mul(&a, &b).unwrap_err();
211        assert!(
212            err.to_string().contains(
213                "Overflow happened on: 123456789000000000000000000 * 10000000000000000000"
214            )
215        );
216
217        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
219        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
221            .with_precision_and_scale(38, 28)
222            .unwrap();
223
224        assert_eq!(&expected, &result);
225        assert_eq!(
226            result.value_as_string(0),
227            "1234567890.0000000000000000000000000000"
228        );
229
230        let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
233            .with_precision_and_scale(38, 18)
234            .unwrap();
235
236        let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
238            .with_precision_and_scale(38, 18)
239            .unwrap();
240
241        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
242        let expected = Decimal128Array::from(vec![
248            15555555556,
249            13854595272345679012071330528765432099,
250            15555555556,
251        ])
252        .with_precision_and_scale(38, 28)
253        .unwrap();
254
255        assert_eq!(&expected, &result);
256
257        assert_eq!(
259            result.value_as_string(1),
260            "1385459527.2345679012071330528765432099"
261        );
262        assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
263        assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
264
265        let a = Decimal128Array::from(vec![1230])
266            .with_precision_and_scale(4, 2)
267            .unwrap();
268
269        let b = Decimal128Array::from(vec![1000])
270            .with_precision_and_scale(4, 2)
271            .unwrap();
272
273        let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
275        assert_eq!(result.precision(), 9);
276        assert_eq!(result.scale(), 4);
277
278        let expected = mul(&a, &b).unwrap();
279        assert_eq!(expected.as_ref(), &result);
280
281        let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
283        assert!(
284            result
285                .to_string()
286                .contains("Required scale 5 is greater than product scale 4")
287        );
288    }
289
290    #[test]
291    fn test_decimal_multiply_allow_precision_loss_overflow() {
292        let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
294            .with_precision_and_scale(38, 18)
295            .unwrap();
296
297        let b = Decimal128Array::from(vec![9999999999910000000000000000000])
299            .with_precision_and_scale(38, 18)
300            .unwrap();
301
302        let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
303        assert!(err.to_string().contains(
304            "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
305        ));
306
307        let result = multiply_fixed_point(&a, &b, 28).unwrap();
308        let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
309            .with_precision_and_scale(38, 28)
310            .unwrap();
311
312        assert_eq!(&expected, &result);
313    }
314
315    #[test]
316    fn test_decimal_multiply_fixed_point() {
317        let a = Decimal128Array::from(vec![123456789000000000000000000])
319            .with_precision_and_scale(38, 18)
320            .unwrap();
321
322        let b = Decimal128Array::from(vec![10000000000000000000])
324            .with_precision_and_scale(38, 18)
325            .unwrap();
326
327        let err = mul(&a, &b).unwrap_err();
329        assert_eq!(
330            err.to_string(),
331            "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"
332        );
333
334        let result = multiply_fixed_point(&a, &b, 28).unwrap();
336        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
338            .with_precision_and_scale(38, 28)
339            .unwrap();
340
341        assert_eq!(&expected, &result);
342        assert_eq!(
343            result.value_as_string(0),
344            "1234567890.0000000000000000000000000000"
345        );
346    }
347}