datafusion_functions/math/
gcd.rs1use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::error::ArrowError;
22use std::any::Any;
23use std::mem::swap;
24use std::sync::Arc;
25
26use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34 doc_section(label = "Math Functions"),
35 description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
36 syntax_example = "gcd(expression_x, expression_y)",
37 sql_example = r#"```sql
38> SELECT gcd(48, 18);
39+------------+
40| gcd(48,18) |
41+------------+
42| 6 |
43+------------+
44```"#,
45 standard_argument(name = "expression_x", prefix = "First numeric"),
46 standard_argument(name = "expression_y", prefix = "Second numeric")
47)]
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct GcdFunc {
50 signature: Signature,
51}
52
53impl Default for GcdFunc {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl GcdFunc {
60 pub fn new() -> Self {
61 Self {
62 signature: Signature::uniform(
63 2,
64 vec![DataType::Int64],
65 Volatility::Immutable,
66 ),
67 }
68 }
69}
70
71impl ScalarUDFImpl for GcdFunc {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "gcd"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85 Ok(DataType::Int64)
86 }
87
88 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
90 internal_datafusion_err!("Expected 2 arguments for function gcd")
91 })?;
92
93 match args {
94 [ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
95 compute_gcd_for_arrays(&a, &b)
96 }
97 [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
98 match (a, b) {
99 (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
100 Some(compute_gcd(a, b)?),
101 ))),
102 _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
103 }
104 }
105 [ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
106 compute_gcd_with_scalar(&a, b)
107 }
108 [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Array(b)] => {
109 compute_gcd_with_scalar(&b, a)
110 }
111 _ => exec_err!("Unsupported argument types for function gcd"),
112 }
113 }
114
115 fn documentation(&self) -> Option<&Documentation> {
116 self.doc()
117 }
118}
119
120fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
121 let a = a.as_primitive::<Int64Type>();
122 let b = b.as_primitive::<Int64Type>();
123 try_binary(a, b, compute_gcd)
124 .map(|arr: PrimitiveArray<Int64Type>| {
125 ColumnarValue::Array(Arc::new(arr) as ArrayRef)
126 })
127 .map_err(Into::into) }
129
130fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
131 match scalar {
132 Some(scalar_value) => {
133 let result: Result<Int64Array> = arr
134 .as_primitive::<Int64Type>()
135 .iter()
136 .map(|val| match val {
137 Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
138 _ => Ok(None),
139 })
140 .collect();
141
142 result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
143 }
144 None => Ok(ColumnarValue::Array(new_null_array(
145 &DataType::Int64,
146 arr.len(),
147 ))),
148 }
149}
150
151pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
153 if a == 0 {
154 return b;
155 }
156 if b == 0 {
157 return a;
158 }
159
160 let shift = (a | b).trailing_zeros();
161 a >>= a.trailing_zeros();
162 loop {
163 b >>= b.trailing_zeros();
164 if a > b {
165 swap(&mut a, &mut b);
166 }
167 b -= a;
168 if b == 0 {
169 return a << shift;
170 }
171 }
172}
173
174pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
176 let a = x.unsigned_abs();
177 let b = y.unsigned_abs();
178 let r = unsigned_gcd(a, b);
179 r.try_into().map_err(|_| {
181 ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
182 })
183}