datafusion_functions/math/
gcd.rs1use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray, new_null_array};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::error::ArrowError;
22use std::any::Any;
23use std::mem::swap;
24use std::sync::Arc;
25
26use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34 doc_section(label = "Math Functions"),
35 description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
36 syntax_example = "gcd(expression_x, expression_y)",
37 sql_example = r#"```sql
38> SELECT gcd(48, 18);
39+------------+
40| gcd(48,18) |
41+------------+
42| 6 |
43+------------+
44```"#,
45 standard_argument(name = "expression_x", prefix = "First numeric"),
46 standard_argument(name = "expression_y", prefix = "Second numeric")
47)]
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct GcdFunc {
50 signature: Signature,
51}
52
53impl Default for GcdFunc {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl GcdFunc {
60 pub fn new() -> Self {
61 Self {
62 signature: Signature::uniform(
63 2,
64 vec![DataType::Int64],
65 Volatility::Immutable,
66 ),
67 }
68 }
69}
70
71impl ScalarUDFImpl for GcdFunc {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "gcd"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85 Ok(DataType::Int64)
86 }
87
88 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
90 internal_datafusion_err!("Expected 2 arguments for function gcd")
91 })?;
92
93 match args {
94 [ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
95 compute_gcd_for_arrays(&a, &b)
96 }
97 [
98 ColumnarValue::Scalar(ScalarValue::Int64(a)),
99 ColumnarValue::Scalar(ScalarValue::Int64(b)),
100 ] => match (a, b) {
101 (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
102 Some(compute_gcd(a, b)?),
103 ))),
104 _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
105 },
106 [
107 ColumnarValue::Array(a),
108 ColumnarValue::Scalar(ScalarValue::Int64(b)),
109 ] => compute_gcd_with_scalar(&a, b),
110 [
111 ColumnarValue::Scalar(ScalarValue::Int64(a)),
112 ColumnarValue::Array(b),
113 ] => compute_gcd_with_scalar(&b, a),
114 _ => exec_err!("Unsupported argument types for function gcd"),
115 }
116 }
117
118 fn documentation(&self) -> Option<&Documentation> {
119 self.doc()
120 }
121}
122
123fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
124 let a = a.as_primitive::<Int64Type>();
125 let b = b.as_primitive::<Int64Type>();
126 try_binary(a, b, compute_gcd)
127 .map(|arr: PrimitiveArray<Int64Type>| {
128 ColumnarValue::Array(Arc::new(arr) as ArrayRef)
129 })
130 .map_err(Into::into) }
132
133fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
134 match scalar {
135 Some(scalar_value) => {
136 let result: Result<Int64Array> = arr
137 .as_primitive::<Int64Type>()
138 .iter()
139 .map(|val| match val {
140 Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
141 _ => Ok(None),
142 })
143 .collect();
144
145 result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
146 }
147 None => Ok(ColumnarValue::Array(new_null_array(
148 &DataType::Int64,
149 arr.len(),
150 ))),
151 }
152}
153
154pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
156 if a == 0 {
157 return b;
158 }
159 if b == 0 {
160 return a;
161 }
162
163 let shift = (a | b).trailing_zeros();
164 a >>= a.trailing_zeros();
165 loop {
166 b >>= b.trailing_zeros();
167 if a > b {
168 swap(&mut a, &mut b);
169 }
170 b -= a;
171 if b == 0 {
172 return a << shift;
173 }
174 }
175}
176
177pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
179 let a = x.unsigned_abs();
180 let b = y.unsigned_abs();
181 let r = unsigned_gcd(a, b);
182 r.try_into().map_err(|_| {
184 ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
185 })
186}