Skip to main content

datafusion_functions/math/
gcd.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::error::ArrowError;
22use std::mem::swap;
23use std::sync::Arc;
24
25use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
26use datafusion_expr::{
27    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    Volatility,
29};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33    doc_section(label = "Math Functions"),
34    description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
35    syntax_example = "gcd(expression_x, expression_y)",
36    sql_example = r#"```sql
37> SELECT gcd(48, 18);
38+------------+
39| gcd(48,18) |
40+------------+
41| 6          |
42+------------+
43```"#,
44    standard_argument(name = "expression_x", prefix = "First numeric"),
45    standard_argument(name = "expression_y", prefix = "Second numeric")
46)]
47#[derive(Debug, PartialEq, Eq, Hash)]
48pub struct GcdFunc {
49    signature: Signature,
50}
51
52impl Default for GcdFunc {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl GcdFunc {
59    pub fn new() -> Self {
60        Self {
61            signature: Signature::uniform(
62                2,
63                vec![DataType::Int64],
64                Volatility::Immutable,
65            ),
66        }
67    }
68}
69
70impl ScalarUDFImpl for GcdFunc {
71    fn name(&self) -> &str {
72        "gcd"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80        Ok(DataType::Int64)
81    }
82
83    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84        let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
85            internal_datafusion_err!("Expected 2 arguments for function gcd")
86        })?;
87
88        match args {
89            [ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
90                compute_gcd_for_arrays(&a, &b)
91            }
92            [
93                ColumnarValue::Scalar(ScalarValue::Int64(a)),
94                ColumnarValue::Scalar(ScalarValue::Int64(b)),
95            ] => match (a, b) {
96                (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
97                    Some(compute_gcd(a, b)?),
98                ))),
99                _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
100            },
101            [
102                ColumnarValue::Array(a),
103                ColumnarValue::Scalar(ScalarValue::Int64(b)),
104            ] => compute_gcd_with_scalar(&a, b),
105            [
106                ColumnarValue::Scalar(ScalarValue::Int64(a)),
107                ColumnarValue::Array(b),
108            ] => compute_gcd_with_scalar(&b, a),
109            _ => exec_err!("Unsupported argument types for function gcd"),
110        }
111    }
112
113    fn documentation(&self) -> Option<&Documentation> {
114        self.doc()
115    }
116}
117
118fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
119    let a = a.as_primitive::<Int64Type>();
120    let b = b.as_primitive::<Int64Type>();
121    try_binary(a, b, compute_gcd)
122        .map(|arr: PrimitiveArray<Int64Type>| {
123            ColumnarValue::Array(Arc::new(arr) as ArrayRef)
124        })
125        .map_err(Into::into) // convert ArrowError to DataFusionError
126}
127
128fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
129    let prim = arr.as_primitive::<Int64Type>();
130    match scalar {
131        Some(scalar_value) if scalar_value != 0 && scalar_value != i64::MIN => {
132            // The gcd result divides both inputs' absolute values. When the
133            // scalar is neither 0 nor i64::MIN, the gcd's absolute value fits
134            // in i64, so the cast to i64 below cannot overflow. This allows us
135            // to use `unary` instead of `try_unary`, which allows LLVM to
136            // vectorize more effectively.
137            let sv = scalar_value.unsigned_abs();
138            let result: PrimitiveArray<Int64Type> =
139                prim.unary(|val| unsigned_gcd(val.unsigned_abs(), sv) as i64);
140            Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
141        }
142        Some(scalar_value) => {
143            let result: PrimitiveArray<Int64Type> =
144                prim.try_unary(|val| compute_gcd(val, scalar_value))?;
145            Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
146        }
147        None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
148    }
149}
150
151/// Computes gcd of two unsigned integers using Binary GCD algorithm.
152pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
153    if a == 0 {
154        return b;
155    }
156    if b == 0 {
157        return a;
158    }
159
160    let shift = (a | b).trailing_zeros();
161    a >>= a.trailing_zeros();
162    loop {
163        b >>= b.trailing_zeros();
164        if a > b {
165            swap(&mut a, &mut b);
166        }
167        b -= a;
168        if b == 0 {
169            return a << shift;
170        }
171    }
172}
173
174/// Computes greatest common divisor using Binary GCD algorithm.
175pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
176    let a = x.unsigned_abs();
177    let b = y.unsigned_abs();
178    let r = unsigned_gcd(a, b);
179    // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or
180    // gcd(i64::MIN, i64::MIN)), which does not fit into i64.
181    r.try_into().map_err(|_| {
182        ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
183    })
184}