Skip to main content

reifydb_function/math/scalar/
gcd.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::r#type::Type;
6
7use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
8
9pub struct Gcd;
10
11impl Gcd {
12	pub fn new() -> Self {
13		Self
14	}
15}
16
17fn numeric_to_i64(data: &ColumnData, i: usize) -> Option<i64> {
18	match data {
19		ColumnData::Int1(c) => c.get(i).map(|&v| v as i64),
20		ColumnData::Int2(c) => c.get(i).map(|&v| v as i64),
21		ColumnData::Int4(c) => c.get(i).map(|&v| v as i64),
22		ColumnData::Int8(c) => c.get(i).copied(),
23		ColumnData::Int16(c) => c.get(i).map(|&v| v as i64),
24		ColumnData::Uint1(c) => c.get(i).map(|&v| v as i64),
25		ColumnData::Uint2(c) => c.get(i).map(|&v| v as i64),
26		ColumnData::Uint4(c) => c.get(i).map(|&v| v as i64),
27		ColumnData::Uint8(c) => c.get(i).map(|&v| v as i64),
28		_ => None,
29	}
30}
31
32fn compute_gcd(mut a: i64, mut b: i64) -> i64 {
33	a = a.abs();
34	b = b.abs();
35	while b != 0 {
36		let t = b;
37		b = a % b;
38		a = t;
39	}
40	a
41}
42
43impl ScalarFunction for Gcd {
44	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
45		if let Some(result) = propagate_options(self, &ctx) {
46			return result;
47		}
48		let columns = ctx.columns;
49		let row_count = ctx.row_count;
50
51		if columns.len() != 2 {
52			return Err(ScalarFunctionError::ArityMismatch {
53				function: ctx.fragment.clone(),
54				expected: 2,
55				actual: columns.len(),
56			});
57		}
58
59		let a_col = columns.get(0).unwrap();
60		let b_col = columns.get(1).unwrap();
61
62		if !a_col.data().get_type().is_number() {
63			return Err(ScalarFunctionError::InvalidArgumentType {
64				function: ctx.fragment.clone(),
65				argument_index: 0,
66				expected: vec![
67					Type::Int1,
68					Type::Int2,
69					Type::Int4,
70					Type::Int8,
71					Type::Uint1,
72					Type::Uint2,
73					Type::Uint4,
74					Type::Uint8,
75				],
76				actual: a_col.data().get_type(),
77			});
78		}
79
80		if !b_col.data().get_type().is_number() {
81			return Err(ScalarFunctionError::InvalidArgumentType {
82				function: ctx.fragment.clone(),
83				argument_index: 1,
84				expected: vec![
85					Type::Int1,
86					Type::Int2,
87					Type::Int4,
88					Type::Int8,
89					Type::Uint1,
90					Type::Uint2,
91					Type::Uint4,
92					Type::Uint8,
93				],
94				actual: b_col.data().get_type(),
95			});
96		}
97
98		let mut result = Vec::with_capacity(row_count);
99		let mut bitvec = Vec::with_capacity(row_count);
100
101		for i in 0..row_count {
102			match (numeric_to_i64(a_col.data(), i), numeric_to_i64(b_col.data(), i)) {
103				(Some(a), Some(b)) => {
104					result.push(compute_gcd(a, b));
105					bitvec.push(true);
106				}
107				_ => {
108					result.push(0);
109					bitvec.push(false);
110				}
111			}
112		}
113
114		Ok(ColumnData::int8_with_bitvec(result, bitvec))
115	}
116
117	fn return_type(&self, _input_types: &[Type]) -> Type {
118		Type::Int8
119	}
120}