datafusion_functions/math/
gcd.rs1use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::error::ArrowError;
22use std::mem::swap;
23use std::sync::Arc;
24
25use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
26use datafusion_expr::{
27 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28 Volatility,
29};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33 doc_section(label = "Math Functions"),
34 description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
35 syntax_example = "gcd(expression_x, expression_y)",
36 sql_example = r#"```sql
37> SELECT gcd(48, 18);
38+------------+
39| gcd(48,18) |
40+------------+
41| 6 |
42+------------+
43```"#,
44 standard_argument(name = "expression_x", prefix = "First numeric"),
45 standard_argument(name = "expression_y", prefix = "Second numeric")
46)]
47#[derive(Debug, PartialEq, Eq, Hash)]
48pub struct GcdFunc {
49 signature: Signature,
50}
51
52impl Default for GcdFunc {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl GcdFunc {
59 pub fn new() -> Self {
60 Self {
61 signature: Signature::uniform(
62 2,
63 vec![DataType::Int64],
64 Volatility::Immutable,
65 ),
66 }
67 }
68}
69
70impl ScalarUDFImpl for GcdFunc {
71 fn name(&self) -> &str {
72 "gcd"
73 }
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80 Ok(DataType::Int64)
81 }
82
83 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
85 internal_datafusion_err!("Expected 2 arguments for function gcd")
86 })?;
87
88 match args {
89 [ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
90 compute_gcd_for_arrays(&a, &b)
91 }
92 [
93 ColumnarValue::Scalar(ScalarValue::Int64(a)),
94 ColumnarValue::Scalar(ScalarValue::Int64(b)),
95 ] => match (a, b) {
96 (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
97 Some(compute_gcd(a, b)?),
98 ))),
99 _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
100 },
101 [
102 ColumnarValue::Array(a),
103 ColumnarValue::Scalar(ScalarValue::Int64(b)),
104 ] => compute_gcd_with_scalar(&a, b),
105 [
106 ColumnarValue::Scalar(ScalarValue::Int64(a)),
107 ColumnarValue::Array(b),
108 ] => compute_gcd_with_scalar(&b, a),
109 _ => exec_err!("Unsupported argument types for function gcd"),
110 }
111 }
112
113 fn documentation(&self) -> Option<&Documentation> {
114 self.doc()
115 }
116}
117
118fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
119 let a = a.as_primitive::<Int64Type>();
120 let b = b.as_primitive::<Int64Type>();
121 try_binary(a, b, compute_gcd)
122 .map(|arr: PrimitiveArray<Int64Type>| {
123 ColumnarValue::Array(Arc::new(arr) as ArrayRef)
124 })
125 .map_err(Into::into) }
127
128fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
129 let prim = arr.as_primitive::<Int64Type>();
130 match scalar {
131 Some(scalar_value) if scalar_value != 0 && scalar_value != i64::MIN => {
132 let sv = scalar_value.unsigned_abs();
138 let result: PrimitiveArray<Int64Type> =
139 prim.unary(|val| unsigned_gcd(val.unsigned_abs(), sv) as i64);
140 Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
141 }
142 Some(scalar_value) => {
143 let result: PrimitiveArray<Int64Type> =
144 prim.try_unary(|val| compute_gcd(val, scalar_value))?;
145 Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
146 }
147 None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
148 }
149}
150
151pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
153 if a == 0 {
154 return b;
155 }
156 if b == 0 {
157 return a;
158 }
159
160 let shift = (a | b).trailing_zeros();
161 a >>= a.trailing_zeros();
162 loop {
163 b >>= b.trailing_zeros();
164 if a > b {
165 swap(&mut a, &mut b);
166 }
167 b -= a;
168 if b == 0 {
169 return a << shift;
170 }
171 }
172}
173
174pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
176 let a = x.unsigned_abs();
177 let b = y.unsigned_abs();
178 let r = unsigned_gcd(a, b);
179 r.try_into().map_err(|_| {
182 ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
183 })
184}