Skip to main content

reifydb_function/math/scalar/
gcd.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::r#type::Type;
6
7use crate::{
8	ScalarFunction, ScalarFunctionContext,
9	error::{ScalarFunctionError, ScalarFunctionResult},
10	propagate_options,
11};
12
13pub struct Gcd;
14
15impl Gcd {
16	pub fn new() -> Self {
17		Self
18	}
19}
20
21fn numeric_to_i64(data: &ColumnData, i: usize) -> Option<i64> {
22	match data {
23		ColumnData::Int1(c) => c.get(i).map(|&v| v as i64),
24		ColumnData::Int2(c) => c.get(i).map(|&v| v as i64),
25		ColumnData::Int4(c) => c.get(i).map(|&v| v as i64),
26		ColumnData::Int8(c) => c.get(i).copied(),
27		ColumnData::Int16(c) => c.get(i).map(|&v| v as i64),
28		ColumnData::Uint1(c) => c.get(i).map(|&v| v as i64),
29		ColumnData::Uint2(c) => c.get(i).map(|&v| v as i64),
30		ColumnData::Uint4(c) => c.get(i).map(|&v| v as i64),
31		ColumnData::Uint8(c) => c.get(i).map(|&v| v as i64),
32		_ => None,
33	}
34}
35
36fn compute_gcd(mut a: i64, mut b: i64) -> i64 {
37	a = a.abs();
38	b = b.abs();
39	while b != 0 {
40		let t = b;
41		b = a % b;
42		a = t;
43	}
44	a
45}
46
47impl ScalarFunction for Gcd {
48	fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
49		if let Some(result) = propagate_options(self, &ctx) {
50			return result;
51		}
52		let columns = ctx.columns;
53		let row_count = ctx.row_count;
54
55		if columns.len() != 2 {
56			return Err(ScalarFunctionError::ArityMismatch {
57				function: ctx.fragment.clone(),
58				expected: 2,
59				actual: columns.len(),
60			});
61		}
62
63		let a_col = columns.get(0).unwrap();
64		let b_col = columns.get(1).unwrap();
65
66		if !a_col.data().get_type().is_number() {
67			return Err(ScalarFunctionError::InvalidArgumentType {
68				function: ctx.fragment.clone(),
69				argument_index: 0,
70				expected: vec![
71					Type::Int1,
72					Type::Int2,
73					Type::Int4,
74					Type::Int8,
75					Type::Uint1,
76					Type::Uint2,
77					Type::Uint4,
78					Type::Uint8,
79				],
80				actual: a_col.data().get_type(),
81			});
82		}
83
84		if !b_col.data().get_type().is_number() {
85			return Err(ScalarFunctionError::InvalidArgumentType {
86				function: ctx.fragment.clone(),
87				argument_index: 1,
88				expected: vec![
89					Type::Int1,
90					Type::Int2,
91					Type::Int4,
92					Type::Int8,
93					Type::Uint1,
94					Type::Uint2,
95					Type::Uint4,
96					Type::Uint8,
97				],
98				actual: b_col.data().get_type(),
99			});
100		}
101
102		let mut result = Vec::with_capacity(row_count);
103		let mut bitvec = Vec::with_capacity(row_count);
104
105		for i in 0..row_count {
106			match (numeric_to_i64(a_col.data(), i), numeric_to_i64(b_col.data(), i)) {
107				(Some(a), Some(b)) => {
108					result.push(compute_gcd(a, b));
109					bitvec.push(true);
110				}
111				_ => {
112					result.push(0);
113					bitvec.push(false);
114				}
115			}
116		}
117
118		Ok(ColumnData::int8_with_bitvec(result, bitvec))
119	}
120
121	fn return_type(&self, _input_types: &[Type]) -> Type {
122		Type::Int8
123	}
124}